OntoRAG: Ontology-Based Knowledge Extraction and Query Technical Specification (#523)

* Onto-rag tech spec

* New processor kg-extract-ontology, use 'ontology' objects from config to guide triple extraction

* Also entity contexts

* Integrate with ontology extractor from workbench

This is first phase, the extraction is tested and working, also GraphRAG with the extracted knowledge works
This commit is contained in:
cybermaggedon 2025-11-12 20:38:08 +00:00 committed by GitHub
parent 4c3db4dbbe
commit c69f5207a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 11824 additions and 0 deletions

View file

@ -0,0 +1,54 @@
"""
OntoRAG Query System.
Ontology-driven natural language query processing with multi-backend support.
Provides semantic query understanding, ontology matching, and answer generation.
"""
from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse
from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, BackendType, QueryRoute
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata
__all__ = [
# Main service
'OntoRAGQueryService',
'QueryRequest',
'QueryResponse',
# Question analysis
'QuestionAnalyzer',
'QuestionComponents',
'QuestionType',
# Ontology matching
'OntologyMatcher',
'QueryOntologySubset',
# Backend routing
'BackendRouter',
'BackendType',
'QueryRoute',
# SPARQL components
'SPARQLGenerator',
'SPARQLQuery',
'SPARQLCassandraEngine',
'SPARQLResult',
# Cypher components
'CypherGenerator',
'CypherQuery',
'CypherExecutor',
'CypherResult',
# Answer generation
'AnswerGenerator',
'GeneratedAnswer',
'AnswerMetadata',
]

View file

@ -0,0 +1,521 @@
"""
Answer generator for natural language responses.
Converts query results into natural language answers using LLM assistance.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class AnswerMetadata:
"""Metadata about answer generation."""
query_type: str
backend_used: str
execution_time: float
result_count: int
confidence: float
explanation: str
sources: List[str]
@dataclass
class GeneratedAnswer:
"""Generated natural language answer."""
answer: str
metadata: AnswerMetadata
supporting_facts: List[str]
raw_results: Union[SPARQLResult, CypherResult]
generation_time: float
class AnswerGenerator:
"""Generates natural language answers from query results."""
def __init__(self, prompt_service=None):
"""Initialize answer generator.
Args:
prompt_service: Service for LLM-based answer generation
"""
self.prompt_service = prompt_service
# Answer templates for different question types
self.templates = {
'count': "There are {count} {entity_type}.",
'boolean_true': "Yes, {statement} is true.",
'boolean_false': "No, {statement} is not true.",
'list': "The {entity_type} are: {items}.",
'single': "The {property} of {entity} is {value}.",
'none': "No results were found for your query.",
'error': "I encountered an error processing your query: {error}"
}
async def generate_answer(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset,
backend_used: str) -> GeneratedAnswer:
"""Generate natural language answer from query results.
Args:
question_components: Original question analysis
query_results: Results from query execution
ontology_subset: Ontology subset used
backend_used: Backend that executed the query
Returns:
Generated answer with metadata
"""
start_time = datetime.now()
try:
# Try LLM-based generation first
if self.prompt_service:
llm_answer = await self._generate_with_llm(
question_components, query_results, ontology_subset
)
if llm_answer:
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
llm_answer, question_components, query_results,
backend_used, execution_time
)
# Fall back to template-based generation
template_answer = self._generate_with_template(
question_components, query_results, ontology_subset
)
execution_time = (datetime.now() - start_time).total_seconds()
return self._build_answer_response(
template_answer, question_components, query_results,
backend_used, execution_time
)
except Exception as e:
logger.error(f"Answer generation failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
error_answer = self.templates['error'].format(error=str(e))
return self._build_answer_response(
error_answer, question_components, query_results,
backend_used, execution_time, confidence=0.0
)
async def _generate_with_llm(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Generate answer using LLM.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Generated answer or None if failed
"""
try:
prompt = self._build_answer_prompt(
question_components, query_results, ontology_subset
)
response = await self.prompt_service.generate_answer(prompt=prompt)
if response and isinstance(response, dict):
return response.get('answer', '').strip()
elif isinstance(response, str):
return response.strip()
except Exception as e:
logger.error(f"LLM answer generation failed: {e}")
return None
def _generate_with_template(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Generate answer using templates.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Template-based answer
"""
# Handle empty results
if not self._has_results(query_results):
return self.templates['none']
# Handle boolean queries
if question_components.question_type == QuestionType.BOOLEAN:
if hasattr(query_results, 'ask_result'):
# SPARQL ASK result
statement = self._extract_boolean_statement(question_components)
if query_results.ask_result:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
else:
# Cypher boolean (check if any results)
has_results = len(query_results.records) > 0
statement = self._extract_boolean_statement(question_components)
if has_results:
return self.templates['boolean_true'].format(statement=statement)
else:
return self.templates['boolean_false'].format(statement=statement)
# Handle count queries
if question_components.question_type == QuestionType.AGGREGATION:
count = self._extract_count(query_results)
entity_type = self._infer_entity_type(question_components, ontology_subset)
return self.templates['count'].format(count=count, entity_type=entity_type)
# Handle retrieval queries
if question_components.question_type == QuestionType.RETRIEVAL:
items = self._extract_items(query_results)
if len(items) == 1:
# Single result
entity = question_components.entities[0] if question_components.entities else "entity"
property_name = "value"
return self.templates['single'].format(
property=property_name, entity=entity, value=items[0]
)
else:
# Multiple results
entity_type = self._infer_entity_type(question_components, ontology_subset)
items_str = ", ".join(items)
return self.templates['list'].format(entity_type=entity_type, items=items_str)
# Handle factual queries
if question_components.question_type == QuestionType.FACTUAL:
facts = self._extract_facts(query_results)
return ". ".join(facts) if facts else self.templates['none']
# Default fallback
items = self._extract_items(query_results)
if items:
return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "")
else:
return self.templates['none']
def _build_answer_prompt(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM answer generation.
Args:
question_components: Question analysis
query_results: Query results
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
# Format results for prompt
results_str = self._format_results_for_prompt(query_results)
# Extract ontology context
context_classes = list(ontology_subset.classes.keys())[:5]
context_properties = list(ontology_subset.object_properties.keys())[:5]
prompt = f"""Generate a natural language answer for the following question based on the query results.
ORIGINAL QUESTION: {question_components.original_question}
QUESTION TYPE: {question_components.question_type.value}
EXPECTED ANSWER: {question_components.expected_answer_type}
ONTOLOGY CONTEXT:
- Classes: {', '.join(context_classes)}
- Properties: {', '.join(context_properties)}
QUERY RESULTS:
{results_str}
INSTRUCTIONS:
- Provide a clear, concise answer in natural language
- Use the original question's tone and style
- Include specific facts from the results
- If no results, explain that no information was found
- Be accurate and don't make assumptions beyond the data
- Limit response to 2-3 sentences unless the question requires more detail
ANSWER:"""
return prompt
def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str:
"""Format query results for prompt inclusion.
Args:
query_results: Query results to format
Returns:
Formatted results string
"""
if isinstance(query_results, SPARQLResult):
if hasattr(query_results, 'ask_result') and query_results.ask_result is not None:
return f"Boolean result: {query_results.ask_result}"
if not query_results.bindings:
return "No results found"
# Format SPARQL bindings
lines = []
for binding in query_results.bindings[:10]: # Limit to first 10
formatted = []
for var, value in binding.items():
if isinstance(value, dict):
formatted.append(f"{var}: {value.get('value', value)}")
else:
formatted.append(f"{var}: {value}")
lines.append("- " + ", ".join(formatted))
if len(query_results.bindings) > 10:
lines.append(f"... and {len(query_results.bindings) - 10} more results")
return "\n".join(lines)
else: # CypherResult
if not query_results.records:
return "No results found"
# Format Cypher records
lines = []
for record in query_results.records[:10]: # Limit to first 10
if isinstance(record, dict):
formatted = [f"{k}: {v}" for k, v in record.items()]
lines.append("- " + ", ".join(formatted))
else:
lines.append(f"- {record}")
if len(query_results.records) > 10:
lines.append(f"... and {len(query_results.records) - 10} more results")
return "\n".join(lines)
def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool:
"""Check if query results contain data.
Args:
query_results: Query results to check
Returns:
True if results contain data
"""
if isinstance(query_results, SPARQLResult):
return bool(query_results.bindings) or query_results.ask_result is not None
else: # CypherResult
return bool(query_results.records)
def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int:
"""Extract count from aggregation query results.
Args:
query_results: Query results
Returns:
Count value
"""
if isinstance(query_results, SPARQLResult):
if query_results.bindings:
binding = query_results.bindings[0]
# Look for count variable
for var, value in binding.items():
if 'count' in var.lower():
if isinstance(value, dict):
return int(value.get('value', 0))
return int(value)
return len(query_results.bindings)
else: # CypherResult
if query_results.records:
record = query_results.records[0]
if isinstance(record, dict):
# Look for count key
for key, value in record.items():
if 'count' in key.lower():
return int(value)
elif isinstance(record, (int, float)):
return int(record)
return len(query_results.records)
def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract items from query results.
Args:
query_results: Query results
Returns:
List of extracted items
"""
items = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
for var, value in binding.items():
if isinstance(value, dict):
item_value = value.get('value', str(value))
else:
item_value = str(value)
# Clean up URIs
if item_value.startswith('http'):
item_value = item_value.split('/')[-1].split('#')[-1]
items.append(item_value)
break # Take first value per binding
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
# Take first value from record
for key, value in record.items():
items.append(str(value))
break
else:
items.append(str(record))
return items
def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]:
"""Extract facts from query results.
Args:
query_results: Query results
Returns:
List of facts
"""
facts = []
if isinstance(query_results, SPARQLResult):
for binding in query_results.bindings:
fact_parts = []
for var, value in binding.items():
if isinstance(value, dict):
val_str = value.get('value', str(value))
else:
val_str = str(value)
# Clean up URIs
if val_str.startswith('http'):
val_str = val_str.split('/')[-1].split('#')[-1]
fact_parts.append(f"{var}: {val_str}")
facts.append(", ".join(fact_parts))
else: # CypherResult
for record in query_results.records:
if isinstance(record, dict):
fact_parts = [f"{k}: {v}" for k, v in record.items()]
facts.append(", ".join(fact_parts))
else:
facts.append(str(record))
return facts
def _extract_boolean_statement(self, question_components: QuestionComponents) -> str:
"""Extract statement for boolean answer.
Args:
question_components: Question analysis
Returns:
Statement string
"""
# Extract the key assertion from the question
question = question_components.original_question.lower()
# Remove question words
statement = question.replace('is ', '').replace('are ', '').replace('does ', '')
statement = statement.replace('?', '').strip()
return statement
def _infer_entity_type(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Infer entity type from question and ontology.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Entity type string
"""
# Try to match entities to ontology classes
for entity in question_components.entities:
entity_lower = entity.lower()
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower or entity_lower in class_id.lower():
return class_id
# Fallback to first entity or generic term
if question_components.entities:
return question_components.entities[0]
else:
return "entities"
def _build_answer_response(self,
answer: str,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
backend_used: str,
execution_time: float,
confidence: float = 0.8) -> GeneratedAnswer:
"""Build final answer response.
Args:
answer: Generated answer text
question_components: Question analysis
query_results: Query results
backend_used: Backend used for query
execution_time: Answer generation time
confidence: Confidence score
Returns:
Complete answer response
"""
# Extract supporting facts
supporting_facts = self._extract_facts(query_results)
# Build metadata
result_count = 0
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
metadata = AnswerMetadata(
query_type=question_components.question_type.value,
backend_used=backend_used,
execution_time=execution_time,
result_count=result_count,
confidence=confidence,
explanation=f"Generated answer using {backend_used} backend",
sources=[] # Could be populated with data source information
)
return GeneratedAnswer(
answer=answer,
metadata=metadata,
supporting_facts=supporting_facts[:5], # Limit to top 5
raw_results=query_results,
generation_time=execution_time
)

View file

@ -0,0 +1,350 @@
"""
Backend router for ontology query system.
Routes queries to appropriate backend based on configuration.
"""
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
from .question_analyzer import QuestionComponents
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
class BackendType(Enum):
"""Supported backend types."""
CASSANDRA = "cassandra"
NEO4J = "neo4j"
MEMGRAPH = "memgraph"
FALKORDB = "falkordb"
@dataclass
class BackendConfig:
"""Configuration for a backend."""
type: BackendType
priority: int = 0
enabled: bool = True
config: Dict[str, Any] = None
@dataclass
class QueryRoute:
"""Routing decision for a query."""
backend_type: BackendType
query_language: str # 'sparql' or 'cypher'
confidence: float
reasoning: str
class BackendRouter:
"""Routes queries to appropriate backends based on configuration and heuristics."""
def __init__(self, config: Dict[str, Any]):
"""Initialize backend router.
Args:
config: Router configuration
"""
self.config = config
self.backends = self._parse_backend_config(config)
self.routing_strategy = config.get('routing_strategy', 'priority')
self.enable_fallback = config.get('enable_fallback', True)
def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]:
"""Parse backend configuration.
Args:
config: Configuration dictionary
Returns:
Dictionary of backend type to configuration
"""
backends = {}
# Parse primary backend
primary = config.get('primary', 'cassandra')
if primary:
try:
backend_type = BackendType(primary)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=100,
enabled=True,
config=config.get(primary, {})
)
except ValueError:
logger.warning(f"Unknown primary backend type: {primary}")
# Parse fallback backends
fallbacks = config.get('fallback', [])
for i, fallback in enumerate(fallbacks):
try:
backend_type = BackendType(fallback)
backends[backend_type] = BackendConfig(
type=backend_type,
priority=50 - i * 10, # Decreasing priority
enabled=True,
config=config.get(fallback, {})
)
except ValueError:
logger.warning(f"Unknown fallback backend type: {fallback}")
return backends
def route_query(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route a query to the best backend.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with routing decision
"""
if self.routing_strategy == 'priority':
return self._route_by_priority()
elif self.routing_strategy == 'adaptive':
return self._route_adaptive(question_components, ontology_subsets)
elif self.routing_strategy == 'round_robin':
return self._route_round_robin()
else:
return self._route_by_priority()
def _route_by_priority(self) -> QueryRoute:
"""Route based on backend priority.
Returns:
QueryRoute to highest priority backend
"""
# Find highest priority enabled backend
best_backend = None
best_priority = -1
for backend_type, backend_config in self.backends.items():
if backend_config.enabled and backend_config.priority > best_priority:
best_backend = backend_type
best_priority = backend_config.priority
if best_backend is None:
raise RuntimeError("No enabled backends available")
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=1.0,
reasoning=f"Priority routing to {best_backend.value}"
)
def _route_adaptive(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> QueryRoute:
"""Route based on question characteristics and ontology complexity.
Args:
question_components: Analyzed question
ontology_subsets: Relevant ontology subsets
Returns:
QueryRoute with adaptive decision
"""
scores = {}
for backend_type, backend_config in self.backends.items():
if not backend_config.enabled:
continue
score = self._calculate_backend_score(
backend_type, question_components, ontology_subsets
)
scores[backend_type] = score
if not scores:
raise RuntimeError("No enabled backends available")
# Select backend with highest score
best_backend = max(scores.keys(), key=lambda k: scores[k])
best_score = scores[best_backend]
# Determine query language
query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=best_backend,
query_language=query_language,
confidence=best_score,
reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}"
)
def _calculate_backend_score(self,
backend_type: BackendType,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> float:
"""Calculate score for a backend based on query characteristics.
Args:
backend_type: Backend to score
question_components: Question analysis
ontology_subsets: Ontology subsets
Returns:
Score (0.0 to 1.0)
"""
score = 0.0
# Base priority score
backend_config = self.backends[backend_type]
score += backend_config.priority / 100.0
# Question type preferences
if backend_type == BackendType.CASSANDRA:
# SPARQL is good for hierarchical and complex reasoning
if question_components.question_type.value in ['factual', 'aggregation']:
score += 0.3
# Good for ontology-heavy queries
if len(ontology_subsets) > 1:
score += 0.2
else:
# Cypher is good for graph traversal and relationships
if question_components.question_type.value in ['relationship', 'retrieval']:
score += 0.3
# Good for simple graph patterns
if len(question_components.relationships) > 0:
score += 0.2
# Complexity considerations
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if backend_type == BackendType.CASSANDRA:
# SPARQL handles complex ontologies well
if total_elements > 20:
score += 0.2
else:
# Cypher is efficient for simpler queries
if total_elements <= 10:
score += 0.2
# Aggregation considerations
if question_components.aggregations:
if backend_type == BackendType.CASSANDRA:
score += 0.1 # SPARQL has built-in aggregation
else:
score += 0.2 # Cypher has excellent aggregation
return min(score, 1.0)
def _route_round_robin(self) -> QueryRoute:
"""Route using round-robin strategy.
Returns:
QueryRoute using round-robin selection
"""
# Simple round-robin implementation
enabled_backends = [
bt for bt, bc in self.backends.items() if bc.enabled
]
if not enabled_backends:
raise RuntimeError("No enabled backends available")
# For simplicity, just return the first enabled backend
# In a real implementation, you'd track state
backend_type = enabled_backends[0]
query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=backend_type,
query_language=query_language,
confidence=0.8,
reasoning=f"Round-robin routing to {backend_type.value}"
)
def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]:
"""Get fallback route when a backend fails.
Args:
failed_backend: Backend that failed
Returns:
Fallback route or None if no fallback available
"""
if not self.enable_fallback:
return None
# Find next best backend
fallback_backends = [
(bt, bc) for bt, bc in self.backends.items()
if bc.enabled and bt != failed_backend
]
if not fallback_backends:
return None
# Sort by priority
fallback_backends.sort(key=lambda x: x[1].priority, reverse=True)
fallback_type = fallback_backends[0][0]
query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher'
return QueryRoute(
backend_type=fallback_type,
query_language=query_language,
confidence=0.7,
reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}"
)
def get_available_backends(self) -> List[BackendType]:
"""Get list of available backends.
Returns:
List of enabled backend types
"""
return [bt for bt, bc in self.backends.items() if bc.enabled]
def is_backend_enabled(self, backend_type: BackendType) -> bool:
"""Check if a backend is enabled.
Args:
backend_type: Backend to check
Returns:
True if backend is enabled
"""
backend_config = self.backends.get(backend_type)
return backend_config is not None and backend_config.enabled
def update_backend_status(self, backend_type: BackendType, enabled: bool):
"""Update backend enabled status.
Args:
backend_type: Backend to update
enabled: New enabled status
"""
if backend_type in self.backends:
self.backends[backend_type].enabled = enabled
logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}")
else:
logger.warning(f"Unknown backend type: {backend_type}")
def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]:
"""Get configuration for a backend.
Args:
backend_type: Backend type
Returns:
Configuration dictionary or None
"""
backend_config = self.backends.get(backend_type)
return backend_config.config if backend_config else None

View file

@ -0,0 +1,651 @@
"""
Caching system for OntoRAG query results and computations.
Provides multiple cache backends and intelligent cache management.
"""
import logging
import time
import json
import pickle
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from abc import ABC, abstractmethod
from pathlib import Path
import threading
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Cache entry with metadata."""
key: str
value: Any
created_at: datetime
accessed_at: datetime
access_count: int
ttl_seconds: Optional[int] = None
tags: List[str] = None
size_bytes: int = 0
def is_expired(self) -> bool:
"""Check if cache entry is expired."""
if self.ttl_seconds is None:
return False
return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds
def touch(self):
"""Update access time and count."""
self.accessed_at = datetime.now()
self.access_count += 1
@dataclass
class CacheStats:
"""Cache performance statistics."""
hits: int = 0
misses: int = 0
evictions: int = 0
total_entries: int = 0
total_size_bytes: int = 0
hit_rate: float = 0.0
def update_hit_rate(self):
"""Update hit rate calculation."""
total_requests = self.hits + self.misses
self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0
class CacheBackend(ABC):
"""Abstract base class for cache backends."""
@abstractmethod
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
pass
@abstractmethod
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""Delete cache entry."""
pass
@abstractmethod
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
pass
@abstractmethod
def cleanup_expired(self):
"""Clean up expired entries."""
pass
class InMemoryCache(CacheBackend):
"""In-memory cache backend."""
def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024):
"""Initialize in-memory cache.
Args:
max_size: Maximum number of entries
max_size_bytes: Maximum total size in bytes
"""
self.max_size = max_size
self.max_size_bytes = max_size_bytes
self.entries: Dict[str, CacheEntry] = {}
self.stats = CacheStats()
self._lock = threading.RLock()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
entry = self.entries.get(key)
if entry is None:
self.stats.misses += 1
self.stats.update_hit_rate()
return None
if entry.is_expired():
del self.entries[key]
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.update_hit_rate()
return None
entry.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
# Calculate size
try:
size_bytes = len(pickle.dumps(value))
except Exception:
size_bytes = len(str(value).encode('utf-8'))
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or [],
size_bytes=size_bytes
)
# Check if we need to evict
self._ensure_capacity(size_bytes)
# Store entry
old_entry = self.entries.get(key)
if old_entry:
self.stats.total_size_bytes -= old_entry.size_bytes
else:
self.stats.total_entries += 1
self.entries[key] = entry
self.stats.total_size_bytes += size_bytes
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
entry = self.entries.pop(key, None)
if entry:
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
self.stats.evictions += len(self.entries)
self.entries.clear()
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
to_delete = []
for key, entry in self.entries.items():
if any(tag in entry.tags for tag in tags):
to_delete.append(key)
for key in to_delete:
self.delete(key)
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
to_delete = []
for key, entry in self.entries.items():
if entry.is_expired():
to_delete.append(key)
for key in to_delete:
self.delete(key)
def _ensure_capacity(self, new_size_bytes: int):
"""Ensure cache has capacity for new entry."""
# Check size limit
if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes:
self._evict_by_size(new_size_bytes)
# Check count limit
if len(self.entries) >= self.max_size:
self._evict_by_count()
def _evict_by_size(self, needed_bytes: int):
"""Evict entries to free up space."""
# Sort by access time (LRU)
sorted_entries = sorted(
self.entries.items(),
key=lambda x: (x[1].accessed_at, x[1].access_count)
)
freed_bytes = 0
for key, entry in sorted_entries:
if freed_bytes >= needed_bytes:
break
freed_bytes += entry.size_bytes
del self.entries[key]
self.stats.total_entries -= 1
self.stats.total_size_bytes -= entry.size_bytes
self.stats.evictions += 1
def _evict_by_count(self):
"""Evict least recently used entry."""
if not self.entries:
return
# Find LRU entry
lru_key = min(
self.entries.keys(),
key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count)
)
self.delete(lru_key)
class FileCache(CacheBackend):
"""File-based cache backend."""
def __init__(self, cache_dir: str, max_files: int = 10000):
"""Initialize file cache.
Args:
cache_dir: Directory to store cache files
max_files: Maximum number of cache files
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_files = max_files
self.stats = CacheStats()
self._lock = threading.RLock()
# Load existing stats
self._load_stats()
def get(self, key: str) -> Optional[CacheEntry]:
"""Get cache entry by key."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if not cache_file.exists():
self.stats.misses += 1
self.stats.update_hit_rate()
return None
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
cache_file.unlink()
self.stats.misses += 1
self.stats.evictions += 1
self.stats.total_entries -= 1
self.stats.update_hit_rate()
return None
entry.touch()
# Update file modification time
cache_file.touch()
self.stats.hits += 1
self.stats.update_hit_rate()
return entry
except Exception as e:
logger.error(f"Error reading cache file {cache_file}: {e}")
cache_file.unlink(missing_ok=True)
self.stats.misses += 1
self.stats.update_hit_rate()
return None
def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None):
"""Set cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
# Create entry
now = datetime.now()
entry = CacheEntry(
key=key,
value=value,
created_at=now,
accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds,
tags=tags or []
)
try:
# Ensure capacity
self._ensure_capacity()
# Write to file
with open(cache_file, 'wb') as f:
pickle.dump(entry, f)
entry.size_bytes = cache_file.stat().st_size
if not cache_file.exists():
self.stats.total_entries += 1
self.stats.total_size_bytes += entry.size_bytes
self._save_stats()
except Exception as e:
logger.error(f"Error writing cache file {cache_file}: {e}")
def delete(self, key: str) -> bool:
"""Delete cache entry."""
with self._lock:
cache_file = self.cache_dir / f"{self._safe_key(key)}.cache"
if cache_file.exists():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
self._save_stats()
return True
return False
def clear(self, tags: Optional[List[str]] = None):
"""Clear cache entries."""
with self._lock:
if tags is None:
# Clear all
for cache_file in self.cache_dir.glob("*.cache"):
cache_file.unlink()
self.stats.evictions += self.stats.total_entries
self.stats.total_entries = 0
self.stats.total_size_bytes = 0
else:
# Clear by tags
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if any(tag in entry.tags for tag in tags):
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
continue
self._save_stats()
def get_stats(self) -> CacheStats:
"""Get cache statistics."""
with self._lock:
return CacheStats(
hits=self.stats.hits,
misses=self.stats.misses,
evictions=self.stats.evictions,
total_entries=self.stats.total_entries,
total_size_bytes=self.stats.total_size_bytes,
hit_rate=self.stats.hit_rate
)
def cleanup_expired(self):
"""Clean up expired entries."""
with self._lock:
for cache_file in self.cache_dir.glob("*.cache"):
try:
with open(cache_file, 'rb') as f:
entry = pickle.load(f)
if entry.is_expired():
size = cache_file.stat().st_size
cache_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
except Exception:
# Remove corrupted files
cache_file.unlink()
self._save_stats()
def _safe_key(self, key: str) -> str:
"""Convert key to safe filename."""
import hashlib
return hashlib.md5(key.encode()).hexdigest()
def _ensure_capacity(self):
"""Ensure cache has capacity for new entry."""
cache_files = list(self.cache_dir.glob("*.cache"))
if len(cache_files) >= self.max_files:
# Remove oldest file
oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime)
size = oldest_file.stat().st_size
oldest_file.unlink()
self.stats.total_entries -= 1
self.stats.total_size_bytes -= size
self.stats.evictions += 1
def _load_stats(self):
"""Load statistics from file."""
stats_file = self.cache_dir / "stats.json"
if stats_file.exists():
try:
with open(stats_file, 'r') as f:
data = json.load(f)
self.stats = CacheStats(**data)
except Exception:
pass
def _save_stats(self):
"""Save statistics to file."""
stats_file = self.cache_dir / "stats.json"
try:
with open(stats_file, 'w') as f:
json.dump(asdict(self.stats), f, default=str)
except Exception:
pass
class CacheManager:
"""Cache manager with multiple backends and intelligent caching strategies."""
def __init__(self, config: Dict[str, Any]):
"""Initialize cache manager.
Args:
config: Cache configuration
"""
self.config = config
self.backends: Dict[str, CacheBackend] = {}
self.default_backend = config.get('default_backend', 'memory')
self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour
# Initialize backends
self._init_backends()
# Start cleanup task
self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes
self._start_cleanup_task()
def _init_backends(self):
"""Initialize cache backends."""
backends_config = self.config.get('backends', {})
# Memory backend
if 'memory' in backends_config or self.default_backend == 'memory':
memory_config = backends_config.get('memory', {})
self.backends['memory'] = InMemoryCache(
max_size=memory_config.get('max_size', 1000),
max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024)
)
# File backend
if 'file' in backends_config or self.default_backend == 'file':
file_config = backends_config.get('file', {})
self.backends['file'] = FileCache(
cache_dir=file_config.get('cache_dir', './cache'),
max_files=file_config.get('max_files', 10000)
)
def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]:
"""Get value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
Cached value or None
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return None
entry = cache_backend.get(key)
return entry.value if entry else None
def set(self,
key: str,
value: Any,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: Time to live in seconds
tags: Cache tags
backend: Backend name (optional)
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
logger.warning(f"Cache backend '{backend_name}' not found")
return
ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl
cache_backend.set(key, value, ttl, tags)
def delete(self, key: str, backend: Optional[str] = None) -> bool:
"""Delete value from cache.
Args:
key: Cache key
backend: Backend name (optional)
Returns:
True if deleted
"""
backend_name = backend or self.default_backend
cache_backend = self.backends.get(backend_name)
if cache_backend is None:
return False
return cache_backend.delete(key)
def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None):
"""Clear cache entries.
Args:
tags: Tags to clear (optional)
backend: Backend name (optional)
"""
if backend:
cache_backend = self.backends.get(backend)
if cache_backend:
cache_backend.clear(tags)
else:
# Clear all backends
for cache_backend in self.backends.values():
cache_backend.clear(tags)
def get_stats(self) -> Dict[str, CacheStats]:
"""Get statistics for all backends.
Returns:
Dictionary of backend name to statistics
"""
return {name: backend.get_stats() for name, backend in self.backends.items()}
def cleanup_expired(self):
"""Clean up expired entries in all backends."""
for backend in self.backends.values():
try:
backend.cleanup_expired()
except Exception as e:
logger.error(f"Error cleaning up cache backend: {e}")
def _start_cleanup_task(self):
"""Start periodic cleanup task."""
def cleanup_worker():
while True:
try:
time.sleep(self.cleanup_interval)
self.cleanup_expired()
except Exception as e:
logger.error(f"Cache cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Cache decorators and utilities
def cache_result(cache_manager: CacheManager,
key_func: Optional[callable] = None,
ttl_seconds: Optional[int] = None,
tags: Optional[List[str]] = None,
backend: Optional[str] = None):
"""Decorator to cache function results.
Args:
cache_manager: Cache manager instance
key_func: Function to generate cache key
ttl_seconds: Time to live
tags: Cache tags
backend: Backend name
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Generate cache key
if key_func:
cache_key = key_func(*args, **kwargs)
else:
cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
# Try to get from cache
cached_result = cache_manager.get(cache_key, backend)
if cached_result is not None:
return cached_result
# Execute function
result = func(*args, **kwargs)
# Cache result
cache_manager.set(cache_key, result, ttl_seconds, tags, backend)
return result
return wrapper
return decorator

View file

@ -0,0 +1,610 @@
"""
Cypher executor for multiple graph databases.
Executes Cypher queries against Neo4j, Memgraph, and FalkorDB.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
# Try to import various database drivers
try:
from neo4j import GraphDatabase, Driver as Neo4jDriver
NEO4J_AVAILABLE = True
except ImportError:
NEO4J_AVAILABLE = False
Neo4jDriver = None
try:
import redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
@dataclass
class CypherResult:
"""Result from Cypher query execution."""
records: List[Dict[str, Any]]
summary: Dict[str, Any]
execution_time: float
database_type: str
query_plan: Optional[Dict[str, Any]] = None
class CypherExecutorBase(ABC):
"""Abstract base class for Cypher executors."""
@abstractmethod
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query."""
pass
@abstractmethod
async def close(self):
"""Close database connection."""
pass
@abstractmethod
def is_connected(self) -> bool:
"""Check if connected to database."""
pass
class Neo4jExecutor(CypherExecutorBase):
"""Cypher executor for Neo4j database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Neo4j executor.
Args:
config: Neo4j configuration
"""
if not NEO4J_AVAILABLE:
raise RuntimeError("Neo4j driver not available")
self.config = config
self.driver: Optional[Neo4jDriver] = None
self._connection_pool_size = config.get('connection_pool_size', 10)
async def connect(self):
"""Connect to Neo4j database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7687')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Create driver with connection pool
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self._connection_pool_size,
connection_timeout=self.config.get('connection_timeout', 30),
max_retry_time=self.config.get('max_retry_time', 15)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Neo4j at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Neo4j: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Neo4j.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query in a session
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={'record_count': len(records)},
execution_time=execution_time,
database_type='neo4j'
)
except Exception as e:
logger.error(f"Neo4j query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='neo4j'
)
def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously in thread executor.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
result = session.run(cypher_query.query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
value = record[key]
record_dict[key] = self._format_neo4j_value(value)
records.append(record_dict)
return records
def _format_neo4j_value(self, value):
"""Format Neo4j value for JSON serialization.
Args:
value: Neo4j value
Returns:
JSON-serializable value
"""
# Handle Neo4j node objects
if hasattr(value, 'labels') and hasattr(value, 'items'):
return {
'labels': list(value.labels),
'properties': dict(value.items())
}
# Handle Neo4j relationship objects
elif hasattr(value, 'type') and hasattr(value, 'items'):
return {
'type': value.type,
'properties': dict(value.items())
}
# Handle Neo4j path objects
elif hasattr(value, 'nodes') and hasattr(value, 'relationships'):
return {
'nodes': [self._format_neo4j_value(n) for n in value.nodes],
'relationships': [self._format_neo4j_value(r) for r in value.relationships]
}
else:
return value
async def close(self):
"""Close Neo4j connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Neo4j connection closed")
def is_connected(self) -> bool:
"""Check if connected to Neo4j."""
return self.driver is not None
class MemgraphExecutor(CypherExecutorBase):
"""Cypher executor for Memgraph database."""
def __init__(self, config: Dict[str, Any]):
"""Initialize Memgraph executor.
Args:
config: Memgraph configuration
"""
if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver
raise RuntimeError("Neo4j driver required for Memgraph")
self.config = config
self.driver: Optional[Neo4jDriver] = None
async def connect(self):
"""Connect to Memgraph database."""
try:
uri = self.config.get('uri', 'bolt://localhost:7688')
username = self.config.get('username')
password = self.config.get('password')
auth = (username, password) if username and password else None
# Memgraph uses Neo4j driver but with different defaults
self.driver = GraphDatabase.driver(
uri,
auth=auth,
max_connection_pool_size=self.config.get('connection_pool_size', 5),
connection_timeout=self.config.get('connection_timeout', 10)
)
# Verify connectivity
await asyncio.get_event_loop().run_in_executor(
None, self.driver.verify_connectivity
)
logger.info(f"Connected to Memgraph at {uri}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.driver:
await self.connect()
import time
start_time = time.time()
try:
# Execute query with Memgraph-specific optimizations
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_memgraph_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'memgraph'
},
execution_time=execution_time,
database_type='memgraph'
)
except Exception as e:
logger.error(f"Memgraph query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='memgraph'
)
def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for Memgraph.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
with self.driver.session() as session:
# Add Memgraph-specific query hints if available
query = cypher_query.query
if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'):
# Memgraph supports memory limits
query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}"
result = session.run(query, cypher_query.parameters)
records = []
for record in result:
record_dict = {}
for key in record.keys():
record_dict[key] = record[key]
records.append(record_dict)
return records
async def close(self):
"""Close Memgraph connection."""
if self.driver:
await asyncio.get_event_loop().run_in_executor(
None, self.driver.close
)
self.driver = None
logger.info("Memgraph connection closed")
def is_connected(self) -> bool:
"""Check if connected to Memgraph."""
return self.driver is not None
class FalkorDBExecutor(CypherExecutorBase):
"""Cypher executor for FalkorDB (Redis-based graph database)."""
def __init__(self, config: Dict[str, Any]):
"""Initialize FalkorDB executor.
Args:
config: FalkorDB configuration
"""
if not REDIS_AVAILABLE:
raise RuntimeError("Redis driver required for FalkorDB")
self.config = config
self.redis_client: Optional[redis.Redis] = None
self.graph_name = config.get('graph_name', 'knowledge_graph')
async def connect(self):
"""Connect to FalkorDB (Redis)."""
try:
self.redis_client = redis.Redis(
host=self.config.get('host', 'localhost'),
port=self.config.get('port', 6379),
password=self.config.get('password'),
db=self.config.get('db', 0),
decode_responses=True,
socket_connect_timeout=self.config.get('connection_timeout', 10),
socket_timeout=self.config.get('socket_timeout', 10)
)
# Test connection
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.ping
)
logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}")
except Exception as e:
logger.error(f"Failed to connect to FalkorDB: {e}")
raise
async def execute(self, cypher_query: CypherQuery) -> CypherResult:
"""Execute Cypher query against FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
Query results
"""
if not self.redis_client:
await self.connect()
import time
start_time = time.time()
try:
# Execute query using FalkorDB's GRAPH.QUERY command
records = await asyncio.get_event_loop().run_in_executor(
None, self._execute_falkordb_sync, cypher_query
)
execution_time = time.time() - start_time
return CypherResult(
records=records,
summary={
'record_count': len(records),
'engine': 'falkordb'
},
execution_time=execution_time,
database_type='falkordb'
)
except Exception as e:
logger.error(f"FalkorDB query execution error: {e}")
execution_time = time.time() - start_time
return CypherResult(
records=[],
summary={'error': str(e)},
execution_time=execution_time,
database_type='falkordb'
)
def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]:
"""Execute query synchronously for FalkorDB.
Args:
cypher_query: Cypher query to execute
Returns:
List of record dictionaries
"""
# Substitute parameters in query (FalkorDB parameter handling)
query = cypher_query.query
for param, value in cypher_query.parameters.items():
if isinstance(value, str):
query = query.replace(f'${param}', f'"{value}"')
else:
query = query.replace(f'${param}', str(value))
# Execute using FalkorDB GRAPH.QUERY command
result = self.redis_client.execute_command(
'GRAPH.QUERY', self.graph_name, query
)
# Parse FalkorDB result format
records = []
if result and len(result) > 1:
# FalkorDB returns [header, data rows, statistics]
headers = result[0] if result[0] else []
data_rows = result[1] if len(result) > 1 else []
for row in data_rows:
record = {}
for i, header in enumerate(headers):
if i < len(row):
record[header] = self._format_falkordb_value(row[i])
records.append(record)
return records
def _format_falkordb_value(self, value):
"""Format FalkorDB value for JSON serialization.
Args:
value: FalkorDB value
Returns:
JSON-serializable value
"""
# FalkorDB returns values in specific formats
if isinstance(value, list) and len(value) == 3:
# Check if it's a node/relationship representation
if value[0] == 1: # Node
return {
'type': 'node',
'labels': value[1],
'properties': value[2]
}
elif value[0] == 2: # Relationship
return {
'type': 'relationship',
'rel_type': value[1],
'properties': value[2]
}
return value
async def close(self):
"""Close FalkorDB connection."""
if self.redis_client:
await asyncio.get_event_loop().run_in_executor(
None, self.redis_client.close
)
self.redis_client = None
logger.info("FalkorDB connection closed")
def is_connected(self) -> bool:
"""Check if connected to FalkorDB."""
return self.redis_client is not None
class CypherExecutor:
"""Multi-database Cypher executor with automatic routing."""
def __init__(self, config: Dict[str, Any]):
"""Initialize multi-database executor.
Args:
config: Configuration for all database types
"""
self.config = config
self.executors: Dict[str, CypherExecutorBase] = {}
# Initialize available executors
self._initialize_executors()
def _initialize_executors(self):
"""Initialize database executors based on configuration."""
# Neo4j executor
if 'neo4j' in self.config and NEO4J_AVAILABLE:
try:
self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j'])
logger.info("Neo4j executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Neo4j executor: {e}")
# Memgraph executor
if 'memgraph' in self.config and NEO4J_AVAILABLE:
try:
self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph'])
logger.info("Memgraph executor initialized")
except Exception as e:
logger.error(f"Failed to initialize Memgraph executor: {e}")
# FalkorDB executor
if 'falkordb' in self.config and REDIS_AVAILABLE:
try:
self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb'])
logger.info("FalkorDB executor initialized")
except Exception as e:
logger.error(f"Failed to initialize FalkorDB executor: {e}")
if not self.executors:
raise RuntimeError("No database executors could be initialized")
async def execute_cypher(self, cypher_query: CypherQuery,
database_type: str) -> CypherResult:
"""Execute Cypher query on specified database.
Args:
cypher_query: Cypher query to execute
database_type: Target database type
Returns:
Query results
"""
if database_type not in self.executors:
raise ValueError(f"Database type {database_type} not available. "
f"Available: {list(self.executors.keys())}")
executor = self.executors[database_type]
# Ensure connection
if not executor.is_connected():
await executor.connect()
# Execute query
return await executor.execute(cypher_query)
async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]:
"""Execute query on all available databases.
Args:
cypher_query: Cypher query to execute
Returns:
Results from all databases
"""
results = {}
tasks = []
for db_type, executor in self.executors.items():
task = asyncio.create_task(
self.execute_cypher(cypher_query, db_type),
name=f"cypher_query_{db_type}"
)
tasks.append((db_type, task))
# Wait for all tasks to complete
for db_type, task in tasks:
try:
results[db_type] = await task
except Exception as e:
logger.error(f"Query failed on {db_type}: {e}")
results[db_type] = CypherResult(
records=[],
summary={'error': str(e)},
execution_time=0.0,
database_type=db_type
)
return results
def get_available_databases(self) -> List[str]:
"""Get list of available database types.
Returns:
List of available database type names
"""
return list(self.executors.keys())
async def close_all(self):
"""Close all database connections."""
for executor in self.executors.values():
await executor.close()
logger.info("All Cypher executor connections closed")

View file

@ -0,0 +1,628 @@
"""
Cypher query generator for ontology-sensitive queries.
Converts natural language questions to Cypher queries for graph databases.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class CypherQuery:
"""Generated Cypher query with metadata."""
query: str
parameters: Dict[str, Any]
variables: List[str]
explanation: str
complexity_score: float
database_hints: Dict[str, Any] = None # Database-specific optimization hints
class CypherGenerator:
"""Generates Cypher queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize Cypher generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# Cypher query templates for common patterns
self.templates = {
'simple_node_query': """
MATCH (n:{node_label})
RETURN n.name AS name, n.{property} AS {property}
LIMIT {limit}""",
'relationship_query': """
MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label})
WHERE a.name = $source_name
RETURN b.name AS name, r.{rel_property} AS property""",
'path_query': """
MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label})
WHERE start.name = $start_name
RETURN path, length(path) AS path_length
ORDER BY path_length""",
'count_query': """
MATCH (n:{node_label})
{where_clause}
RETURN count(n) AS count""",
'aggregation_query': """
MATCH (n:{node_label})
{where_clause}
RETURN
count(n) AS count,
avg(n.{numeric_property}) AS average,
sum(n.{numeric_property}) AS total""",
'boolean_query': """
MATCH (a:{source_label})-[:{relationship}]->(b:{target_label})
WHERE a.name = $source_name AND b.name = $target_name
RETURN count(*) > 0 AS exists""",
'hierarchy_query': """
MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label})
WHERE parent.name = $parent_name
RETURN child.name AS child_name, parent.name AS parent_name""",
'property_filter_query': """
MATCH (n:{node_label})
WHERE n.{property} {operator} ${property}_value
RETURN n.name AS name, n.{property} AS {property}
ORDER BY n.{property}"""
}
async def generate_cypher(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str = "neo4j") -> CypherQuery:
"""Generate Cypher query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
database_type: Target database (neo4j, memgraph, falkordb)
Returns:
Generated Cypher query
"""
# Try template-based generation first
template_query = self._try_template_generation(
question_components, ontology_subset, database_type
)
if template_query:
logger.debug("Generated Cypher using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(
question_components, ontology_subset, database_type
)
if llm_query:
logger.debug("Generated Cypher using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple Cypher pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if no template matches
"""
# Simple node query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1):
node_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
if node_label:
query = self.templates['simple_node_query'].format(
node_label=node_label,
property='name',
limit=100
)
return CypherQuery(
query=query,
parameters={},
variables=['name'],
explanation=f"Retrieve all nodes of type {node_label}",
complexity_score=0.2,
database_hints=self._get_database_hints(database_type, 'simple')
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations):
node_label = self._find_matching_node_label(
question_components.entities[0] if question_components.entities else 'Entity',
ontology_subset
)
if node_label:
where_clause = self._build_where_clause(question_components)
query = self.templates['count_query'].format(
node_label=node_label,
where_clause=where_clause
)
return CypherQuery(
query=query,
parameters=self._extract_parameters(question_components),
variables=['count'],
explanation=f"Count nodes of type {node_label}",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'aggregation')
)
# Relationship query (Which documents were authored by John Smith?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) >= 2):
source_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['relationship_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship,
rel_property='name'
)
return CypherQuery(
query=query,
parameters={'source_name': question_components.entities[1]},
variables=['name'],
explanation=f"Find {target_label} related to {source_label} via {relationship}",
complexity_score=0.4,
database_hints=self._get_database_hints(database_type, 'relationship')
)
# Boolean query (Is X related to Y?)
if question_components.question_type == QuestionType.BOOLEAN:
if len(question_components.entities) >= 2:
source_label = self._find_matching_node_label(
question_components.entities[0], ontology_subset
)
target_label = self._find_matching_node_label(
question_components.entities[1], ontology_subset
)
relationship = self._find_matching_relationship(
question_components, ontology_subset
)
if source_label and target_label and relationship:
query = self.templates['boolean_query'].format(
source_label=source_label,
target_label=target_label,
relationship=relationship
)
return CypherQuery(
query=query,
parameters={
'source_name': question_components.entities[0],
'target_name': question_components.entities[1]
},
variables=['exists'],
explanation="Boolean check for relationship existence",
complexity_score=0.3,
database_hints=self._get_database_hints(database_type, 'boolean')
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> Optional[CypherQuery]:
"""Generate Cypher using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_cypher_prompt(
question_components, ontology_subset, database_type
)
response = await self.prompt_service.generate_cypher(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')):
return CypherQuery(
query=query,
parameters=response.get('parameters', {}),
variables=self._extract_variables(query),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query),
database_hints=self._get_database_hints(database_type, 'complex')
)
except Exception as e:
logger.error(f"LLM Cypher generation failed: {e}")
return None
def _build_cypher_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
database_type: str) -> str:
"""Build prompt for LLM Cypher generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
database_type: Target database type
Returns:
Formatted prompt string
"""
# Format ontology elements as node labels and relationships
node_labels = self._format_node_labels(ontology_subset.classes)
relationships = self._format_relationships(
ontology_subset.object_properties,
ontology_subset.datatype_properties
)
prompt = f"""Generate a Cypher query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
TARGET DATABASE: {database_type}
AVAILABLE NODE LABELS (from classes):
{node_labels}
AVAILABLE RELATIONSHIP TYPES (from properties):
{relationships}
RULES:
- Use MATCH patterns for graph traversal
- Include WHERE clauses for filters
- Use aggregation functions when needed (COUNT, SUM, AVG)
- Optimize for {database_type} performance
- Consider index hints for large datasets
- Use parameters for values (e.g., $name)
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
DATABASE-SPECIFIC OPTIMIZATIONS:
{self._get_database_specific_hints(database_type)}
Generate a complete Cypher query with parameters:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> CypherQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic Cypher query
"""
# Very basic MATCH query
first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity'
query = f"""MATCH (n:{first_class})
WHERE n.name CONTAINS $keyword
RETURN n.name AS name, labels(n) AS types
LIMIT 10"""
return CypherQuery(
query=query,
parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'},
variables=['name', 'types'],
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching node label in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching node label or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _find_matching_relationship(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching relationship type.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Matching relationship type or None
"""
# Look for relationship keywords
for keyword in question_components.keywords:
keyword_lower = keyword.lower()
# Check object properties
for prop_id in ontology_subset.object_properties:
if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower:
return prop_id.upper().replace('-', '_')
# Common relationship mappings
relationship_mappings = {
'author': 'AUTHORED_BY',
'created': 'CREATED_BY',
'owns': 'OWNS',
'has': 'HAS',
'contains': 'CONTAINS',
'parent': 'PARENT_OF',
'child': 'CHILD_OF',
'related': 'RELATED_TO'
}
for keyword in question_components.keywords:
if keyword.lower() in relationship_mappings:
return relationship_mappings[keyword.lower()]
# Default relationship
return 'RELATED_TO'
def _build_where_clause(self, question_components: QuestionComponents) -> str:
"""Build WHERE clause for Cypher query.
Args:
question_components: Question analysis
Returns:
WHERE clause string
"""
conditions = []
for constraint in question_components.constraints:
if 'greater than' in constraint.lower():
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value > {numbers[0]}")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
conditions.append(f"n.value < {numbers[0]}")
if conditions:
return f"WHERE {' AND '.join(conditions)}"
return ""
def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]:
"""Extract parameters from question components.
Args:
question_components: Question analysis
Returns:
Parameters dictionary
"""
parameters = {}
# Extract numeric values
import re
for constraint in question_components.constraints:
numbers = re.findall(r'\d+', constraint)
for i, number in enumerate(numbers):
parameters[f'value_{i}'] = int(number)
return parameters
def _format_node_labels(self, classes: Dict[str, Any]) -> str:
"""Format classes as node labels for prompt.
Args:
classes: Classes dictionary
Returns:
Formatted node labels string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
lines.append(f"- :{class_id} - {comment}")
return '\n'.join(lines)
def _format_relationships(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any]) -> str:
"""Format properties as relationships for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
Returns:
Formatted relationships string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
rel_type = prop_id.upper().replace('-', '_')
lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from Cypher query.
Args:
query: Cypher query string
Returns:
List of variable names
"""
import re
# Extract RETURN clause variables
return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL)
if return_match:
return_clause = return_match.group(1)
variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause)
return [var[1] if var[1] else var[0] for var in variables]
return []
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for Cypher query.
Args:
query: Cypher query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
query_upper = query.upper()
# Count different Cypher features
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'WHERE' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Variable length paths
complexity += 0.2
if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
complexity += 0.2
# Count path length
path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query)
for start, end in path_matches:
complexity += (int(end) - int(start)) * 0.05
return min(complexity, 1.0)
def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]:
"""Get database-specific optimization hints.
Args:
database_type: Target database
query_category: Category of query
Returns:
Optimization hints
"""
hints = {}
if database_type == "neo4j":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'profile_query': 'PROFILE'
})
elif database_type == "memgraph":
hints.update({
'use_index': True,
'explain_plan': 'EXPLAIN',
'memory_limit': '1GB'
})
elif database_type == "falkordb":
hints.update({
'use_index': False, # Redis-based, different indexing
'cache_result': True
})
return hints
def _get_database_specific_hints(self, database_type: str) -> str:
"""Get database-specific optimization hints as text.
Args:
database_type: Target database
Returns:
Hints as formatted string
"""
if database_type == "neo4j":
return """- Use USING INDEX hints for large datasets
- Consider PROFILE for query optimization
- Prefer MERGE over CREATE when appropriate"""
elif database_type == "memgraph":
return """- Leverage in-memory processing advantages
- Use streaming for large result sets
- Consider query parallelization"""
elif database_type == "falkordb":
return """- Optimize for Redis memory constraints
- Use simple patterns for best performance
- Leverage Redis data structures when possible"""
else:
return "- Use standard Cypher optimization patterns"

View file

@ -0,0 +1,557 @@
"""
Error handling and recovery mechanisms for OntoRAG.
Provides comprehensive error handling, retry logic, and graceful degradation.
"""
import logging
import time
import asyncio
from typing import Dict, Any, List, Optional, Callable, Union, Type
from dataclasses import dataclass
from enum import Enum
from functools import wraps
import traceback
logger = logging.getLogger(__name__)
class ErrorSeverity(Enum):
"""Error severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ErrorCategory(Enum):
"""Error categories for better handling."""
ONTOLOGY_LOADING = "ontology_loading"
QUESTION_ANALYSIS = "question_analysis"
QUERY_GENERATION = "query_generation"
QUERY_EXECUTION = "query_execution"
ANSWER_GENERATION = "answer_generation"
BACKEND_CONNECTION = "backend_connection"
CACHE_ERROR = "cache_error"
VALIDATION_ERROR = "validation_error"
TIMEOUT_ERROR = "timeout_error"
AUTHENTICATION_ERROR = "authentication_error"
@dataclass
class ErrorContext:
"""Context information for an error."""
category: ErrorCategory
severity: ErrorSeverity
component: str
operation: str
user_message: Optional[str] = None
technical_details: Optional[str] = None
suggestion: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
metadata: Dict[str, Any] = None
class OntoRAGError(Exception):
"""Base exception for OntoRAG system."""
def __init__(self,
message: str,
context: Optional[ErrorContext] = None,
cause: Optional[Exception] = None):
"""Initialize OntoRAG error.
Args:
message: Error message
context: Error context
cause: Original exception that caused this error
"""
super().__init__(message)
self.message = message
self.context = context or ErrorContext(
category=ErrorCategory.VALIDATION_ERROR,
severity=ErrorSeverity.MEDIUM,
component="unknown",
operation="unknown"
)
self.cause = cause
self.timestamp = time.time()
class OntologyLoadingError(OntoRAGError):
"""Error loading ontology."""
pass
class QuestionAnalysisError(OntoRAGError):
"""Error analyzing question."""
pass
class QueryGenerationError(OntoRAGError):
"""Error generating query."""
pass
class QueryExecutionError(OntoRAGError):
"""Error executing query."""
pass
class AnswerGenerationError(OntoRAGError):
"""Error generating answer."""
pass
class BackendConnectionError(OntoRAGError):
"""Error connecting to backend."""
pass
class TimeoutError(OntoRAGError):
"""Operation timeout error."""
pass
@dataclass
class RetryConfig:
"""Configuration for retry logic."""
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
exponential_backoff: bool = True
jitter: bool = True
retry_on_exceptions: List[Type[Exception]] = None
class ErrorRecoveryStrategy:
"""Strategy for handling and recovering from errors."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error recovery strategy.
Args:
config: Recovery configuration
"""
self.config = config or {}
self.retry_configs = self._build_retry_configs()
self.fallback_strategies = self._build_fallback_strategies()
self.error_counters: Dict[str, int] = {}
self.circuit_breakers: Dict[str, Dict[str, Any]] = {}
def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]:
"""Build retry configurations for different error categories."""
return {
ErrorCategory.BACKEND_CONNECTION: RetryConfig(
max_retries=5,
base_delay=2.0,
retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError]
),
ErrorCategory.QUERY_EXECUTION: RetryConfig(
max_retries=3,
base_delay=1.0,
retry_on_exceptions=[QueryExecutionError, TimeoutError]
),
ErrorCategory.ONTOLOGY_LOADING: RetryConfig(
max_retries=2,
base_delay=0.5,
retry_on_exceptions=[OntologyLoadingError, IOError]
),
ErrorCategory.QUESTION_ANALYSIS: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[QuestionAnalysisError, TimeoutError]
),
ErrorCategory.ANSWER_GENERATION: RetryConfig(
max_retries=2,
base_delay=1.0,
retry_on_exceptions=[AnswerGenerationError, TimeoutError]
)
}
def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]:
"""Build fallback strategies for different error categories."""
return {
ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis,
ErrorCategory.QUERY_GENERATION: self._fallback_query_generation,
ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution,
ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation,
ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection
}
async def handle_error(self,
error: Exception,
context: ErrorContext,
operation: Callable,
*args,
**kwargs) -> Any:
"""Handle error with recovery strategies.
Args:
error: The exception that occurred
context: Error context
operation: Function to retry
*args: Operation arguments
**kwargs: Operation keyword arguments
Returns:
Result of successful operation or fallback
"""
logger.error(f"Handling error in {context.component}.{context.operation}: {error}")
# Update error counters
error_key = f"{context.category.value}:{context.component}"
self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1
# Check circuit breaker
if self._is_circuit_open(error_key):
return await self._execute_fallback(context, *args, **kwargs)
# Try retry if configured
retry_config = self.retry_configs.get(context.category)
if retry_config and context.retry_count < retry_config.max_retries:
if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []):
return await self._retry_operation(
operation, context, retry_config, *args, **kwargs
)
# Execute fallback strategy
return await self._execute_fallback(context, *args, **kwargs)
async def _retry_operation(self,
operation: Callable,
context: ErrorContext,
retry_config: RetryConfig,
*args,
**kwargs) -> Any:
"""Retry operation with backoff."""
context.retry_count += 1
# Calculate delay
delay = retry_config.base_delay
if retry_config.exponential_backoff:
delay *= (2 ** (context.retry_count - 1))
delay = min(delay, retry_config.max_delay)
# Add jitter
if retry_config.jitter:
import random
delay *= (0.5 + random.random())
logger.info(f"Retrying {context.component}.{context.operation} "
f"(attempt {context.retry_count}) after {delay:.2f}s")
await asyncio.sleep(delay)
try:
if asyncio.iscoroutinefunction(operation):
return await operation(*args, **kwargs)
else:
return operation(*args, **kwargs)
except Exception as e:
return await self.handle_error(e, context, operation, *args, **kwargs)
async def _execute_fallback(self,
context: ErrorContext,
*args,
**kwargs) -> Any:
"""Execute fallback strategy."""
fallback_func = self.fallback_strategies.get(context.category)
if fallback_func:
logger.info(f"Executing fallback for {context.category.value}")
try:
if asyncio.iscoroutinefunction(fallback_func):
return await fallback_func(context, *args, **kwargs)
else:
return fallback_func(context, *args, **kwargs)
except Exception as e:
logger.error(f"Fallback strategy failed: {e}")
# Default fallback
return self._default_fallback(context)
def _is_circuit_open(self, error_key: str) -> bool:
"""Check if circuit breaker is open."""
circuit = self.circuit_breakers.get(error_key, {})
error_count = self.error_counters.get(error_key, 0)
error_threshold = self.config.get('circuit_breaker_threshold', 10)
window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes
current_time = time.time()
window_start = circuit.get('window_start', current_time)
# Reset window if expired
if current_time - window_start > window_seconds:
self.circuit_breakers[error_key] = {'window_start': current_time}
self.error_counters[error_key] = 0
return False
return error_count >= error_threshold
def _default_fallback(self, context: ErrorContext) -> Any:
"""Default fallback response."""
if context.category == ErrorCategory.ANSWER_GENERATION:
return "I'm sorry, I encountered an error while processing your question. Please try again."
elif context.category == ErrorCategory.QUERY_EXECUTION:
return {"error": "Query execution failed", "results": []}
else:
return None
# Fallback strategy implementations
async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs):
"""Fallback for question analysis."""
from .question_analyzer import QuestionComponents, QuestionType
# Simple keyword-based analysis
question_lower = question.lower()
# Determine question type
if any(word in question_lower for word in ['how many', 'count', 'number']):
question_type = QuestionType.AGGREGATION
elif question_lower.startswith(('is', 'are', 'does', 'can')):
question_type = QuestionType.BOOLEAN
elif any(word in question_lower for word in ['what', 'which', 'who', 'where']):
question_type = QuestionType.RETRIEVAL
else:
question_type = QuestionType.FACTUAL
# Extract simple entities (nouns)
import re
words = re.findall(r'\b[a-zA-Z]+\b', question)
entities = [word for word in words if len(word) > 3 and word.lower() not in
{'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}]
return QuestionComponents(
original_question=question,
normalized_question=question.lower(),
question_type=question_type,
entities=entities[:3], # Limit to 3 entities
keywords=words[:5], # Limit to 5 keywords
relationships=[],
constraints=[],
aggregations=['count'] if question_type == QuestionType.AGGREGATION else [],
expected_answer_type='text'
)
async def _fallback_query_generation(self, context: ErrorContext, **kwargs):
"""Fallback for query generation."""
# Generate simple query based on available information
if 'sparql' in context.metadata.get('query_language', '').lower():
query = """
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {
?subject ?predicate ?object .
}
LIMIT 10
"""
from .sparql_generator import SPARQLQuery
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation='Fallback SPARQL query',
complexity_score=0.1
)
else:
query = "MATCH (n) RETURN n LIMIT 10"
from .cypher_generator import CypherQuery
return CypherQuery(
query=query,
variables=['n'],
query_type='MATCH',
explanation='Fallback Cypher query',
complexity_score=0.1
)
async def _fallback_query_execution(self, context: ErrorContext, **kwargs):
"""Fallback for query execution."""
# Return empty results
if 'sparql' in context.metadata.get('query_language', '').lower():
from .sparql_cassandra import SPARQLResult
return SPARQLResult(
bindings=[],
variables=[],
execution_time=0.0
)
else:
from .cypher_executor import CypherResult
return CypherResult(
records=[],
summary={'type': 'fallback'},
metadata={'query': 'fallback'},
execution_time=0.0
)
async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs):
"""Fallback for answer generation."""
fallback_messages = [
"I'm experiencing some technical difficulties. Please try rephrasing your question.",
"I couldn't process your question at the moment. Could you try asking it differently?",
"There seems to be an issue with my analysis. Please try again in a moment.",
"I'm having trouble understanding your question right now. Please try again."
]
import random
return random.choice(fallback_messages)
async def _fallback_backend_connection(self, context: ErrorContext, **kwargs):
"""Fallback for backend connection."""
logger.warning(f"Backend connection failed for {context.component}")
# Could switch to alternative backend here
return None
def with_error_handling(category: ErrorCategory,
component: str,
operation: str,
severity: ErrorSeverity = ErrorSeverity.MEDIUM):
"""Decorator for automatic error handling.
Args:
category: Error category
component: Component name
operation: Operation name
severity: Error severity
"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
# Get error recovery strategy from first argument if it's available
error_strategy = None
if args and hasattr(args[0], '_error_strategy'):
error_strategy = args[0]._error_strategy
if error_strategy:
return await error_strategy.handle_error(e, context, func, *args, **kwargs)
else:
# Re-raise as OntoRAG error
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
context = ErrorContext(
category=category,
severity=severity,
component=component,
operation=operation,
technical_details=str(e),
metadata={'args': str(args), 'kwargs': str(kwargs)}
)
raise OntoRAGError(
f"Error in {component}.{operation}: {str(e)}",
context=context,
cause=e
)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
class ErrorReporter:
"""Reports and tracks errors for monitoring and debugging."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize error reporter.
Args:
config: Reporter configuration
"""
self.config = config or {}
self.error_log: List[Dict[str, Any]] = []
self.max_log_size = self.config.get('max_log_size', 1000)
def report_error(self, error: OntoRAGError):
"""Report an error for tracking.
Args:
error: The error to report
"""
error_entry = {
'timestamp': error.timestamp,
'message': error.message,
'category': error.context.category.value,
'severity': error.context.severity.value,
'component': error.context.component,
'operation': error.context.operation,
'retry_count': error.context.retry_count,
'technical_details': error.context.technical_details,
'stack_trace': traceback.format_exc() if error.cause else None
}
self.error_log.append(error_entry)
# Trim log if too large
if len(self.error_log) > self.max_log_size:
self.error_log = self.error_log[-self.max_log_size:]
# Log based on severity
if error.context.severity == ErrorSeverity.CRITICAL:
logger.critical(f"CRITICAL ERROR: {error.message}")
elif error.context.severity == ErrorSeverity.HIGH:
logger.error(f"HIGH SEVERITY: {error.message}")
elif error.context.severity == ErrorSeverity.MEDIUM:
logger.warning(f"MEDIUM SEVERITY: {error.message}")
else:
logger.info(f"LOW SEVERITY: {error.message}")
def get_error_summary(self) -> Dict[str, Any]:
"""Get summary of recent errors.
Returns:
Error summary statistics
"""
if not self.error_log:
return {'total_errors': 0}
recent_errors = [
e for e in self.error_log
if time.time() - e['timestamp'] < 3600 # Last hour
]
category_counts = {}
severity_counts = {}
component_counts = {}
for error in recent_errors:
category_counts[error['category']] = category_counts.get(error['category'], 0) + 1
severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1
component_counts[error['component']] = component_counts.get(error['component'], 0) + 1
return {
'total_errors': len(self.error_log),
'recent_errors': len(recent_errors),
'category_breakdown': category_counts,
'severity_breakdown': severity_counts,
'component_breakdown': component_counts,
'most_recent_error': self.error_log[-1] if self.error_log else None
}

View file

@ -0,0 +1,737 @@
"""
Performance monitoring and metrics collection for OntoRAG.
Provides comprehensive monitoring of system performance, query patterns, and resource usage.
"""
import logging
import time
import asyncio
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from collections import defaultdict, deque
import statistics
from enum import Enum
logger = logging.getLogger(__name__)
class MetricType(Enum):
"""Types of metrics to collect."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
TIMER = "timer"
@dataclass
class Metric:
"""Individual metric data point."""
name: str
value: float
timestamp: datetime
labels: Dict[str, str] = field(default_factory=dict)
metric_type: MetricType = MetricType.GAUGE
@dataclass
class TimerMetric:
"""Timer metric for measuring duration."""
name: str
start_time: float
labels: Dict[str, str] = field(default_factory=dict)
def stop(self) -> float:
"""Stop timer and return duration."""
return time.time() - self.start_time
@dataclass
class PerformanceStats:
"""Performance statistics for a component."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
avg_response_time: float = 0.0
min_response_time: float = float('inf')
max_response_time: float = 0.0
p95_response_time: float = 0.0
p99_response_time: float = 0.0
throughput_per_second: float = 0.0
error_rate: float = 0.0
@dataclass
class SystemHealth:
"""Overall system health metrics."""
status: str = "healthy" # healthy, degraded, unhealthy
uptime_seconds: float = 0.0
cpu_usage_percent: float = 0.0
memory_usage_percent: float = 0.0
active_connections: int = 0
queue_size: int = 0
cache_hit_rate: float = 0.0
error_rate: float = 0.0
class MetricsCollector:
"""Collects and stores metrics data."""
def __init__(self, max_metrics: int = 10000, retention_hours: int = 24):
"""Initialize metrics collector.
Args:
max_metrics: Maximum number of metrics to retain
retention_hours: Hours to retain metrics
"""
self.max_metrics = max_metrics
self.retention_hours = retention_hours
self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics))
self.counters: Dict[str, float] = defaultdict(float)
self.gauges: Dict[str, float] = defaultdict(float)
self.timers: Dict[str, List[float]] = defaultdict(list)
self._lock = threading.RLock()
def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
"""Increment a counter metric.
Args:
name: Metric name
value: Value to increment by
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.counters[metric_key] += value
self._add_metric(name, value, MetricType.COUNTER, labels)
def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None):
"""Set a gauge metric value.
Args:
name: Metric name
value: Gauge value
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.gauges[metric_key] = value
self._add_metric(name, value, MetricType.GAUGE, labels)
def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None):
"""Record a timer measurement.
Args:
name: Metric name
duration: Duration in seconds
labels: Metric labels
"""
with self._lock:
metric_key = self._build_key(name, labels)
self.timers[metric_key].append(duration)
# Keep only recent measurements
max_timer_values = 1000
if len(self.timers[metric_key]) > max_timer_values:
self.timers[metric_key] = self.timers[metric_key][-max_timer_values:]
self._add_metric(name, duration, MetricType.TIMER, labels)
def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric:
"""Start a timer.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer metric object
"""
return TimerMetric(name=name, start_time=time.time(), labels=labels or {})
def stop_timer(self, timer: TimerMetric):
"""Stop a timer and record the measurement.
Args:
timer: Timer metric to stop
"""
duration = timer.stop()
self.record_timer(timer.name, duration, timer.labels)
return duration
def get_counter(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get counter value.
Args:
name: Metric name
labels: Metric labels
Returns:
Counter value
"""
metric_key = self._build_key(name, labels)
return self.counters.get(metric_key, 0.0)
def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float:
"""Get gauge value.
Args:
name: Metric name
labels: Metric labels
Returns:
Gauge value
"""
metric_key = self._build_key(name, labels)
return self.gauges.get(metric_key, 0.0)
def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]:
"""Get timer statistics.
Args:
name: Metric name
labels: Metric labels
Returns:
Timer statistics
"""
metric_key = self._build_key(name, labels)
values = self.timers.get(metric_key, [])
if not values:
return {}
sorted_values = sorted(values)
return {
'count': len(values),
'sum': sum(values),
'avg': statistics.mean(values),
'min': min(values),
'max': max(values),
'p50': sorted_values[int(len(sorted_values) * 0.5)],
'p95': sorted_values[int(len(sorted_values) * 0.95)],
'p99': sorted_values[int(len(sorted_values) * 0.99)]
}
def get_metrics(self,
name_pattern: Optional[str] = None,
since: Optional[datetime] = None) -> List[Metric]:
"""Get metrics matching pattern and time range.
Args:
name_pattern: Pattern to match metric names
since: Only return metrics since this time
Returns:
List of matching metrics
"""
with self._lock:
results = []
cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours)
for metric_name, metric_queue in self.metrics.items():
if name_pattern and name_pattern not in metric_name:
continue
for metric in metric_queue:
if metric.timestamp >= cutoff_time:
results.append(metric)
return sorted(results, key=lambda m: m.timestamp)
def cleanup_old_metrics(self):
"""Remove old metrics beyond retention period."""
with self._lock:
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
for metric_name in list(self.metrics.keys()):
metric_queue = self.metrics[metric_name]
# Remove old metrics
while metric_queue and metric_queue[0].timestamp < cutoff_time:
metric_queue.popleft()
# Remove empty queues
if not metric_queue:
del self.metrics[metric_name]
def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]):
"""Add metric to storage."""
metric = Metric(
name=name,
value=value,
timestamp=datetime.now(),
labels=labels or {},
metric_type=metric_type
)
self.metrics[name].append(metric)
def _build_key(self, name: str, labels: Dict[str, str]) -> str:
"""Build metric key from name and labels."""
if not labels:
return name
label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items()))
return f"{name}{{{label_str}}}"
class PerformanceMonitor:
"""Monitors system performance and component health."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize performance monitor.
Args:
config: Monitor configuration
"""
self.config = config or {}
self.metrics_collector = MetricsCollector(
max_metrics=self.config.get('max_metrics', 10000),
retention_hours=self.config.get('retention_hours', 24)
)
self.component_stats: Dict[str, PerformanceStats] = {}
self.start_time = time.time()
self.monitoring_enabled = self.config.get('enabled', True)
# Start background monitoring tasks
if self.monitoring_enabled:
self._start_background_tasks()
def record_request(self,
component: str,
operation: str,
duration: float,
success: bool = True,
labels: Dict[str, str] = None):
"""Record a request completion.
Args:
component: Component name
operation: Operation name
duration: Request duration in seconds
success: Whether request was successful
labels: Additional labels
"""
if not self.monitoring_enabled:
return
base_labels = {'component': component, 'operation': operation}
if labels:
base_labels.update(labels)
# Record metrics
self.metrics_collector.increment('requests_total', labels=base_labels)
self.metrics_collector.record_timer('request_duration', duration, base_labels)
if success:
self.metrics_collector.increment('requests_successful', labels=base_labels)
else:
self.metrics_collector.increment('requests_failed', labels=base_labels)
# Update component stats
self._update_component_stats(component, duration, success)
def record_query_complexity(self,
complexity_score: float,
query_type: str,
backend: str):
"""Record query complexity metrics.
Args:
complexity_score: Query complexity score (0.0 to 1.0)
query_type: Type of query (SPARQL, Cypher)
backend: Backend used
"""
if not self.monitoring_enabled:
return
labels = {'query_type': query_type, 'backend': backend}
self.metrics_collector.set_gauge('query_complexity', complexity_score, labels)
def record_cache_access(self, hit: bool, cache_type: str = 'default'):
"""Record cache access.
Args:
hit: Whether it was a cache hit
cache_type: Type of cache
"""
if not self.monitoring_enabled:
return
labels = {'cache_type': cache_type}
self.metrics_collector.increment('cache_requests_total', labels=labels)
if hit:
self.metrics_collector.increment('cache_hits_total', labels=labels)
else:
self.metrics_collector.increment('cache_misses_total', labels=labels)
def record_ontology_selection(self,
selected_elements: int,
total_elements: int,
ontology_id: str):
"""Record ontology selection metrics.
Args:
selected_elements: Number of selected ontology elements
total_elements: Total ontology elements
ontology_id: Ontology identifier
"""
if not self.monitoring_enabled:
return
labels = {'ontology_id': ontology_id}
self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels)
self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels)
selection_ratio = selected_elements / total_elements if total_elements > 0 else 0
self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels)
def get_component_stats(self, component: str) -> Optional[PerformanceStats]:
"""Get performance statistics for a component.
Args:
component: Component name
Returns:
Performance statistics or None
"""
return self.component_stats.get(component)
def get_system_health(self) -> SystemHealth:
"""Get overall system health status.
Returns:
System health metrics
"""
# Calculate uptime
uptime = time.time() - self.start_time
# Get error rate
total_requests = self.metrics_collector.get_counter('requests_total')
failed_requests = self.metrics_collector.get_counter('requests_failed')
error_rate = failed_requests / total_requests if total_requests > 0 else 0.0
# Get cache hit rate
cache_hits = self.metrics_collector.get_counter('cache_hits_total')
cache_requests = self.metrics_collector.get_counter('cache_requests_total')
cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0
# Determine status
status = "healthy"
if error_rate > 0.1: # More than 10% error rate
status = "degraded"
if error_rate > 0.3: # More than 30% error rate
status = "unhealthy"
return SystemHealth(
status=status,
uptime_seconds=uptime,
error_rate=error_rate,
cache_hit_rate=cache_hit_rate
)
def get_performance_report(self) -> Dict[str, Any]:
"""Get comprehensive performance report.
Returns:
Performance report
"""
report = {
'system_health': self.get_system_health(),
'component_stats': {},
'top_slow_operations': [],
'error_patterns': {},
'cache_performance': {},
'ontology_usage': {}
}
# Component statistics
for component, stats in self.component_stats.items():
report['component_stats'][component] = stats
# Top slow operations
timer_stats = {}
for metric_name in self.metrics_collector.timers.keys():
if 'request_duration' in metric_name:
stats = self.metrics_collector.get_timer_stats(metric_name)
if stats:
timer_stats[metric_name] = stats
# Sort by p95 latency
slow_ops = sorted(
timer_stats.items(),
key=lambda x: x[1].get('p95', 0),
reverse=True
)[:10]
report['top_slow_operations'] = [
{'operation': op, 'stats': stats} for op, stats in slow_ops
]
# Cache performance
cache_types = set()
for metric_name in self.metrics_collector.counters.keys():
if 'cache_type=' in metric_name:
cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0]
cache_types.add(cache_type)
for cache_type in cache_types:
labels = {'cache_type': cache_type}
hits = self.metrics_collector.get_counter('cache_hits_total', labels)
requests = self.metrics_collector.get_counter('cache_requests_total', labels)
hit_rate = hits / requests if requests > 0 else 0.0
report['cache_performance'][cache_type] = {
'hit_rate': hit_rate,
'total_requests': requests,
'total_hits': hits
}
return report
def _update_component_stats(self, component: str, duration: float, success: bool):
"""Update component performance statistics."""
if component not in self.component_stats:
self.component_stats[component] = PerformanceStats()
stats = self.component_stats[component]
stats.total_requests += 1
if success:
stats.successful_requests += 1
else:
stats.failed_requests += 1
# Update response time stats
stats.min_response_time = min(stats.min_response_time, duration)
stats.max_response_time = max(stats.max_response_time, duration)
# Get timer stats for percentiles
timer_stats = self.metrics_collector.get_timer_stats(
'request_duration', {'component': component}
)
if timer_stats:
stats.avg_response_time = timer_stats.get('avg', 0.0)
stats.p95_response_time = timer_stats.get('p95', 0.0)
stats.p99_response_time = timer_stats.get('p99', 0.0)
# Calculate rates
stats.error_rate = stats.failed_requests / stats.total_requests
# Calculate throughput (requests per second over last minute)
recent_requests = len([
m for m in self.metrics_collector.get_metrics('requests_total')
if m.labels.get('component') == component and
m.timestamp > datetime.now() - timedelta(minutes=1)
])
stats.throughput_per_second = recent_requests / 60.0
def _start_background_tasks(self):
"""Start background monitoring tasks."""
def cleanup_worker():
"""Worker to clean up old metrics."""
while True:
try:
time.sleep(300) # 5 minutes
self.metrics_collector.cleanup_old_metrics()
except Exception as e:
logger.error(f"Metrics cleanup error: {e}")
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
cleanup_thread.start()
# Monitoring decorators
def monitor_performance(component: str,
operation: str,
monitor: Optional[PerformanceMonitor] = None):
"""Decorator to monitor function performance.
Args:
component: Component name
operation: Operation name
monitor: Performance monitor instance
"""
def decorator(func):
def wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
async def async_wrapper(*args, **kwargs):
if not monitor or not monitor.monitoring_enabled:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
timer = monitor.metrics_collector.start_timer(
'request_duration',
{'component': component, 'operation': operation}
)
success = True
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
duration = monitor.metrics_collector.stop_timer(timer)
monitor.record_request(component, operation, duration, success)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return wrapper
return decorator
class QueryPatternAnalyzer:
"""Analyzes query patterns for optimization insights."""
def __init__(self, monitor: PerformanceMonitor):
"""Initialize query pattern analyzer.
Args:
monitor: Performance monitor instance
"""
self.monitor = monitor
self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
def record_query_pattern(self,
question_type: str,
entities: List[str],
complexity: float,
backend: str,
duration: float,
success: bool):
"""Record a query pattern for analysis.
Args:
question_type: Type of question
entities: Entities in question
complexity: Query complexity score
backend: Backend used
duration: Query duration
success: Whether query succeeded
"""
pattern = {
'timestamp': datetime.now(),
'question_type': question_type,
'entity_count': len(entities),
'entities': entities,
'complexity': complexity,
'backend': backend,
'duration': duration,
'success': success
}
pattern_key = f"{question_type}:{len(entities)}"
self.query_patterns[pattern_key].append(pattern)
# Keep only recent patterns
cutoff_time = datetime.now() - timedelta(hours=24)
self.query_patterns[pattern_key] = [
p for p in self.query_patterns[pattern_key]
if p['timestamp'] > cutoff_time
]
def get_optimization_insights(self) -> Dict[str, Any]:
"""Get insights for query optimization.
Returns:
Optimization insights and recommendations
"""
insights = {
'slow_patterns': [],
'common_failures': [],
'backend_performance': {},
'complexity_analysis': {},
'recommendations': []
}
# Analyze slow patterns
for pattern_key, patterns in self.query_patterns.items():
if not patterns:
continue
avg_duration = statistics.mean([p['duration'] for p in patterns])
success_rate = sum(1 for p in patterns if p['success']) / len(patterns)
if avg_duration > 5.0: # Slow queries > 5 seconds
insights['slow_patterns'].append({
'pattern': pattern_key,
'avg_duration': avg_duration,
'count': len(patterns),
'success_rate': success_rate
})
if success_rate < 0.8: # Low success rate
insights['common_failures'].append({
'pattern': pattern_key,
'success_rate': success_rate,
'count': len(patterns)
})
# Analyze backend performance
backend_stats = defaultdict(list)
for patterns in self.query_patterns.values():
for pattern in patterns:
backend_stats[pattern['backend']].append(pattern['duration'])
for backend, durations in backend_stats.items():
insights['backend_performance'][backend] = {
'avg_duration': statistics.mean(durations),
'p95_duration': sorted(durations)[int(len(durations) * 0.95)],
'query_count': len(durations)
}
# Generate recommendations
recommendations = []
# Slow pattern recommendations
for slow_pattern in insights['slow_patterns']:
recommendations.append(
f"Consider optimizing {slow_pattern['pattern']} queries - "
f"average duration {slow_pattern['avg_duration']:.2f}s"
)
# Backend recommendations
if len(insights['backend_performance']) > 1:
fastest_backend = min(
insights['backend_performance'].items(),
key=lambda x: x[1]['avg_duration']
)[0]
recommendations.append(
f"Consider routing more queries to {fastest_backend} "
f"for better performance"
)
insights['recommendations'] = recommendations
return insights

View file

@ -0,0 +1,656 @@
"""
Multi-language support for OntoRAG.
Provides language detection, translation, and multilingual query processing.
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class Language(Enum):
"""Supported languages."""
ENGLISH = "en"
SPANISH = "es"
FRENCH = "fr"
GERMAN = "de"
ITALIAN = "it"
PORTUGUESE = "pt"
CHINESE = "zh"
JAPANESE = "ja"
KOREAN = "ko"
ARABIC = "ar"
RUSSIAN = "ru"
DUTCH = "nl"
@dataclass
class LanguageDetectionResult:
"""Language detection result."""
language: Language
confidence: float
detected_text: str
alternative_languages: List[Tuple[Language, float]] = None
@dataclass
class TranslationResult:
"""Translation result."""
original_text: str
translated_text: str
source_language: Language
target_language: Language
confidence: float
class LanguageDetector:
"""Detects language of input text."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize language detector.
Args:
config: Detector configuration
"""
self.config = config or {}
self.default_language = Language(self.config.get('default_language', 'en'))
self.confidence_threshold = self.config.get('confidence_threshold', 0.7)
# Try to import language detection libraries
self.detector = None
self._init_detector()
def _init_detector(self):
"""Initialize language detection backend."""
try:
# Try langdetect first
import langdetect
self.detector = 'langdetect'
logger.info("Using langdetect for language detection")
except ImportError:
try:
# Try textblob as fallback
from textblob import TextBlob
self.detector = 'textblob'
logger.info("Using TextBlob for language detection")
except ImportError:
logger.warning("No language detection library available, using rule-based detection")
self.detector = 'rule_based'
def detect_language(self, text: str) -> LanguageDetectionResult:
"""Detect language of input text.
Args:
text: Text to analyze
Returns:
Language detection result
"""
if not text or not text.strip():
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
try:
if self.detector == 'langdetect':
return self._detect_with_langdetect(text)
elif self.detector == 'textblob':
return self._detect_with_textblob(text)
else:
return self._detect_with_rules(text)
except Exception as e:
logger.error(f"Language detection failed: {e}")
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult:
"""Detect language using langdetect library."""
import langdetect
from langdetect.lang_detect_exception import LangDetectException
try:
# Get detailed detection results
probabilities = langdetect.detect_langs(text)
if not probabilities:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
best_match = probabilities[0]
detected_lang_code = best_match.lang
confidence = best_match.prob
# Map to our Language enum
try:
detected_language = Language(detected_lang_code)
except ValueError:
# Map common variations
lang_mapping = {
'ca': Language.SPANISH, # Catalan -> Spanish
'eu': Language.SPANISH, # Basque -> Spanish
'gl': Language.SPANISH, # Galician -> Spanish
'zh-cn': Language.CHINESE,
'zh-tw': Language.CHINESE,
}
detected_language = lang_mapping.get(detected_lang_code, self.default_language)
# Get alternatives
alternatives = []
for lang_prob in probabilities[1:3]: # Top 3 alternatives
try:
alt_lang = Language(lang_prob.lang)
alternatives.append((alt_lang, lang_prob.prob))
except ValueError:
continue
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text,
alternative_languages=alternatives
)
except LangDetectException:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_textblob(self, text: str) -> LanguageDetectionResult:
"""Detect language using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
detected_lang_code = blob.detect_language()
try:
detected_language = Language(detected_lang_code)
except ValueError:
detected_language = self.default_language
# TextBlob doesn't provide confidence, so estimate based on text length
confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5
return LanguageDetectionResult(
language=detected_language,
confidence=confidence,
detected_text=text
)
except Exception:
return LanguageDetectionResult(
language=self.default_language,
confidence=0.0,
detected_text=text
)
def _detect_with_rules(self, text: str) -> LanguageDetectionResult:
"""Rule-based language detection fallback."""
text_lower = text.lower()
# Simple keyword-based detection
language_keywords = {
Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'],
Language.FRENCH: ['que', 'quel', 'quand', '', 'comment', 'pourquoi', 'combien'],
Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'],
Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'],
Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'],
Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel']
}
best_match = self.default_language
best_score = 0
for language, keywords in language_keywords.items():
score = sum(1 for keyword in keywords if keyword in text_lower)
if score > best_score:
best_score = score
best_match = language
confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1
return LanguageDetectionResult(
language=best_match,
confidence=confidence,
detected_text=text
)
class TextTranslator:
"""Translates text between languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize text translator.
Args:
config: Translator configuration
"""
self.config = config or {}
self.translator = None
self._init_translator()
def _init_translator(self):
"""Initialize translation backend."""
try:
# Try Google Translate first
from googletrans import Translator
self.translator = Translator()
self.backend = 'googletrans'
logger.info("Using Google Translate for translation")
except ImportError:
try:
# Try TextBlob as fallback
from textblob import TextBlob
self.backend = 'textblob'
logger.info("Using TextBlob for translation")
except ImportError:
logger.warning("No translation library available")
self.backend = None
def translate(self,
text: str,
target_language: Language,
source_language: Optional[Language] = None) -> TranslationResult:
"""Translate text to target language.
Args:
text: Text to translate
target_language: Target language
source_language: Source language (auto-detect if None)
Returns:
Translation result
"""
if not text or not text.strip():
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
try:
if self.backend == 'googletrans':
return self._translate_with_googletrans(text, target_language, source_language)
elif self.backend == 'textblob':
return self._translate_with_textblob(text, target_language, source_language)
else:
# No translation available
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
except Exception as e:
logger.error(f"Translation failed: {e}")
return TranslationResult(
original_text=text,
translated_text=text,
source_language=source_language or Language.ENGLISH,
target_language=target_language,
confidence=0.0
)
def _translate_with_googletrans(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using Google Translate."""
try:
src_code = source_language.value if source_language else 'auto'
dest_code = target_language.value
result = self.translator.translate(text, src=src_code, dest=dest_code)
detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH
confidence = 0.9 # Google Translate is generally reliable
return TranslationResult(
original_text=text,
translated_text=result.text,
source_language=detected_source,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"Google Translate error: {e}")
raise
def _translate_with_textblob(self,
text: str,
target_language: Language,
source_language: Optional[Language]) -> TranslationResult:
"""Translate using TextBlob."""
from textblob import TextBlob
try:
blob = TextBlob(text)
if not source_language:
# Auto-detect source language
detected_lang = blob.detect_language()
try:
source_language = Language(detected_lang)
except ValueError:
source_language = Language.ENGLISH
translated_blob = blob.translate(to=target_language.value)
translated_text = str(translated_blob)
# TextBlob confidence estimation
confidence = 0.7 if len(text) > 10 else 0.5
return TranslationResult(
original_text=text,
translated_text=translated_text,
source_language=source_language,
target_language=target_language,
confidence=confidence
)
except Exception as e:
logger.error(f"TextBlob translation error: {e}")
raise
class MultiLanguageQueryProcessor:
"""Processes queries in multiple languages."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize multi-language query processor.
Args:
config: Processor configuration
"""
self.config = config or {}
self.language_detector = LanguageDetector(config.get('language_detection', {}))
self.translator = TextTranslator(config.get('translation', {}))
self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])]
self.primary_language = Language(config.get('primary_language', 'en'))
async def process_multilingual_query(self, question: str) -> Dict[str, Any]:
"""Process a query in any supported language.
Args:
question: Question in any language
Returns:
Processing result with language information
"""
# Step 1: Detect language
detection_result = self.language_detector.detect_language(question)
detected_language = detection_result.language
logger.info(f"Detected language: {detected_language.value} "
f"(confidence: {detection_result.confidence:.2f})")
# Step 2: Translate to primary language if needed
translated_question = question
translation_result = None
if detected_language != self.primary_language:
if detection_result.confidence >= self.language_detector.confidence_threshold:
translation_result = self.translator.translate(
question, self.primary_language, detected_language
)
translated_question = translation_result.translated_text
logger.info(f"Translated question: {translated_question}")
else:
logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}")
# Step 3: Return processing information
return {
'original_question': question,
'translated_question': translated_question,
'detected_language': detected_language,
'detection_confidence': detection_result.confidence,
'translation_result': translation_result,
'processing_language': self.primary_language,
'alternative_languages': detection_result.alternative_languages
}
async def translate_answer(self,
answer: str,
target_language: Language) -> TranslationResult:
"""Translate answer back to target language.
Args:
answer: Answer in primary language
target_language: Target language for answer
Returns:
Translation result
"""
if target_language == self.primary_language:
# No translation needed
return TranslationResult(
original_text=answer,
translated_text=answer,
source_language=self.primary_language,
target_language=target_language,
confidence=1.0
)
return self.translator.translate(answer, target_language, self.primary_language)
def get_language_specific_ontology_terms(self,
ontology_subset: Dict[str, Any],
language: Language) -> Dict[str, Any]:
"""Get language-specific terms from ontology.
Args:
ontology_subset: Ontology subset
language: Target language
Returns:
Language-specific ontology terms
"""
# Extract language-specific labels and descriptions
lang_code = language.value
result = {}
# Process classes
if 'classes' in ontology_subset:
result['classes'] = {}
for class_id, class_def in ontology_subset['classes'].items():
lang_labels = []
if 'labels' in class_def:
for label in class_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result['classes'][class_id] = {
**class_def,
'language_labels': lang_labels
}
# Process properties
for prop_type in ['object_properties', 'datatype_properties']:
if prop_type in ontology_subset:
result[prop_type] = {}
for prop_id, prop_def in ontology_subset[prop_type].items():
lang_labels = []
if 'labels' in prop_def:
for label in prop_def['labels']:
if isinstance(label, dict) and label.get('language') == lang_code:
lang_labels.append(label['value'])
elif isinstance(label, str):
lang_labels.append(label)
result[prop_type][prop_id] = {
**prop_def,
'language_labels': lang_labels
}
return result
def is_language_supported(self, language: Language) -> bool:
"""Check if language is supported.
Args:
language: Language to check
Returns:
True if language is supported
"""
return language in self.supported_languages
def get_supported_languages(self) -> List[Language]:
"""Get list of supported languages.
Returns:
List of supported languages
"""
return self.supported_languages.copy()
def add_language_support(self, language: Language):
"""Add support for a new language.
Args:
language: Language to add support for
"""
if language not in self.supported_languages:
self.supported_languages.append(language)
logger.info(f"Added support for language: {language.value}")
def remove_language_support(self, language: Language):
"""Remove support for a language.
Args:
language: Language to remove support for
"""
if language in self.supported_languages and language != self.primary_language:
self.supported_languages.remove(language)
logger.info(f"Removed support for language: {language.value}")
else:
logger.warning(f"Cannot remove primary language or unsupported language: {language.value}")
class LanguageSpecificTemplates:
"""Manages language-specific query and answer templates."""
def __init__(self):
"""Initialize language-specific templates."""
self.question_templates = {
Language.ENGLISH: {
'count': ['how many', 'count of', 'number of'],
'boolean': ['is', 'are', 'does', 'can', 'will'],
'retrieval': ['what', 'which', 'who', 'where'],
'factual': ['tell me about', 'describe', 'explain']
},
Language.SPANISH: {
'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'],
'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'],
'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'],
'factual': ['dime sobre', 'describe', 'explica']
},
Language.FRENCH: {
'count': ['combien', 'nombre de', 'quantité de'],
'boolean': ['est', 'sont', 'peut', 'peuvent'],
'retrieval': ['que', 'quel', 'quelle', 'qui', ''],
'factual': ['dis-moi sur', 'décris', 'explique']
},
Language.GERMAN: {
'count': ['wie viele', 'anzahl der', 'zahl der'],
'boolean': ['ist', 'sind', 'kann', 'können'],
'retrieval': ['was', 'welche', 'wer', 'wo'],
'factual': ['erzähl mir über', 'beschreibe', 'erkläre']
}
}
self.answer_templates = {
Language.ENGLISH: {
'count': 'There are {count} {entity}.',
'boolean_true': 'Yes, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No information found.',
'error': 'Sorry, I encountered an error.'
},
Language.SPANISH: {
'count': 'Hay {count} {entity}.',
'boolean_true': 'Sí, {statement}.',
'boolean_false': 'No, {statement}.',
'not_found': 'No se encontró información.',
'error': 'Lo siento, encontré un error.'
},
Language.FRENCH: {
'count': 'Il y a {count} {entity}.',
'boolean_true': 'Oui, {statement}.',
'boolean_false': 'Non, {statement}.',
'not_found': 'Aucune information trouvée.',
'error': 'Désolé, j\'ai rencontré une erreur.'
},
Language.GERMAN: {
'count': 'Es gibt {count} {entity}.',
'boolean_true': 'Ja, {statement}.',
'boolean_false': 'Nein, {statement}.',
'not_found': 'Keine Informationen gefunden.',
'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.'
}
}
def get_question_patterns(self, language: Language) -> Dict[str, List[str]]:
"""Get question patterns for a language.
Args:
language: Target language
Returns:
Dictionary of question patterns
"""
return self.question_templates.get(language, self.question_templates[Language.ENGLISH])
def get_answer_template(self, language: Language, template_type: str) -> str:
"""Get answer template for a language and type.
Args:
language: Target language
template_type: Template type
Returns:
Answer template string
"""
templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH])
return templates.get(template_type, templates.get('error', 'Error'))
def format_answer(self,
language: Language,
template_type: str,
**kwargs) -> str:
"""Format answer using language-specific template.
Args:
language: Target language
template_type: Template type
**kwargs: Template variables
Returns:
Formatted answer
"""
template = self.get_answer_template(language, template_type)
try:
return template.format(**kwargs)
except KeyError as e:
logger.error(f"Missing template variable: {e}")
return self.get_answer_template(language, 'error')

View file

@ -0,0 +1,256 @@
"""
Ontology matcher for query system.
Identifies relevant ontology subsets for answering questions.
"""
import logging
from typing import List, Dict, Any, Set, Optional
from dataclasses import dataclass
from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader
from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder
from ...extract.kg.ontology.text_processor import TextSegment
from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset
from .question_analyzer import QuestionComponents, QuestionType
logger = logging.getLogger(__name__)
@dataclass
class QueryOntologySubset(OntologySubset):
"""Extended ontology subset for query processing."""
traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal
inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning
class OntologyMatcherForQueries(OntologySelector):
"""
Specialized ontology matcher for question answering.
Extends OntologySelector with query-specific logic.
"""
def __init__(self, ontology_embedder: OntologyEmbedder,
ontology_loader: OntologyLoader,
top_k: int = 15, # Higher k for queries
similarity_threshold: float = 0.6): # Lower threshold for broader coverage
"""Initialize query-specific ontology matcher.
Args:
ontology_embedder: Embedder with vector store
ontology_loader: Loader with ontology definitions
top_k: Number of top results to retrieve
similarity_threshold: Minimum similarity score
"""
super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold)
async def match_question_to_ontology(self,
question_components: QuestionComponents,
question_segments: List[str]) -> List[QueryOntologySubset]:
"""Match question components to relevant ontology elements.
Args:
question_components: Analyzed question components
question_segments: Text segments from question
Returns:
List of query-optimized ontology subsets
"""
# Convert question segments to TextSegment objects
text_segments = [
TextSegment(text=seg, type='question', position=i)
for i, seg in enumerate(question_segments)
]
# Get base ontology subsets using parent class method
base_subsets = await self.select_ontology_subset(text_segments)
# Enhance subsets for query processing
query_subsets = []
for subset in base_subsets:
query_subset = self._enhance_for_query(subset, question_components)
query_subsets.append(query_subset)
return query_subsets
def _enhance_for_query(self, subset: OntologySubset,
question_components: QuestionComponents) -> QueryOntologySubset:
"""Enhance ontology subset with query-specific elements.
Args:
subset: Base ontology subset
question_components: Analyzed question components
Returns:
Enhanced query ontology subset
"""
# Create query subset
query_subset = QueryOntologySubset(
ontology_id=subset.ontology_id,
classes=dict(subset.classes),
object_properties=dict(subset.object_properties),
datatype_properties=dict(subset.datatype_properties),
metadata=subset.metadata,
relevance_score=subset.relevance_score,
traversal_properties={},
inference_rules=[]
)
# Add traversal properties based on question type
self._add_traversal_properties(query_subset, question_components)
# Add related properties for exploration
self._add_related_properties(query_subset)
# Add inference rules if needed
self._add_inference_rules(query_subset, question_components)
return query_subset
def _add_traversal_properties(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add properties useful for graph traversal.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# For relationship questions, add all properties connecting mentioned classes
if question_components.question_type == QuestionType.RELATIONSHIP:
for prop_id, prop_def in ontology.object_properties.items():
domain = prop_def.domain
range_val = prop_def.range
# Check if property connects relevant classes
if domain in subset.classes or range_val in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
logger.debug(f"Added traversal property: {prop_id}")
# For retrieval questions, add properties that might filter results
elif question_components.question_type == QuestionType.RETRIEVAL:
# Add all properties with domains in our classes
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.object_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
for prop_id, prop_def in ontology.datatype_properties.items():
if prop_def.domain in subset.classes:
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
# For aggregation questions, ensure we have counting properties
elif question_components.question_type == QuestionType.AGGREGATION:
# Add properties that might be counted
for prop_id, prop_def in ontology.datatype_properties.items():
if 'count' in prop_id.lower() or 'number' in prop_id.lower():
if prop_id not in subset.datatype_properties:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_related_properties(self, subset: QueryOntologySubset):
"""Add properties related to already selected ones.
Args:
subset: Query ontology subset to enhance
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return
# Add inverse properties
for prop_id in list(subset.object_properties.keys()):
prop = ontology.object_properties.get(prop_id)
if prop and prop.inverse_of:
inverse_prop = ontology.object_properties.get(prop.inverse_of)
if inverse_prop and prop.inverse_of not in subset.object_properties:
subset.object_properties[prop.inverse_of] = inverse_prop.__dict__
logger.debug(f"Added inverse property: {prop.inverse_of}")
# Add sibling properties (same domain)
domains_in_subset = set()
for prop_def in subset.object_properties.values():
if 'domain' in prop_def and prop_def['domain']:
domains_in_subset.add(prop_def['domain'])
for domain in domains_in_subset:
for prop_id, prop_def in ontology.object_properties.items():
if prop_def.domain == domain and prop_id not in subset.object_properties:
# Add up to 3 sibling properties
if len(subset.traversal_properties) < 3:
subset.traversal_properties[prop_id] = prop_def.__dict__
def _add_inference_rules(self, subset: QueryOntologySubset,
question_components: QuestionComponents):
"""Add inference rules for reasoning.
Args:
subset: Query ontology subset to enhance
question_components: Question analysis
"""
# Add transitivity rules for subclass relationships
if any(cls.get('subclass_of') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'transitivity',
'property': 'rdfs:subClassOf',
'description': 'Subclass relationships are transitive'
})
# Add symmetry rules for equivalent classes
if any(cls.get('equivalent_classes') for cls in subset.classes.values()):
subset.inference_rules.append({
'type': 'symmetry',
'property': 'owl:equivalentClass',
'description': 'Equivalent class relationships are symmetric'
})
# Add inverse property rules
for prop_id, prop_def in subset.object_properties.items():
if 'inverse_of' in prop_def and prop_def['inverse_of']:
subset.inference_rules.append({
'type': 'inverse',
'property': prop_id,
'inverse': prop_def['inverse_of'],
'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}'
})
def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset:
"""Expand subset to include full class hierarchies.
Args:
subset: Query ontology subset
Returns:
Expanded subset with complete hierarchies
"""
ontology = self.loader.get_ontology(subset.ontology_id)
if not ontology:
return subset
# Add all parent and child classes
classes_to_add = set()
for class_id in list(subset.classes.keys()):
# Add all parents
parents = ontology.get_parent_classes(class_id)
for parent_id in parents:
if parent_id not in subset.classes:
parent_class = ontology.get_class(parent_id)
if parent_class:
classes_to_add.add(parent_id)
# Add all children
for other_class_id, other_class in ontology.classes.items():
if other_class.subclass_of == class_id and other_class_id not in subset.classes:
classes_to_add.add(other_class_id)
# Add collected classes
for class_id in classes_to_add:
cls = ontology.get_class(class_id)
if cls:
subset.classes[class_id] = cls.__dict__
logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes")
return subset

View file

@ -0,0 +1,640 @@
"""
Query explanation system for OntoRAG.
Provides detailed explanations of how queries are processed and results are derived.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
from .sparql_cassandra import SPARQLResult
from .cypher_executor import CypherResult
logger = logging.getLogger(__name__)
@dataclass
class ExplanationStep:
"""Individual step in query explanation."""
step_number: int
component: str
operation: str
input_data: Dict[str, Any]
output_data: Dict[str, Any]
explanation: str
duration_ms: float
success: bool
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class QueryExplanation:
"""Complete explanation of query processing."""
query_id: str
original_question: str
processing_steps: List[ExplanationStep]
final_answer: str
confidence_score: float
total_duration_ms: float
ontologies_used: List[str]
backend_used: str
reasoning_chain: List[str]
technical_details: Dict[str, Any]
user_friendly_explanation: str
class QueryExplainer:
"""Generates explanations for query processing."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query explainer.
Args:
config: Explainer configuration
"""
self.config = config or {}
self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical
self.include_technical_details = self.config.get('include_technical_details', True)
self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10)
# Templates for different explanation types
self.step_templates = {
'question_analysis': {
'basic': "I analyzed your question to understand what you're asking.",
'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.",
'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}"
},
'ontology_matching': {
'basic': "I found relevant knowledge about {entities} in the available ontologies.",
'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.",
'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}"
},
'query_generation': {
'basic': "I generated a query to search for the information.",
'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.",
'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}"
},
'query_execution': {
'basic': "I searched the database and found {result_count} results.",
'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.",
'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms"
},
'answer_generation': {
'basic': "I generated a natural language answer from the results.",
'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.",
'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}"
}
}
self.reasoning_templates = {
'entity_identification': "I identified '{entity}' as a key concept in your question.",
'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.",
'query_strategy': "I chose a {strategy} query approach because {reason}.",
'result_filtering': "I filtered the results to show only the most relevant {count} items.",
'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}."
}
def explain_query_processing(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
processing_metadata: Dict[str, Any]) -> QueryExplanation:
"""Generate comprehensive explanation of query processing.
Args:
question: Original question
question_components: Analyzed question components
ontology_subsets: Selected ontology subsets
generated_query: Generated query
query_results: Query execution results
final_answer: Final generated answer
processing_metadata: Processing metadata
Returns:
Complete query explanation
"""
query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}")
start_time = processing_metadata.get('start_time', datetime.now())
# Build explanation steps
steps = []
step_number = 1
# Step 1: Question Analysis
steps.append(self._explain_question_analysis(
step_number, question, question_components
))
step_number += 1
# Step 2: Ontology Matching
steps.append(self._explain_ontology_matching(
step_number, question_components, ontology_subsets
))
step_number += 1
# Step 3: Query Generation
steps.append(self._explain_query_generation(
step_number, generated_query, processing_metadata
))
step_number += 1
# Step 4: Query Execution
steps.append(self._explain_query_execution(
step_number, generated_query, query_results, processing_metadata
))
step_number += 1
# Step 5: Answer Generation
steps.append(self._explain_answer_generation(
step_number, query_results, final_answer, processing_metadata
))
# Build reasoning chain
reasoning_chain = self._build_reasoning_chain(
question_components, ontology_subsets, generated_query, processing_metadata
)
# Calculate overall confidence
confidence_score = self._calculate_explanation_confidence(
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
# Calculate total duration
total_duration = processing_metadata.get('total_duration_ms', 0)
return QueryExplanation(
query_id=query_id,
original_question=question,
processing_steps=steps,
final_answer=final_answer,
confidence_score=confidence_score,
total_duration_ms=total_duration,
ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets],
backend_used=processing_metadata.get('backend_used', 'unknown'),
reasoning_chain=reasoning_chain,
technical_details=self._extract_technical_details(processing_metadata),
user_friendly_explanation=user_friendly_explanation
)
def _explain_question_analysis(self,
step_number: int,
question: str,
question_components: QuestionComponents) -> ExplanationStep:
"""Explain question analysis step."""
template = self.step_templates['question_analysis'][self.explanation_level]
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
question=question,
question_type=question_components.question_type.value.replace('_', ' '),
entities=', '.join(question_components.entities[:3])
)
else: # technical
explanation = template.format(
question_type=question_components.question_type.value,
entities=question_components.entities,
keywords=question_components.keywords,
answer_type=question_components.expected_answer_type
)
return ExplanationStep(
step_number=step_number,
component="question_analyzer",
operation="analyze_question",
input_data={"question": question},
output_data={
"question_type": question_components.question_type.value,
"entities": question_components.entities,
"keywords": question_components.keywords
},
explanation=explanation,
duration_ms=0.0, # Would be tracked in actual implementation
success=True
)
def _explain_ontology_matching(self,
step_number: int,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep:
"""Explain ontology matching step."""
template = self.step_templates['ontology_matching'][self.explanation_level]
total_elements = sum(
len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
if self.explanation_level == 'basic':
explanation = template.format(
entities=', '.join(question_components.entities[:3])
)
elif self.explanation_level == 'detailed':
explanation = template.format(
ontology_count=len(ontology_subsets),
selected_elements=total_elements
)
else: # technical
total_classes = sum(len(subset.classes) for subset in ontology_subsets)
total_properties = sum(
len(subset.object_properties) + len(subset.datatype_properties)
for subset in ontology_subsets
)
ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets]
explanation = template.format(
classes=total_classes,
properties=total_properties,
ontologies=', '.join(ontology_names)
)
return ExplanationStep(
step_number=step_number,
component="ontology_matcher",
operation="select_relevant_subset",
input_data={"entities": question_components.entities},
output_data={
"ontology_count": len(ontology_subsets),
"total_elements": total_elements
},
explanation=explanation,
duration_ms=0.0,
success=True
)
def _explain_query_generation(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query generation step."""
template = self.step_templates['query_generation'][self.explanation_level]
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
backend = metadata.get('backend_used', 'unknown')
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
query_type=generated_query.query_type,
query_language=query_language,
backend=backend
)
else: # technical
explanation = template.format(
query_language=query_language,
variables=len(generated_query.variables),
complexity=f"{generated_query.complexity_score:.2f}"
)
return ExplanationStep(
step_number=step_number,
component="query_generator",
operation="generate_query",
input_data={"query_type": generated_query.query_type},
output_data={
"query_language": query_language,
"variables": generated_query.variables,
"complexity": generated_query.complexity_score
},
explanation=explanation,
duration_ms=0.0,
success=True,
metadata={"generated_query": generated_query.query}
)
def _explain_query_execution(self,
step_number: int,
generated_query: Union[SPARQLQuery, CypherQuery],
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain query execution step."""
template = self.step_templates['query_execution'][self.explanation_level]
backend = metadata.get('backend_used', 'unknown')
duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
if self.explanation_level == 'basic':
explanation = template.format(result_count=result_count)
elif self.explanation_level == 'detailed':
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
else: # technical
explanation = template.format(
backend=backend,
result_count=result_count,
duration=f"{duration:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="query_executor",
operation="execute_query",
input_data={"query": generated_query.query},
output_data={
"result_count": result_count,
"execution_time_ms": duration
},
explanation=explanation,
duration_ms=duration,
success=result_count >= 0
)
def _explain_answer_generation(self,
step_number: int,
query_results: Union[SPARQLResult, CypherResult],
final_answer: str,
metadata: Dict[str, Any]) -> ExplanationStep:
"""Explain answer generation step."""
template = self.step_templates['answer_generation'][self.explanation_level]
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else: # CypherResult
result_count = len(query_results.records)
confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage
if self.explanation_level == 'basic':
explanation = template
elif self.explanation_level == 'detailed':
explanation = template.format(
result_count=result_count,
confidence=f"{confidence:.0f}"
)
else: # technical
generation_method = metadata.get('generation_method', 'template_based')
explanation = template.format(
result_count=result_count,
generation_method=generation_method,
confidence=f"{confidence:.1f}"
)
return ExplanationStep(
step_number=step_number,
component="answer_generator",
operation="generate_answer",
input_data={"result_count": result_count},
output_data={
"answer": final_answer,
"confidence": confidence / 100
},
explanation=explanation,
duration_ms=0.0,
success=bool(final_answer)
)
def _build_reasoning_chain(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
generated_query: Union[SPARQLQuery, CypherQuery],
metadata: Dict[str, Any]) -> List[str]:
"""Build reasoning chain explaining the decision process."""
reasoning = []
# Entity identification reasoning
if question_components.entities:
for entity in question_components.entities[:3]:
reasoning.append(
self.reasoning_templates['entity_identification'].format(entity=entity)
)
# Ontology selection reasoning
if ontology_subsets:
primary_ontology = ontology_subsets[0]
ontology_id = primary_ontology.metadata.get('ontology_id', 'primary')
concepts = list(primary_ontology.classes.keys())[:3]
reasoning.append(
self.reasoning_templates['ontology_selection'].format(
ontology=ontology_id,
concepts=', '.join(concepts)
)
)
# Query strategy reasoning
query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher"
if question_components.question_type == QuestionType.AGGREGATION:
strategy = "aggregation"
reason = "you asked for a count or sum"
elif question_components.question_type == QuestionType.BOOLEAN:
strategy = "boolean"
reason = "you asked a yes/no question"
else:
strategy = "retrieval"
reason = "you asked for specific information"
reasoning.append(
self.reasoning_templates['query_strategy'].format(
strategy=strategy,
reason=reason
)
)
# Confidence assessment
confidence = metadata.get('answer_confidence', 0.8) * 100
if confidence > 90:
confidence_reason = "the query matched well with available data"
elif confidence > 70:
confidence_reason = "the query found relevant information with some uncertainty"
else:
confidence_reason = "the available data partially matches your question"
reasoning.append(
self.reasoning_templates['confidence_assessment'].format(
confidence=f"{confidence:.0f}",
reasoning=confidence_reason
)
)
return reasoning[:self.max_reasoning_steps]
def _calculate_explanation_confidence(self,
question_components: QuestionComponents,
query_results: Union[SPARQLResult, CypherResult],
metadata: Dict[str, Any]) -> float:
"""Calculate confidence score for the explanation."""
confidence = 0.8 # Base confidence
# Adjust based on result count
if isinstance(query_results, SPARQLResult):
result_count = len(query_results.bindings)
else:
result_count = len(query_results.records)
if result_count > 0:
confidence += 0.1
if result_count > 5:
confidence += 0.05
# Adjust based on question complexity
if len(question_components.entities) > 0:
confidence += 0.05
# Adjust based on processing success
if metadata.get('all_steps_successful', True):
confidence += 0.05
return min(confidence, 1.0)
def _generate_user_friendly_explanation(self,
question: str,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
explanation_parts = []
# Introduction
explanation_parts.append(f"To answer your question '{question}', I followed these steps:")
# Process summary
if question_components.question_type == QuestionType.AGGREGATION:
explanation_parts.append("1. I recognized this as a counting or aggregation question")
elif question_components.question_type == QuestionType.BOOLEAN:
explanation_parts.append("1. I recognized this as a yes/no question")
else:
explanation_parts.append("1. I analyzed your question to understand what information you need")
# Ontology usage
if ontology_subsets:
ontology_count = len(ontology_subsets)
if ontology_count == 1:
explanation_parts.append("2. I searched through the relevant knowledge base")
else:
explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases")
# Result processing
explanation_parts.append("3. I found the relevant information and generated your answer")
# Conclusion
explanation_parts.append(f"The answer is: {final_answer}")
return " ".join(explanation_parts)
def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract technical details for debugging and optimization."""
return {
'query_optimization': metadata.get('query_optimization', {}),
'backend_performance': metadata.get('backend_performance', {}),
'cache_usage': metadata.get('cache_usage', {}),
'error_handling': metadata.get('error_handling', {}),
'routing_decision': metadata.get('routing_decision', {})
}
def format_explanation_for_display(self,
explanation: QueryExplanation,
format_type: str = 'html') -> str:
"""Format explanation for display.
Args:
explanation: Query explanation
format_type: Output format ('html', 'markdown', 'text')
Returns:
Formatted explanation
"""
if format_type == 'html':
return self._format_html_explanation(explanation)
elif format_type == 'markdown':
return self._format_markdown_explanation(explanation)
else:
return self._format_text_explanation(explanation)
def _format_html_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as HTML."""
html_parts = [
f"<h2>Query Explanation: {explanation.query_id}</h2>",
f"<p><strong>Question:</strong> {explanation.original_question}</p>",
f"<p><strong>Answer:</strong> {explanation.final_answer}</p>",
f"<p><strong>Confidence:</strong> {explanation.confidence_score:.1%}</p>",
"<h3>Processing Steps:</h3>",
"<ol>"
]
for step in explanation.processing_steps:
html_parts.append(f"<li><strong>{step.component}</strong>: {step.explanation}</li>")
html_parts.extend([
"</ol>",
"<h3>Reasoning:</h3>",
"<ul>"
])
for reasoning in explanation.reasoning_chain:
html_parts.append(f"<li>{reasoning}</li>")
html_parts.append("</ul>")
return "".join(html_parts)
def _format_markdown_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as Markdown."""
md_parts = [
f"## Query Explanation: {explanation.query_id}",
f"**Question:** {explanation.original_question}",
f"**Answer:** {explanation.final_answer}",
f"**Confidence:** {explanation.confidence_score:.1%}",
"",
"### Processing Steps:",
""
]
for i, step in enumerate(explanation.processing_steps, 1):
md_parts.append(f"{i}. **{step.component}**: {step.explanation}")
md_parts.extend([
"",
"### Reasoning:",
""
])
for reasoning in explanation.reasoning_chain:
md_parts.append(f"- {reasoning}")
return "\n".join(md_parts)
def _format_text_explanation(self, explanation: QueryExplanation) -> str:
"""Format explanation as plain text."""
text_parts = [
f"Query Explanation: {explanation.query_id}",
f"Question: {explanation.original_question}",
f"Answer: {explanation.final_answer}",
f"Confidence: {explanation.confidence_score:.1%}",
"",
"Processing Steps:",
]
for i, step in enumerate(explanation.processing_steps, 1):
text_parts.append(f" {i}. {step.component}: {step.explanation}")
text_parts.extend([
"",
"Reasoning:",
])
for reasoning in explanation.reasoning_chain:
text_parts.append(f" - {reasoning}")
return "\n".join(text_parts)

View file

@ -0,0 +1,519 @@
"""
Query optimization module for OntoRAG.
Optimizes SPARQL and Cypher queries for better performance and accuracy.
"""
import logging
from typing import Dict, Any, List, Optional, Union, Tuple
from dataclasses import dataclass
from enum import Enum
import re
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
from .sparql_generator import SPARQLQuery
from .cypher_generator import CypherQuery
logger = logging.getLogger(__name__)
class OptimizationStrategy(Enum):
"""Query optimization strategies."""
PERFORMANCE = "performance"
ACCURACY = "accuracy"
BALANCED = "balanced"
@dataclass
class OptimizationHint:
"""Optimization hint for query processing."""
strategy: OptimizationStrategy
max_results: Optional[int] = None
timeout_seconds: Optional[int] = None
use_indices: bool = True
enable_parallel: bool = False
cache_results: bool = True
@dataclass
class QueryPlan:
"""Query execution plan with optimization metadata."""
original_query: str
optimized_query: str
estimated_cost: float
optimization_notes: List[str]
index_hints: List[str]
execution_order: List[str]
class QueryOptimizer:
"""Optimizes SPARQL and Cypher queries for performance and accuracy."""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize query optimizer.
Args:
config: Optimizer configuration
"""
self.config = config or {}
self.default_strategy = OptimizationStrategy(
self.config.get('default_strategy', 'balanced')
)
self.max_query_complexity = self.config.get('max_query_complexity', 10)
self.enable_query_rewriting = self.config.get('enable_query_rewriting', True)
# Performance thresholds
self.large_result_threshold = self.config.get('large_result_threshold', 1000)
self.complex_join_threshold = self.config.get('complex_join_threshold', 3)
def optimize_sparql(self,
sparql_query: SPARQLQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]:
"""Optimize SPARQL query.
Args:
sparql_query: Original SPARQL query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized SPARQL query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = sparql_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_sparql_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=sparql_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_sparql = SPARQLQuery(
query=optimized_query,
variables=sparql_query.variables,
query_type=sparql_query.query_type,
explanation=f"Optimized: {sparql_query.explanation}",
complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity
)
return optimized_sparql, query_plan
def optimize_cypher(self,
cypher_query: CypherQuery,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]:
"""Optimize Cypher query.
Args:
cypher_query: Original Cypher query
question_components: Question analysis
ontology_subset: Ontology subset
optimization_hint: Optimization hints
Returns:
Optimized Cypher query and execution plan
"""
hint = optimization_hint or OptimizationHint(strategy=self.default_strategy)
optimized_query = cypher_query.query
optimization_notes = []
index_hints = []
execution_order = []
# Apply optimizations based on strategy
if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]:
optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance(
optimized_query, question_components, ontology_subset, hint
)
optimization_notes.extend(perf_notes)
index_hints.extend(perf_hints)
if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]:
optimized_query, acc_notes = self._optimize_cypher_accuracy(
optimized_query, question_components, ontology_subset
)
optimization_notes.extend(acc_notes)
# Estimate query cost
estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset)
# Build execution plan
query_plan = QueryPlan(
original_query=cypher_query.query,
optimized_query=optimized_query,
estimated_cost=estimated_cost,
optimization_notes=optimization_notes,
index_hints=index_hints,
execution_order=execution_order
)
# Create optimized query object
optimized_cypher = CypherQuery(
query=optimized_query,
variables=cypher_query.variables,
query_type=cypher_query.query_type,
explanation=f"Optimized: {cypher_query.explanation}",
complexity_score=min(cypher_query.complexity_score * 0.8, 1.0)
)
return optimized_cypher, query_plan
def _optimize_sparql_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present and large results expected
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Optimize OPTIONAL clauses (move to end)
optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL)
optionals = optional_pattern.findall(optimized)
if optionals:
# Remove optionals from current position
for optional in optionals:
optimized = optimized.replace(optional, '')
# Add them at the end (before ORDER BY/LIMIT)
insert_point = optimized.rfind('ORDER BY')
if insert_point == -1:
insert_point = optimized.rfind('LIMIT')
if insert_point == -1:
insert_point = len(optimized.rstrip())
for optional in optionals:
optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:]
notes.append("Moved OPTIONAL clauses to end for better performance")
# Add index hints for Cassandra
if 'WHERE' in optimized.upper():
# Suggest indices for common patterns
if '?subject rdf:type' in optimized:
index_hints.append("type_index")
if 'rdfs:subClassOf' in optimized:
index_hints.append("hierarchy_index")
# Optimize FILTER clauses (move closer to variable bindings)
filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE)
filters = filter_pattern.findall(optimized)
if filters:
notes.append("FILTER clauses present - ensure they're positioned optimally")
return optimized, notes, index_hints
def _optimize_sparql_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to SPARQL query.
Args:
query: SPARQL query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add missing namespace checks
if question_components.question_type == QuestionType.RETRIEVAL:
# Ensure we're not mixing namespaces inappropriately
if 'http://' in optimized and '?' in optimized:
notes.append("Verified namespace consistency for accuracy")
# Add type constraints for better precision
if '?entity' in optimized and 'rdf:type' not in optimized:
# Find a good insertion point
where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE)
if where_clause and ontology_subset.classes:
# Add type constraint for the most relevant class
main_class = list(ontology_subset.classes.keys())[0]
type_constraint = f"\n ?entity rdf:type :{main_class} ."
# Insert after the WHERE {
where_start = where_clause.start(1)
optimized = optimized[:where_start] + type_constraint + optimized[where_start:]
notes.append(f"Added type constraint for {main_class} to improve accuracy")
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'SELECT' in optimized.upper()):
optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1)
notes.append("Added DISTINCT to eliminate duplicate results")
return optimized, notes
def _optimize_cypher_performance(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset,
hint: OptimizationHint) -> Tuple[str, List[str], List[str]]:
"""Apply performance optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
hint: Optimization hints
Returns:
Optimized query, optimization notes, and index hints
"""
optimized = query
notes = []
index_hints = []
# Add LIMIT if not present
if hint.max_results and 'LIMIT' not in optimized.upper():
optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}"
notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets")
# Use parameters for literals to enable query plan caching
if "'" in optimized or '"' in optimized:
notes.append("Consider using parameters for literal values to enable query plan caching")
# Suggest indices based on query patterns
if 'MATCH (n:' in optimized:
label_match = re.search(r'MATCH \(n:(\w+)\)', optimized)
if label_match:
label = label_match.group(1)
index_hints.append(f"node_label_index:{label}")
if 'WHERE' in optimized.upper() and '.' in optimized:
# Property access patterns
property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE)
properties = property_pattern.findall(optimized)
for prop in set(properties):
index_hints.append(f"property_index:{prop}")
# Optimize relationship traversals
if '-[' in optimized and '*' in optimized:
notes.append("Variable length path detected - consider adding relationship type filters")
# Early filtering optimization
if 'WHERE' in optimized.upper():
# Move WHERE clauses closer to MATCH clauses
notes.append("WHERE clauses present - ensure early filtering for performance")
return optimized, notes, index_hints
def _optimize_cypher_accuracy(self,
query: str,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]:
"""Apply accuracy optimizations to Cypher query.
Args:
query: Cypher query string
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Optimized query and optimization notes
"""
optimized = query
notes = []
# Add DISTINCT if not present for retrieval queries
if (question_components.question_type == QuestionType.RETRIEVAL and
'DISTINCT' not in optimized.upper() and
'RETURN' in optimized.upper()):
optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE)
notes.append("Added DISTINCT to eliminate duplicate results")
# Ensure proper relationship direction
if '-[' in optimized and question_components.relationships:
notes.append("Verified relationship directions for semantic accuracy")
# Add null checks for optional properties
if '?' in optimized or 'OPTIONAL' in optimized.upper():
notes.append("Consider adding null checks for optional properties")
return optimized, notes
def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for SPARQL query.
Args:
query: SPARQL query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Join complexity
triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query))
cost += triple_patterns * 0.1
# OPTIONAL clauses
optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE))
cost += optional_count * 0.15
# FILTER clauses
filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE))
cost += filter_count * 0.1
# Property paths
path_count = len(re.findall(r'\*|\+', query))
cost += path_count * 0.2
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float:
"""Estimate execution cost for Cypher query.
Args:
query: Cypher query string
ontology_subset: Ontology subset
Returns:
Estimated cost (0.0 to 1.0)
"""
cost = 0.0
# Basic query complexity
cost += len(query.split('\n')) * 0.01
# Pattern complexity
match_count = len(re.findall(r'MATCH', query, re.IGNORECASE))
cost += match_count * 0.1
# Relationship traversals
rel_count = len(re.findall(r'-\[.*?\]-', query))
cost += rel_count * 0.1
# Variable length paths
var_path_count = len(re.findall(r'\*\d*\.\.', query))
cost += var_path_count * 0.3
# WHERE clauses
where_count = len(re.findall(r'WHERE', query, re.IGNORECASE))
cost += where_count * 0.05
# Aggregation functions
agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE))
cost += agg_count * 0.1
# Ontology subset size impact
total_elements = (len(ontology_subset.classes) +
len(ontology_subset.object_properties) +
len(ontology_subset.datatype_properties))
cost += (total_elements / 100.0) * 0.1
return min(cost, 1.0)
def should_use_cache(self,
query: str,
question_components: QuestionComponents,
optimization_hint: OptimizationHint) -> bool:
"""Determine if query results should be cached.
Args:
query: Query string
question_components: Question analysis
optimization_hint: Optimization hints
Returns:
True if results should be cached
"""
if not optimization_hint.cache_results:
return False
# Cache simple retrieval and factual queries
if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]:
return True
# Cache expensive aggregation queries
if (question_components.question_type == QuestionType.AGGREGATION and
('COUNT' in query.upper() or 'SUM' in query.upper())):
return True
# Don't cache real-time or time-sensitive queries
if any(keyword in question_components.original_question.lower()
for keyword in ['now', 'current', 'latest', 'recent']):
return False
return False
def get_cache_key(self,
query: str,
ontology_subset: QueryOntologySubset) -> str:
"""Generate cache key for query.
Args:
query: Query string
ontology_subset: Ontology subset
Returns:
Cache key string
"""
import hashlib
# Create stable representation
ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}"
combined = f"{query.strip()}-{ontology_repr}"
return hashlib.md5(combined.encode()).hexdigest()

View file

@ -0,0 +1,438 @@
"""
Main OntoRAG query service.
Orchestrates question analysis, ontology matching, query generation, execution, and answer generation.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from datetime import datetime
from ....flow.flow_processor import FlowProcessor
from ....tables.config import ConfigTableStore
from ...extract.kg.ontology.ontology_loader import OntologyLoader
from ...extract.kg.ontology.vector_store import InMemoryVectorStore
from .question_analyzer import QuestionAnalyzer, QuestionComponents
from .ontology_matcher import OntologyMatcher, QueryOntologySubset
from .backend_router import BackendRouter, QueryRoute, BackendType
from .sparql_generator import SPARQLGenerator, SPARQLQuery
from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult
from .cypher_generator import CypherGenerator, CypherQuery
from .cypher_executor import CypherExecutor, CypherResult
from .answer_generator import AnswerGenerator, GeneratedAnswer
logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None
max_results: int = 10
confidence_threshold: float = 0.7
@dataclass
class QueryResponse:
"""Complete query response."""
answer: str
confidence: float
execution_time: float
question_analysis: QuestionComponents
ontology_subsets: List[QueryOntologySubset]
query_route: QueryRoute
generated_query: Union[SPARQLQuery, CypherQuery]
raw_results: Union[SPARQLResult, CypherResult]
supporting_facts: List[str]
metadata: Dict[str, Any]
class OntoRAGQueryService(FlowProcessor):
"""Main OntoRAG query service orchestrating all components."""
def __init__(self, config: Dict[str, Any]):
"""Initialize OntoRAG query service.
Args:
config: Service configuration
"""
super().__init__(config)
self.config = config
# Initialize components
self.config_store = None
self.ontology_loader = None
self.vector_store = None
self.question_analyzer = None
self.ontology_matcher = None
self.backend_router = None
self.sparql_generator = None
self.sparql_engine = None
self.cypher_generator = None
self.cypher_executor = None
self.answer_generator = None
# Cache for loaded ontologies
self.ontology_cache = {}
async def init(self):
"""Initialize all components."""
await super().init()
# Initialize configuration store
self.config_store = ConfigTableStore(self.config.get('config_store', {}))
# Initialize ontology components
self.ontology_loader = OntologyLoader(self.config_store)
# Initialize vector store
vector_config = self.config.get('vector_store', {})
self.vector_store = InMemoryVectorStore.create(
store_type=vector_config.get('type', 'numpy'),
dimension=vector_config.get('dimension', 384),
similarity_threshold=vector_config.get('similarity_threshold', 0.7)
)
# Initialize question analyzer
analyzer_config = self.config.get('question_analyzer', {})
self.question_analyzer = QuestionAnalyzer(
prompt_service=self.prompt_service,
config=analyzer_config
)
# Initialize ontology matcher
matcher_config = self.config.get('ontology_matcher', {})
self.ontology_matcher = OntologyMatcher(
vector_store=self.vector_store,
embedding_service=self.embedding_service,
config=matcher_config
)
# Initialize backend router
router_config = self.config.get('backend_router', {})
self.backend_router = BackendRouter(router_config)
# Initialize query generators
self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service)
self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service)
# Initialize executors
sparql_config = self.config.get('sparql_executor', {})
if self.backend_router.is_backend_enabled(BackendType.CASSANDRA):
cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA)
if cassandra_config:
self.sparql_engine = SPARQLCassandraEngine(cassandra_config)
await self.sparql_engine.initialize()
cypher_config = self.config.get('cypher_executor', {})
enabled_graph_backends = [
bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB]
if self.backend_router.is_backend_enabled(bt)
]
if enabled_graph_backends:
self.cypher_executor = CypherExecutor(cypher_config)
await self.cypher_executor.initialize()
# Initialize answer generator
self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service)
logger.info("OntoRAG query service initialized")
async def process(self, request: QueryRequest) -> QueryResponse:
"""Process a natural language query.
Args:
request: Query request
Returns:
Complete query response
"""
start_time = datetime.now()
try:
logger.info(f"Processing query: {request.question}")
# Step 1: Analyze question
question_components = await self.question_analyzer.analyze_question(
request.question, context=request.context
)
logger.debug(f"Question analysis: {question_components.question_type}")
# Step 2: Load and match ontologies
ontology_subsets = await self._load_and_match_ontologies(
question_components, request.ontology_hint
)
logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets")
# Step 3: Route to appropriate backend
query_route = self.backend_router.route_query(
question_components, ontology_subsets
)
logger.debug(f"Routed to {query_route.backend_type.value} backend")
# Step 4: Generate and execute query
if query_route.query_language == 'sparql':
query_results = await self._execute_sparql_path(
question_components, ontology_subsets, query_route
)
else: # cypher
query_results = await self._execute_cypher_path(
question_components, ontology_subsets, query_route
)
# Step 5: Generate natural language answer
generated_answer = await self.answer_generator.generate_answer(
question_components,
query_results['raw_results'],
ontology_subsets[0] if ontology_subsets else None,
query_route.backend_type.value
)
# Build response
execution_time = (datetime.now() - start_time).total_seconds()
response = QueryResponse(
answer=generated_answer.answer,
confidence=min(query_route.confidence, generated_answer.metadata.confidence),
execution_time=execution_time,
question_analysis=question_components,
ontology_subsets=ontology_subsets,
query_route=query_route,
generated_query=query_results['generated_query'],
raw_results=query_results['raw_results'],
supporting_facts=generated_answer.supporting_facts,
metadata={
'backend_used': query_route.backend_type.value,
'query_language': query_route.query_language,
'ontology_count': len(ontology_subsets),
'result_count': generated_answer.metadata.result_count,
'routing_reasoning': query_route.reasoning,
'generation_time': generated_answer.generation_time
}
)
logger.info(f"Query processed successfully in {execution_time:.2f}s")
return response
except Exception as e:
logger.error(f"Query processing failed: {e}")
execution_time = (datetime.now() - start_time).total_seconds()
# Return error response
return QueryResponse(
answer=f"I encountered an error processing your query: {str(e)}",
confidence=0.0,
execution_time=execution_time,
question_analysis=QuestionComponents(
original_question=request.question,
normalized_question=request.question,
question_type=None,
entities=[], keywords=[], relationships=[], constraints=[],
aggregations=[], expected_answer_type="unknown"
),
ontology_subsets=[],
query_route=None,
generated_query=None,
raw_results=None,
supporting_facts=[],
metadata={'error': str(e), 'execution_time': execution_time}
)
async def _load_and_match_ontologies(self,
question_components: QuestionComponents,
ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]:
"""Load ontologies and find relevant subsets.
Args:
question_components: Analyzed question
ontology_hint: Optional ontology hint
Returns:
List of relevant ontology subsets
"""
try:
# Load available ontologies
if ontology_hint:
# Load specific ontology
ontologies = [await self.ontology_loader.load_ontology(ontology_hint)]
else:
# Load all available ontologies
available_ontologies = await self.ontology_loader.list_available_ontologies()
ontologies = []
for ontology_id in available_ontologies[:5]: # Limit to 5 for performance
try:
ontology = await self.ontology_loader.load_ontology(ontology_id)
ontologies.append(ontology)
except Exception as e:
logger.warning(f"Failed to load ontology {ontology_id}: {e}")
if not ontologies:
logger.warning("No ontologies loaded")
return []
# Extract relevant subsets
ontology_subsets = []
for ontology in ontologies:
subset = await self.ontology_matcher.select_relevant_subset(
question_components, ontology
)
if subset and (subset.classes or subset.object_properties or subset.datatype_properties):
ontology_subsets.append(subset)
return ontology_subsets
except Exception as e:
logger.error(f"Failed to load and match ontologies: {e}")
return []
async def _execute_sparql_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute SPARQL query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.sparql_engine:
raise RuntimeError("SPARQL engine not initialized")
# Generate SPARQL query
primary_subset = ontology_subsets[0] if ontology_subsets else None
sparql_query = await self.sparql_generator.generate_sparql(
question_components, primary_subset
)
logger.debug(f"Generated SPARQL: {sparql_query.query}")
# Execute query
sparql_results = self.sparql_engine.execute_sparql(sparql_query.query)
return {
'generated_query': sparql_query,
'raw_results': sparql_results
}
async def _execute_cypher_path(self,
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
query_route: QueryRoute) -> Dict[str, Any]:
"""Execute Cypher query path.
Args:
question_components: Question analysis
ontology_subsets: Ontology subsets
query_route: Query route
Returns:
Query execution results
"""
if not self.cypher_executor:
raise RuntimeError("Cypher executor not initialized")
# Generate Cypher query
primary_subset = ontology_subsets[0] if ontology_subsets else None
cypher_query = await self.cypher_generator.generate_cypher(
question_components, primary_subset
)
logger.debug(f"Generated Cypher: {cypher_query.query}")
# Execute query
database_type = query_route.backend_type.value
cypher_results = await self.cypher_executor.execute_query(
cypher_query.query, database_type=database_type
)
return {
'generated_query': cypher_query,
'raw_results': cypher_results
}
async def get_supported_backends(self) -> List[str]:
"""Get list of supported and enabled backends.
Returns:
List of backend names
"""
return [bt.value for bt in self.backend_router.get_available_backends()]
async def get_available_ontologies(self) -> List[str]:
"""Get list of available ontologies.
Returns:
List of ontology identifiers
"""
if self.ontology_loader:
return await self.ontology_loader.list_available_ontologies()
return []
async def health_check(self) -> Dict[str, Any]:
"""Perform health check on all components.
Returns:
Health status of all components
"""
health = {
'service': 'healthy',
'components': {},
'backends': {},
'ontologies': {}
}
try:
# Check ontology loader
if self.ontology_loader:
ontologies = await self.ontology_loader.list_available_ontologies()
health['components']['ontology_loader'] = 'healthy'
health['ontologies']['count'] = len(ontologies)
else:
health['components']['ontology_loader'] = 'not_initialized'
# Check vector store
if self.vector_store:
health['components']['vector_store'] = 'healthy'
health['components']['vector_store_type'] = type(self.vector_store).__name__
else:
health['components']['vector_store'] = 'not_initialized'
# Check backends
for backend_type in self.backend_router.get_available_backends():
if backend_type == BackendType.CASSANDRA and self.sparql_engine:
health['backends']['cassandra'] = 'healthy'
elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor:
health['backends'][backend_type.value] = 'healthy'
else:
health['backends'][backend_type.value] = 'configured_but_not_initialized'
except Exception as e:
health['service'] = 'degraded'
health['error'] = str(e)
return health
async def close(self):
"""Close all connections and cleanup resources."""
try:
if self.sparql_engine:
self.sparql_engine.close()
if self.cypher_executor:
await self.cypher_executor.close()
if self.config_store:
# ConfigTableStore cleanup if needed
pass
logger.info("OntoRAG query service closed")
except Exception as e:
logger.error(f"Error closing OntoRAG query service: {e}")

View file

@ -0,0 +1,364 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
"""
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class QuestionType(Enum):
"""Types of questions that can be asked."""
FACTUAL = "factual" # What is X?
RETRIEVAL = "retrieval" # Find all X
AGGREGATION = "aggregation" # How many X?
COMPARISON = "comparison" # Is X better than Y?
RELATIONSHIP = "relationship" # How is X related to Y?
BOOLEAN = "boolean" # Yes/no questions
PROCESS = "process" # How to do X?
TEMPORAL = "temporal" # When did X happen?
SPATIAL = "spatial" # Where is X?
@dataclass
class QuestionComponents:
"""Components extracted from a question."""
original_question: str
question_type: QuestionType
entities: List[str]
relationships: List[str]
constraints: List[str]
aggregations: List[str]
expected_answer_type: str
keywords: List[str]
class QuestionAnalyzer:
"""Analyzes natural language questions to extract semantic components."""
def __init__(self):
"""Initialize question analyzer."""
# Question word patterns
self.question_patterns = {
QuestionType.FACTUAL: [
r'^what\s+(?:is|are)',
r'^who\s+(?:is|are)',
r'^which\s+',
],
QuestionType.RETRIEVAL: [
r'^find\s+',
r'^list\s+',
r'^show\s+',
r'^get\s+',
r'^retrieve\s+',
],
QuestionType.AGGREGATION: [
r'^how\s+many',
r'^count\s+',
r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)',
],
QuestionType.COMPARISON: [
r'(?:better|worse|more|less|greater|smaller)\s+than',
r'compare\s+',
r'difference\s+between',
],
QuestionType.RELATIONSHIP: [
r'^how\s+(?:is|are).*related',
r'relationship\s+between',
r'connection\s+between',
],
QuestionType.BOOLEAN: [
r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)',
r'^has\s+',
r'^have\s+',
],
QuestionType.PROCESS: [
r'^how\s+(?:to|do)',
r'^explain\s+how',
],
QuestionType.TEMPORAL: [
r'^when\s+',
r'what\s+time',
r'what\s+date',
],
QuestionType.SPATIAL: [
r'^where\s+',
r'location\s+of',
],
}
# Aggregation keywords
self.aggregation_keywords = [
'count', 'sum', 'total', 'average', 'mean', 'median',
'maximum', 'minimum', 'max', 'min', 'number of'
]
# Constraint patterns
self.constraint_patterns = [
r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)',
r'(?:greater|less|more|fewer)\s+than\s+(\d+)',
r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)',
r'(?:before|after|since|until)\s+(.+)',
]
def analyze(self, question: str) -> QuestionComponents:
"""Analyze a question to extract components.
Args:
question: Natural language question
Returns:
QuestionComponents with extracted information
"""
# Normalize question
question_lower = question.lower().strip()
# Determine question type
question_type = self._identify_question_type(question_lower)
# Extract entities
entities = self._extract_entities(question)
# Extract relationships
relationships = self._extract_relationships(question_lower)
# Extract constraints
constraints = self._extract_constraints(question_lower)
# Extract aggregations
aggregations = self._extract_aggregations(question_lower)
# Determine expected answer type
answer_type = self._determine_answer_type(question_type, aggregations)
# Extract keywords
keywords = self._extract_keywords(question_lower)
return QuestionComponents(
original_question=question,
question_type=question_type,
entities=entities,
relationships=relationships,
constraints=constraints,
aggregations=aggregations,
expected_answer_type=answer_type,
keywords=keywords
)
def _identify_question_type(self, question: str) -> QuestionType:
"""Identify the type of question.
Args:
question: Lowercase question text
Returns:
QuestionType enum value
"""
for q_type, patterns in self.question_patterns.items():
for pattern in patterns:
if re.search(pattern, question):
return q_type
# Default to factual
return QuestionType.FACTUAL
def _extract_entities(self, question: str) -> List[str]:
"""Extract potential entities from question.
Args:
question: Original question text
Returns:
List of entity strings
"""
entities = []
# Extract capitalized words/phrases (potential proper nouns)
# Pattern for consecutive capitalized words
pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b'
matches = re.findall(pattern, question)
entities.extend(matches)
# Extract quoted strings
quoted = re.findall(r'"([^"]+)"', question)
entities.extend(quoted)
quoted = re.findall(r"'([^']+)'", question)
entities.extend(quoted)
# Remove duplicates while preserving order
seen = set()
unique_entities = []
for entity in entities:
if entity not in seen:
seen.add(entity)
unique_entities.append(entity)
return unique_entities
def _extract_relationships(self, question: str) -> List[str]:
"""Extract relationship indicators from question.
Args:
question: Lowercase question text
Returns:
List of relationship strings
"""
relationships = []
# Common relationship patterns
rel_patterns = [
r'(\w+)\s+(?:of|by|from|to|with|for)\s+',
r'has\s+(\w+)',
r'belongs?\s+to',
r'(?:created|written|authored|owned)\s+by',
r'related\s+to',
r'connected\s+to',
r'associated\s+with',
]
for pattern in rel_patterns:
matches = re.findall(pattern, question)
relationships.extend(matches)
# Clean up
relationships = [r for r in relationships if len(r) > 2]
return list(set(relationships))
def _extract_constraints(self, question: str) -> List[str]:
"""Extract constraints from question.
Args:
question: Lowercase question text
Returns:
List of constraint strings
"""
constraints = []
for pattern in self.constraint_patterns:
matches = re.findall(pattern, question)
if matches:
if isinstance(matches[0], tuple):
constraints.extend(list(matches[0]))
else:
constraints.extend(matches)
# Clean up
constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0]
return constraints
def _extract_aggregations(self, question: str) -> List[str]:
"""Extract aggregation operations from question.
Args:
question: Lowercase question text
Returns:
List of aggregation operations
"""
aggregations = []
for keyword in self.aggregation_keywords:
if keyword in question:
aggregations.append(keyword)
return aggregations
def _determine_answer_type(self, question_type: QuestionType,
aggregations: List[str]) -> str:
"""Determine expected answer type.
Args:
question_type: Type of question
aggregations: Aggregation operations found
Returns:
Expected answer type string
"""
if aggregations:
if any(a in ['count', 'number of', 'total'] for a in aggregations):
return 'number'
elif any(a in ['average', 'mean', 'median'] for a in aggregations):
return 'number'
elif any(a in ['sum'] for a in aggregations):
return 'number'
if question_type == QuestionType.BOOLEAN:
return 'boolean'
elif question_type == QuestionType.TEMPORAL:
return 'datetime'
elif question_type == QuestionType.SPATIAL:
return 'location'
elif question_type == QuestionType.RETRIEVAL:
return 'list'
elif question_type == QuestionType.COMPARISON:
return 'comparison'
else:
return 'text'
def _extract_keywords(self, question: str) -> List[str]:
"""Extract important keywords from question.
Args:
question: Lowercase question text
Returns:
List of keywords
"""
# Remove common stop words
stop_words = {
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are',
'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do',
'does', 'did', 'will', 'would', 'could', 'should', 'may',
'might', 'must', 'can', 'shall', 'what', 'which', 'who',
'when', 'where', 'why', 'how'
}
# Extract words
words = re.findall(r'\b\w+\b', question)
# Filter stop words and short words
keywords = [w for w in words if w not in stop_words and len(w) > 2]
# Remove duplicates while preserving order
seen = set()
unique_keywords = []
for kw in keywords:
if kw not in seen:
seen.add(kw)
unique_keywords.append(kw)
return unique_keywords
def get_question_segments(self, question: str) -> List[str]:
"""Split question into segments for embedding.
Args:
question: Question text
Returns:
List of question segments
"""
segments = []
# Add full question
segments.append(question)
# Split by clauses
clauses = re.split(r'[,;]', question)
segments.extend([c.strip() for c in clauses if len(c.strip()) > 3])
# Extract key phrases
components = self.analyze(question)
segments.extend(components.entities)
segments.extend(components.keywords)
# Remove duplicates
return list(dict.fromkeys(segments))

View file

@ -0,0 +1,481 @@
"""
SPARQL-Cassandra engine using Python rdflib.
Executes SPARQL queries against Cassandra using a custom Store implementation.
"""
import logging
from typing import Dict, Any, List, Optional, Iterator, Tuple
from dataclasses import dataclass
import json
# Try to import rdflib
try:
from rdflib import Graph, Namespace, URIRef, Literal, BNode
from rdflib.store import Store
from rdflib.plugins.sparql.processor import SPARQLResult
from rdflib.plugins.sparql import prepareQuery
from rdflib.term import Node
RDFLIB_AVAILABLE = True
except ImportError:
RDFLIB_AVAILABLE = False
# Try to import Cassandra driver
try:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.policies import DCAwareRoundRobinPolicy
CASSANDRA_AVAILABLE = True
except ImportError:
CASSANDRA_AVAILABLE = False
from ....tables.config import ConfigTableStore
logger = logging.getLogger(__name__)
@dataclass
class SPARQLResult:
"""Result from SPARQL query execution."""
bindings: List[Dict[str, Any]]
variables: List[str]
ask_result: Optional[bool] = None # For ASK queries
execution_time: float = 0.0
query_plan: Optional[str] = None
class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
"""Custom rdflib Store implementation for Cassandra."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize Cassandra triple store.
Args:
cassandra_config: Cassandra connection configuration
"""
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver not available")
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib not available")
super().__init__()
self.cassandra_config = cassandra_config
self.cluster = None
self.session = None
self.keyspace = cassandra_config.get('keyspace', 'trustgraph')
# Triple storage table structure
self.triple_table = f"{self.keyspace}.triples"
self.metadata_table = f"{self.keyspace}.triple_metadata"
def open(self, configuration=None, create=False):
"""Open connection to Cassandra."""
try:
# Create authentication if provided
auth_provider = None
if 'username' in self.cassandra_config and 'password' in self.cassandra_config:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_config['username'],
password=self.cassandra_config['password']
)
# Create cluster
self.cluster = Cluster(
[self.cassandra_config.get('host', 'localhost')],
port=self.cassandra_config.get('port', 9042),
auth_provider=auth_provider,
load_balancing_policy=DCAwareRoundRobinPolicy()
)
# Connect
self.session = self.cluster.connect()
# Ensure keyspace exists
if create:
self._create_schema()
# Set keyspace
self.session.set_keyspace(self.keyspace)
logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}")
return True
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}")
return False
def close(self, commit_pending_transaction=True):
"""Close Cassandra connection."""
if self.session:
self.session.shutdown()
if self.cluster:
self.cluster.shutdown()
def _create_schema(self):
"""Create Cassandra schema for triple storage."""
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
""")
# Create triples table optimized for SPARQL queries
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.triple_table} (
subject text,
predicate text,
object text,
object_datatype text,
object_language text,
is_literal boolean,
graph_id text,
PRIMARY KEY ((subject), predicate, object)
)
""")
# Create indexes for efficient querying
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate)
""")
self.session.execute(f"""
CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object)
""")
# Metadata table for graph information
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
graph_id text PRIMARY KEY,
created timestamp,
modified timestamp,
triple_count counter
)
""")
def triples(self, triple_pattern, context=None):
"""Retrieve triples matching the given pattern.
Args:
triple_pattern: (subject, predicate, object) pattern with None for variables
context: Graph context (optional)
Yields:
Matching triples as (subject, predicate, object) tuples
"""
if not self.session:
return
subject, predicate, object_val = triple_pattern
# Build CQL query based on pattern
cql_queries = self._pattern_to_cql(subject, predicate, object_val)
for cql, params in cql_queries:
try:
rows = self.session.execute(cql, params)
for row in rows:
yield self._row_to_triple(row)
except Exception as e:
logger.error(f"Error executing CQL query: {e}")
def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]:
"""Convert triple pattern to CQL queries.
Args:
subject: Subject node or None
predicate: Predicate node or None
object_val: Object node or None
Returns:
List of (CQL query, parameters) tuples
"""
queries = []
# Convert None to wildcard, nodes to strings
s_str = str(subject) if subject else None
p_str = str(predicate) if predicate else None
o_str = str(object_val) if object_val else None
if s_str and p_str and o_str:
# Specific triple lookup
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?"
queries.append((cql, [s_str, p_str, o_str]))
elif s_str and p_str:
# Subject and predicate known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?"
queries.append((cql, [s_str, p_str]))
elif s_str:
# Subject known
cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?"
queries.append((cql, [s_str]))
elif p_str:
# Predicate known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING"
queries.append((cql, [p_str]))
elif o_str:
# Object known (requires index scan)
cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING"
queries.append((cql, [o_str]))
else:
# Full scan (should be avoided in production)
cql = f"SELECT * FROM {self.triple_table}"
queries.append((cql, []))
return queries
def _row_to_triple(self, row):
"""Convert Cassandra row to RDF triple.
Args:
row: Cassandra row object
Returns:
(subject, predicate, object) tuple with rdflib nodes
"""
# Convert to rdflib nodes
subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject)
predicate = URIRef(row.predicate)
if row.is_literal:
# Create literal with datatype/language
if row.object_datatype:
object_node = Literal(row.object, datatype=URIRef(row.object_datatype))
elif row.object_language:
object_node = Literal(row.object, lang=row.object_language)
else:
object_node = Literal(row.object)
else:
object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object)
return (subject, predicate, object_node)
def add(self, triple, context=None, quoted=False):
"""Add a triple to the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
quoted: Whether triple is quoted
"""
if not self.session:
return
subject, predicate, object_val = triple
# Convert to storage format
s_str = str(subject)
p_str = str(predicate)
is_literal = isinstance(object_val, Literal)
o_str = str(object_val)
o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None
o_language = object_val.language if is_literal and object_val.language else None
# Insert into Cassandra
cql = f"""
INSERT INTO {self.triple_table}
(subject, predicate, object, object_datatype, object_language, is_literal, graph_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
try:
self.session.execute(cql, [
s_str, p_str, o_str, o_datatype, o_language, is_literal,
str(context) if context else 'default'
])
except Exception as e:
logger.error(f"Error adding triple: {e}")
def remove(self, triple, context=None):
"""Remove a triple from the store.
Args:
triple: (subject, predicate, object) tuple
context: Graph context
"""
if not self.session:
return
subject, predicate, object_val = triple
cql = f"""
DELETE FROM {self.triple_table}
WHERE subject = ? AND predicate = ? AND object = ?
"""
try:
self.session.execute(cql, [str(subject), str(predicate), str(object_val)])
except Exception as e:
logger.error(f"Error removing triple: {e}")
def __len__(self, context=None):
"""Get number of triples in store.
Args:
context: Graph context
Returns:
Number of triples
"""
if not self.session:
return 0
try:
cql = f"SELECT COUNT(*) FROM {self.triple_table}"
result = self.session.execute(cql)
return result.one().count
except Exception as e:
logger.error(f"Error counting triples: {e}")
return 0
class SPARQLCassandraEngine:
"""SPARQL processor using Cassandra backend."""
def __init__(self, cassandra_config: Dict[str, Any]):
"""Initialize SPARQL-Cassandra engine.
Args:
cassandra_config: Cassandra configuration
"""
if not RDFLIB_AVAILABLE:
raise RuntimeError("rdflib is required for SPARQL processing")
if not CASSANDRA_AVAILABLE:
raise RuntimeError("Cassandra driver is required")
self.cassandra_config = cassandra_config
self.store = CassandraTripleStore(cassandra_config)
self.graph = Graph(store=self.store)
# Common namespaces
self.namespaces = {
'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'),
'owl': Namespace('http://www.w3.org/2002/07/owl#'),
'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'),
}
# Bind namespaces to graph
for prefix, namespace in self.namespaces.items():
self.graph.bind(prefix, namespace)
async def initialize(self, create_schema=False):
"""Initialize the engine.
Args:
create_schema: Whether to create Cassandra schema
"""
success = self.store.open(create=create_schema)
if not success:
raise RuntimeError("Failed to connect to Cassandra")
logger.info("SPARQL-Cassandra engine initialized")
def execute_sparql(self, sparql_query: str) -> SPARQLResult:
"""Execute SPARQL query against Cassandra.
Args:
sparql_query: SPARQL query string
Returns:
Query results
"""
import time
start_time = time.time()
try:
# Prepare and execute query
prepared_query = prepareQuery(sparql_query)
result = self.graph.query(prepared_query)
execution_time = time.time() - start_time
# Format results based on query type
if sparql_query.strip().upper().startswith('ASK'):
return SPARQLResult(
bindings=[],
variables=[],
ask_result=bool(result),
execution_time=execution_time
)
else:
# SELECT query
bindings = []
variables = result.vars if hasattr(result, 'vars') else []
for row in result:
binding = {}
for i, var in enumerate(variables):
if i < len(row):
value = row[i]
binding[str(var)] = self._format_result_value(value)
bindings.append(binding)
return SPARQLResult(
bindings=bindings,
variables=[str(v) for v in variables],
execution_time=execution_time
)
except Exception as e:
logger.error(f"SPARQL execution error: {e}")
return SPARQLResult(
bindings=[],
variables=[],
execution_time=time.time() - start_time
)
def _format_result_value(self, value):
"""Format result value for output.
Args:
value: RDF value (URIRef, Literal, BNode)
Returns:
Formatted value
"""
if isinstance(value, URIRef):
return {'type': 'uri', 'value': str(value)}
elif isinstance(value, Literal):
result = {'type': 'literal', 'value': str(value)}
if value.datatype:
result['datatype'] = str(value.datatype)
if value.language:
result['language'] = value.language
return result
elif isinstance(value, BNode):
return {'type': 'bnode', 'value': str(value)}
else:
return {'type': 'unknown', 'value': str(value)}
def load_triples_from_store(self, config_store: ConfigTableStore):
"""Load triples from TrustGraph's storage into the RDF graph.
Args:
config_store: Configuration store with triples
"""
# This would need to be implemented based on how triples are stored
# in TrustGraph's Cassandra tables
logger.info("Loading triples from TrustGraph store...")
# Example implementation - would need to be adapted
# to actual TrustGraph storage format
try:
# Get all triple data
# This is a placeholder - actual implementation would need
# to query the appropriate TrustGraph tables
pass
except Exception as e:
logger.error(f"Error loading triples: {e}")
def close(self):
"""Close the engine and connections."""
if self.store:
self.store.close()
logger.info("SPARQL-Cassandra engine closed")

View file

@ -0,0 +1,487 @@
"""
SPARQL query generator for ontology-sensitive queries.
Converts natural language questions to SPARQL queries for Cassandra execution.
"""
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from .question_analyzer import QuestionComponents, QuestionType
from .ontology_matcher import QueryOntologySubset
logger = logging.getLogger(__name__)
@dataclass
class SPARQLQuery:
"""Generated SPARQL query with metadata."""
query: str
variables: List[str]
query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE
explanation: str
complexity_score: float
class SPARQLGenerator:
"""Generates SPARQL queries from natural language questions using LLM assistance."""
def __init__(self, prompt_service=None):
"""Initialize SPARQL generator.
Args:
prompt_service: Service for LLM-based query generation
"""
self.prompt_service = prompt_service
# SPARQL query templates for common patterns
self.templates = {
'simple_class_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?entity ?label WHERE {{
?entity rdf:type :{class_name} .
OPTIONAL {{ ?entity rdfs:label ?label }}
}}""",
'property_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?object WHERE {{
?subject :{property} ?object .
?subject rdf:type :{subject_class} .
}}""",
'hierarchy_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subclass ?superclass WHERE {{
?subclass rdfs:subClassOf* ?superclass .
?superclass rdf:type :{root_class} .
}}""",
'count_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
SELECT (COUNT(?entity) AS ?count) WHERE {{
?entity rdf:type :{class_name} .
{additional_constraints}
}}""",
'boolean_query': """
PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
ASK {{
{triple_pattern}
}}"""
}
async def generate_sparql(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate SPARQL query for a question.
Args:
question_components: Analyzed question components
ontology_subset: Relevant ontology subset
Returns:
Generated SPARQL query
"""
# Try template-based generation first
template_query = self._try_template_generation(question_components, ontology_subset)
if template_query:
logger.debug("Generated SPARQL using template")
return template_query
# Fall back to LLM-based generation
if self.prompt_service:
llm_query = await self._generate_with_llm(question_components, ontology_subset)
if llm_query:
logger.debug("Generated SPARQL using LLM")
return llm_query
# Final fallback to simple pattern
logger.warning("Falling back to simple SPARQL pattern")
return self._generate_fallback_query(question_components, ontology_subset)
def _try_template_generation(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Try to generate query using templates.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if no template matches
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Simple class query (What are the animals?)
if (question_components.question_type == QuestionType.RETRIEVAL and
len(question_components.entities) == 1 and
question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['simple_class_query'].format(
namespace=namespace,
class_name=class_name
)
return SPARQLQuery(
query=query,
variables=['entity', 'label'],
query_type='SELECT',
explanation=f"Retrieve all instances of {class_name}",
complexity_score=0.3
)
# Count query (How many animals are there?)
if (question_components.question_type == QuestionType.AGGREGATION and
'count' in question_components.aggregations and
len(question_components.entities) >= 1):
class_name = self._find_matching_class(question_components.entities[0], ontology_subset)
if class_name:
query = self.templates['count_query'].format(
namespace=namespace,
class_name=class_name,
additional_constraints=self._build_constraints(question_components, ontology_subset)
)
return SPARQLQuery(
query=query,
variables=['count'],
query_type='SELECT',
explanation=f"Count instances of {class_name}",
complexity_score=0.4
)
# Boolean query (Is X a Y?)
if question_components.question_type == QuestionType.BOOLEAN:
triple_pattern = self._build_boolean_pattern(question_components, ontology_subset)
if triple_pattern:
query = self.templates['boolean_query'].format(
namespace=namespace,
triple_pattern=triple_pattern
)
return SPARQLQuery(
query=query,
variables=[],
query_type='ASK',
explanation="Boolean query for fact checking",
complexity_score=0.2
)
return None
async def _generate_with_llm(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]:
"""Generate SPARQL using LLM.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Generated query or None if failed
"""
try:
prompt = self._build_sparql_prompt(question_components, ontology_subset)
response = await self.prompt_service.generate_sparql(prompt=prompt)
if response and isinstance(response, dict):
query = response.get('query', '').strip()
if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')):
return SPARQLQuery(
query=query,
variables=self._extract_variables(query),
query_type=query.split()[0].upper(),
explanation=response.get('explanation', 'Generated by LLM'),
complexity_score=self._calculate_complexity(query)
)
except Exception as e:
logger.error(f"LLM SPARQL generation failed: {e}")
return None
def _build_sparql_prompt(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build prompt for LLM SPARQL generation.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Formatted prompt string
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Format ontology elements
classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace)
props_str = self._format_properties_for_prompt(
ontology_subset.object_properties,
ontology_subset.datatype_properties,
namespace
)
prompt = f"""Generate a SPARQL query for the following question using the provided ontology.
QUESTION: {question_components.original_question}
ONTOLOGY NAMESPACE: {namespace}
AVAILABLE CLASSES:
{classes_str}
AVAILABLE PROPERTIES:
{props_str}
RULES:
- Use proper SPARQL syntax
- Include appropriate prefixes
- Use property paths for hierarchical queries (rdfs:subClassOf*)
- Add FILTER clauses for constraints
- Optimize for Cassandra backend
- Return both query and explanation
QUERY TYPE HINTS:
- Question type: {question_components.question_type.value}
- Expected answer: {question_components.expected_answer_type}
- Entities mentioned: {', '.join(question_components.entities)}
- Aggregations: {', '.join(question_components.aggregations)}
Generate a complete SPARQL query:"""
return prompt
def _generate_fallback_query(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> SPARQLQuery:
"""Generate simple fallback query.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
Basic SPARQL query
"""
namespace = ontology_subset.metadata.get('namespace', 'http://example.org/')
# Very basic SELECT query
query = f"""PREFIX : <{namespace}>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?subject ?predicate ?object WHERE {{
?subject ?predicate ?object .
FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}"))
}}
LIMIT 10"""
return SPARQLQuery(
query=query,
variables=['subject', 'predicate', 'object'],
query_type='SELECT',
explanation="Fallback query for basic pattern matching",
complexity_score=0.1
)
def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Find matching class in ontology subset.
Args:
entity: Entity string to match
ontology_subset: Ontology subset
Returns:
Matching class name or None
"""
entity_lower = entity.lower()
# Direct match
for class_id in ontology_subset.classes:
if class_id.lower() == entity_lower:
return class_id
# Label match
for class_id, class_def in ontology_subset.classes.items():
labels = class_def.get('labels', [])
for label in labels:
if isinstance(label, dict):
label_value = label.get('value', '').lower()
if label_value == entity_lower:
return class_id
# Partial match
for class_id in ontology_subset.classes:
if entity_lower in class_id.lower() or class_id.lower() in entity_lower:
return class_id
return None
def _build_constraints(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> str:
"""Build constraint clauses for SPARQL.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL constraint string
"""
constraints = []
for constraint in question_components.constraints:
# Simple constraint patterns
if 'greater than' in constraint.lower():
# Extract number
import re
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value > {numbers[0]})")
elif 'less than' in constraint.lower():
numbers = re.findall(r'\d+', constraint)
if numbers:
constraints.append(f"FILTER(?value < {numbers[0]})")
return '\n '.join(constraints)
def _build_boolean_pattern(self,
question_components: QuestionComponents,
ontology_subset: QueryOntologySubset) -> Optional[str]:
"""Build triple pattern for boolean queries.
Args:
question_components: Question analysis
ontology_subset: Ontology subset
Returns:
SPARQL triple pattern or None
"""
if len(question_components.entities) >= 2:
subject = question_components.entities[0]
object_val = question_components.entities[1]
# Try to find connecting property
for prop_id in ontology_subset.object_properties:
return f":{subject} :{prop_id} :{object_val} ."
# Fallback to type check
return f":{subject} rdf:type :{object_val} ."
return None
def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str:
"""Format classes for prompt.
Args:
classes: Classes dictionary
namespace: Ontology namespace
Returns:
Formatted classes string
"""
if not classes:
return "None"
lines = []
for class_id, definition in classes.items():
comment = definition.get('comment', '')
parent = definition.get('subclass_of', 'Thing')
lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}")
return '\n'.join(lines)
def _format_properties_for_prompt(self,
object_props: Dict[str, Any],
datatype_props: Dict[str, Any],
namespace: str) -> str:
"""Format properties for prompt.
Args:
object_props: Object properties
datatype_props: Datatype properties
namespace: Ontology namespace
Returns:
Formatted properties string
"""
lines = []
for prop_id, definition in object_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'Any')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}")
for prop_id, definition in datatype_props.items():
domain = definition.get('domain', 'Any')
range_val = definition.get('range', 'xsd:string')
comment = definition.get('comment', '')
lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}")
return '\n'.join(lines) if lines else "None"
def _extract_variables(self, query: str) -> List[str]:
"""Extract variables from SPARQL query.
Args:
query: SPARQL query string
Returns:
List of variable names
"""
import re
variables = re.findall(r'\?(\w+)', query)
return list(set(variables))
def _calculate_complexity(self, query: str) -> float:
"""Calculate complexity score for SPARQL query.
Args:
query: SPARQL query string
Returns:
Complexity score (0.0 to 1.0)
"""
complexity = 0.0
# Count different SPARQL features
query_upper = query.upper()
if 'JOIN' in query_upper or 'UNION' in query_upper:
complexity += 0.3
if 'FILTER' in query_upper:
complexity += 0.2
if 'OPTIONAL' in query_upper:
complexity += 0.1
if 'GROUP BY' in query_upper:
complexity += 0.2
if 'ORDER BY' in query_upper:
complexity += 0.1
if '*' in query: # Property paths
complexity += 0.1
# Count variables
variables = self._extract_variables(query)
complexity += len(variables) * 0.05
return min(complexity, 1.0)