mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-10 15:52:36 +02:00
Refactor to use config event queue
This commit is contained in:
parent
ef0f2b6837
commit
dfd7ad3a56
6 changed files with 251 additions and 498 deletions
|
|
@ -30,6 +30,7 @@ dependencies = [
|
|||
"minio",
|
||||
"mistralai",
|
||||
"neo4j",
|
||||
"nltk",
|
||||
"ollama",
|
||||
"openai",
|
||||
"pinecone[grpc]",
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from .... schema import Chunk, Triple, Triples, Metadata, Value
|
|||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
from .... tables.config import ConfigTableStore
|
||||
from .... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
|
||||
from .ontology_loader import OntologyLoader
|
||||
from .ontology_embedder import OntologyEmbedder
|
||||
|
|
@ -58,6 +57,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name="embeddings-request",
|
||||
response_name="embeddings-response"
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name="triples",
|
||||
|
|
@ -65,6 +71,9 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Register config handler for ontology updates
|
||||
self.register_config_handler(self.on_ontology_config)
|
||||
|
||||
# Initialize components
|
||||
self.ontology_loader = None
|
||||
self.ontology_embedder = None
|
||||
|
|
@ -75,13 +84,10 @@ class Processor(FlowProcessor):
|
|||
# Configuration
|
||||
self.top_k = params.get("top_k", 10)
|
||||
self.similarity_threshold = params.get("similarity_threshold", 0.7)
|
||||
self.refresh_interval = params.get("ontology_refresh_interval", 300)
|
||||
|
||||
# Cassandra configuration for config store
|
||||
self.cassandra_host = params.get("cassandra_host", "localhost")
|
||||
self.cassandra_username = params.get("cassandra_username", "cassandra")
|
||||
self.cassandra_password = params.get("cassandra_password", "cassandra")
|
||||
self.cassandra_keyspace = params.get("cassandra_keyspace", "trustgraph")
|
||||
# Track loaded ontology version
|
||||
self.current_ontology_version = None
|
||||
self.loaded_ontology_ids = set()
|
||||
|
||||
async def initialize_components(self, flow):
|
||||
"""Initialize OntoRAG components."""
|
||||
|
|
@ -89,38 +95,24 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
|
||||
try:
|
||||
# Create configuration store
|
||||
config_store = ConfigTableStore(
|
||||
self.cassandra_host,
|
||||
self.cassandra_username,
|
||||
self.cassandra_password,
|
||||
self.cassandra_keyspace
|
||||
)
|
||||
# Initialize ontology loader (no ConfigTableStore needed)
|
||||
self.ontology_loader = OntologyLoader()
|
||||
logger.info("Ontology loader initialized")
|
||||
|
||||
# Initialize ontology loader
|
||||
self.ontology_loader = OntologyLoader(config_store)
|
||||
ontologies = await self.ontology_loader.load_ontologies()
|
||||
logger.info(f"Loaded {len(ontologies)} ontologies")
|
||||
|
||||
# Initialize vector store
|
||||
vector_store = InMemoryVectorStore.create(
|
||||
# Initialize vector store (FAISS only, no fallback)
|
||||
vector_store = InMemoryVectorStore(
|
||||
dimension=1536, # text-embedding-3-small
|
||||
prefer_faiss=True,
|
||||
index_type='flat'
|
||||
)
|
||||
|
||||
# Initialize ontology embedder with embedding service wrapper
|
||||
embedding_service = EmbeddingServiceWrapper(flow)
|
||||
# Use embeddings client directly (no wrapper needed)
|
||||
embeddings_client = flow("embeddings-request")
|
||||
|
||||
self.ontology_embedder = OntologyEmbedder(
|
||||
embedding_service=embedding_service,
|
||||
embedding_service=embeddings_client,
|
||||
vector_store=vector_store
|
||||
)
|
||||
|
||||
# Embed all ontologies
|
||||
if ontologies:
|
||||
await self.ontology_embedder.embed_ontologies(ontologies)
|
||||
logger.info(f"Embedded {self.ontology_embedder.get_embedded_count()} ontology elements")
|
||||
|
||||
# Initialize ontology selector
|
||||
self.ontology_selector = OntologySelector(
|
||||
ontology_embedder=self.ontology_embedder,
|
||||
|
|
@ -132,28 +124,90 @@ class Processor(FlowProcessor):
|
|||
self.initialized = True
|
||||
logger.info("OntoRAG components initialized successfully")
|
||||
|
||||
# Schedule periodic refresh
|
||||
asyncio.create_task(self.refresh_ontologies_periodically())
|
||||
# NOTE: Ontologies will be loaded via on_ontology_config() handler
|
||||
# when ConfigPush messages arrive (including initial config on startup)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OntoRAG components: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def refresh_ontologies_periodically(self):
|
||||
"""Periodically refresh ontologies from configuration."""
|
||||
while True:
|
||||
await asyncio.sleep(self.refresh_interval)
|
||||
try:
|
||||
logger.info("Refreshing ontologies...")
|
||||
ontologies = await self.ontology_loader.refresh_ontologies()
|
||||
if ontologies:
|
||||
# Re-embed new ontologies
|
||||
for ont_id in ontologies:
|
||||
if not self.ontology_embedder.is_ontology_embedded(ont_id):
|
||||
await self.ontology_embedder.embed_ontology(ontologies[ont_id])
|
||||
logger.info("Ontology refresh complete")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing ontologies: {e}", exc_info=True)
|
||||
async def on_ontology_config(self, config, version):
|
||||
"""
|
||||
Handle ontology configuration updates from ConfigPush queue.
|
||||
|
||||
Called automatically when:
|
||||
- Processor starts (gets full config history via start_of_messages=True)
|
||||
- Config service pushes updates (immediate event-driven notification)
|
||||
|
||||
Args:
|
||||
config: Full configuration map - config[type][key] = value
|
||||
version: Config version number (monotonically increasing)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Received ontology config update, version={version}")
|
||||
|
||||
# Skip if we've already processed this version
|
||||
if version == self.current_ontology_version:
|
||||
logger.debug(f"Already at version {version}, skipping")
|
||||
return
|
||||
|
||||
# Extract ontology configurations
|
||||
if "ontology" not in config:
|
||||
logger.warning("No 'ontology' section in config")
|
||||
return
|
||||
|
||||
ontology_configs = config["ontology"]
|
||||
|
||||
# Parse ontology definitions
|
||||
ontologies = {}
|
||||
for ont_id, ont_json in ontology_configs.items():
|
||||
try:
|
||||
ontologies[ont_id] = json.loads(ont_json)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(ontologies)} ontology definitions")
|
||||
|
||||
# Determine what changed (for incremental updates)
|
||||
new_ids = set(ontologies.keys())
|
||||
added_ids = new_ids - self.loaded_ontology_ids
|
||||
removed_ids = self.loaded_ontology_ids - new_ids
|
||||
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
|
||||
|
||||
if added_ids:
|
||||
logger.info(f"New ontologies: {added_ids}")
|
||||
if removed_ids:
|
||||
logger.info(f"Removed ontologies: {removed_ids}")
|
||||
if updated_ids:
|
||||
logger.info(f"Updated ontologies: {updated_ids}")
|
||||
|
||||
# Update ontology loader's internal state
|
||||
self.ontology_loader.update_ontologies(ontologies)
|
||||
|
||||
# Re-embed changed ontologies
|
||||
if self.ontology_embedder:
|
||||
# Remove embeddings for deleted ontologies
|
||||
for ont_id in removed_ids:
|
||||
self.ontology_embedder.remove_ontology(ont_id)
|
||||
|
||||
# Embed new and updated ontologies
|
||||
for ont_id in added_ids | updated_ids:
|
||||
if ont_id in self.ontology_loader.get_all_ontologies():
|
||||
await self.ontology_embedder.embed_ontology(
|
||||
self.ontology_loader.get_ontology(ont_id)
|
||||
)
|
||||
|
||||
logger.info(f"Re-embedded ontologies, total elements: {self.ontology_embedder.get_embedded_count()}")
|
||||
|
||||
# Update tracking
|
||||
self.current_ontology_version = version
|
||||
self.loaded_ontology_ids = new_ids
|
||||
|
||||
logger.info(f"Ontology config update complete, version={version}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Process incoming chunk message."""
|
||||
|
|
@ -403,71 +457,9 @@ TRIPLES (JSON array):"""
|
|||
default=0.7,
|
||||
help='Similarity threshold for ontology matching (default: 0.7)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--ontology-refresh-interval',
|
||||
type=int,
|
||||
default=300,
|
||||
help='Ontology refresh interval in seconds (default: 300)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
type=str,
|
||||
default='localhost',
|
||||
help='Cassandra host (default: localhost)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cassandra-username',
|
||||
type=str,
|
||||
default='cassandra',
|
||||
help='Cassandra username (default: cassandra)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
type=str,
|
||||
default='cassandra',
|
||||
help='Cassandra password (default: cassandra)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cassandra-keyspace',
|
||||
type=str,
|
||||
default='trustgraph',
|
||||
help='Cassandra keyspace (default: trustgraph)'
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
||||
class EmbeddingServiceWrapper:
|
||||
"""Wrapper to adapt flow prompt service to embedding service interface."""
|
||||
|
||||
def __init__(self, flow):
|
||||
self.flow = flow
|
||||
|
||||
async def embed(self, text: str):
|
||||
"""Generate embedding for single text."""
|
||||
try:
|
||||
response = await self.flow("prompt-request").get_embedding(text=text)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
return None
|
||||
|
||||
async def embed_batch(self, texts: List[str]):
|
||||
"""Generate embeddings for multiple texts."""
|
||||
try:
|
||||
# Process in parallel for better performance
|
||||
tasks = [self.embed(text) for text in texts]
|
||||
embeddings = await asyncio.gather(*tasks)
|
||||
# Filter out None values and convert to array
|
||||
import numpy as np
|
||||
valid_embeddings = [e for e in embeddings if e is not None]
|
||||
if valid_embeddings:
|
||||
return np.array(valid_embeddings)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Batch embedding service error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def run():
|
||||
"""Launch the OntoRAG extraction service."""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -9,7 +9,7 @@ from typing import Dict, List, Any, Optional
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .ontology_loader import Ontology, OntologyClass, OntologyProperty
|
||||
from .vector_store import VectorStore, InMemoryVectorStore
|
||||
from .vector_store import InMemoryVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -27,15 +27,15 @@ class OntologyElementMetadata:
|
|||
class OntologyEmbedder:
|
||||
"""Generates embeddings for ontology elements and stores them in vector store."""
|
||||
|
||||
def __init__(self, embedding_service=None, vector_store: Optional[VectorStore] = None):
|
||||
def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None):
|
||||
"""Initialize the ontology embedder.
|
||||
|
||||
Args:
|
||||
embedding_service: Service for generating embeddings
|
||||
vector_store: Vector store instance (defaults to InMemoryVectorStore)
|
||||
vector_store: Vector store instance (InMemoryVectorStore)
|
||||
"""
|
||||
self.embedding_service = embedding_service
|
||||
self.vector_store = vector_store or InMemoryVectorStore.create()
|
||||
self.vector_store = vector_store or InMemoryVectorStore()
|
||||
self.embedded_ontologies = set()
|
||||
|
||||
def _create_text_representation(self, element_id: str, element: Any,
|
||||
|
|
@ -232,6 +232,25 @@ class OntologyEmbedder:
|
|||
logger.error(f"Failed to embed texts: {e}")
|
||||
return None
|
||||
|
||||
def remove_ontology(self, ontology_id: str):
|
||||
"""Remove all embeddings for a specific ontology.
|
||||
|
||||
Note: FAISS doesn't support efficient deletion, so this currently
|
||||
requires rebuilding the entire index without the removed ontology.
|
||||
|
||||
Args:
|
||||
ontology_id: ID of ontology to remove
|
||||
"""
|
||||
if ontology_id not in self.embedded_ontologies:
|
||||
logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove")
|
||||
return
|
||||
|
||||
# FAISS doesn't support selective deletion, so we'd need to rebuild the index
|
||||
# For now, just remove from tracking set
|
||||
# TODO: Implement index rebuilding if selective removal is needed
|
||||
self.embedded_ontologies.discard(ontology_id)
|
||||
logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)")
|
||||
|
||||
def clear_embeddings(self, ontology_id: Optional[str] = None):
|
||||
"""Clear embeddings from vector store.
|
||||
|
||||
|
|
@ -240,15 +259,13 @@ class OntologyEmbedder:
|
|||
Otherwise, clear all embeddings
|
||||
"""
|
||||
if ontology_id:
|
||||
# Would need to implement selective clearing in vector store
|
||||
# For now, log warning
|
||||
logger.warning(f"Selective clearing not implemented, would clear {ontology_id}")
|
||||
self.remove_ontology(ontology_id)
|
||||
else:
|
||||
self.vector_store.clear()
|
||||
self.embedded_ontologies.clear()
|
||||
logger.info("Cleared all embeddings from vector store")
|
||||
|
||||
def get_vector_store(self) -> VectorStore:
|
||||
def get_vector_store(self) -> InMemoryVectorStore:
|
||||
"""Get the vector store instance.
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -158,84 +158,61 @@ class Ontology:
|
|||
|
||||
|
||||
class OntologyLoader:
|
||||
"""Loads and manages ontologies from configuration service."""
|
||||
"""Manages ontologies received via event-driven config updates.
|
||||
|
||||
def __init__(self, config_store=None):
|
||||
"""Initialize the ontology loader.
|
||||
No direct database access - receives ontologies via config handler.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize empty ontology store."""
|
||||
self.ontologies: Dict[str, Ontology] = {}
|
||||
|
||||
def update_ontologies(self, ontology_configs: Dict[str, Any]):
|
||||
"""Update ontology definitions from config.
|
||||
|
||||
Args:
|
||||
config_store: Configuration store instance (injected dependency)
|
||||
ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts)
|
||||
"""
|
||||
self.config_store = config_store
|
||||
self.ontologies: Dict[str, Ontology] = {}
|
||||
self.refresh_interval = 300 # Default 5 minutes
|
||||
self.ontologies.clear()
|
||||
|
||||
async def load_ontologies(self) -> Dict[str, Ontology]:
|
||||
"""Load all ontologies from configuration service.
|
||||
for ont_id, ont_data in ontology_configs.items():
|
||||
try:
|
||||
# Parse classes
|
||||
classes = {}
|
||||
for class_id, class_data in ont_data.get('classes', {}).items():
|
||||
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
|
||||
|
||||
Returns:
|
||||
Dictionary of ontology ID to Ontology objects
|
||||
"""
|
||||
if not self.config_store:
|
||||
logger.warning("No configuration store available, returning empty ontologies")
|
||||
return {}
|
||||
# Parse object properties
|
||||
object_props = {}
|
||||
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
|
||||
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
|
||||
try:
|
||||
# Get all ontology configurations
|
||||
ontology_configs = await self.config_store.get("ontology").values()
|
||||
# Parse datatype properties
|
||||
datatype_props = {}
|
||||
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
|
||||
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
|
||||
for ont_id, ont_data in ontology_configs.items():
|
||||
try:
|
||||
# Parse JSON if string
|
||||
if isinstance(ont_data, str):
|
||||
ont_data = json.loads(ont_data)
|
||||
# Create ontology
|
||||
ontology = Ontology(
|
||||
id=ont_id,
|
||||
metadata=ont_data.get('metadata', {}),
|
||||
classes=classes,
|
||||
object_properties=object_props,
|
||||
datatype_properties=datatype_props
|
||||
)
|
||||
|
||||
# Parse classes
|
||||
classes = {}
|
||||
for class_id, class_data in ont_data.get('classes', {}).items():
|
||||
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
|
||||
# Validate structure
|
||||
issues = ontology.validate_structure()
|
||||
if issues:
|
||||
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
|
||||
|
||||
# Parse object properties
|
||||
object_props = {}
|
||||
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
|
||||
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
self.ontologies[ont_id] = ontology
|
||||
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
|
||||
f"{len(object_props)} object properties, "
|
||||
f"{len(datatype_props)} datatype properties")
|
||||
|
||||
# Parse datatype properties
|
||||
datatype_props = {}
|
||||
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
|
||||
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
|
||||
|
||||
# Create ontology
|
||||
ontology = Ontology(
|
||||
id=ont_id,
|
||||
metadata=ont_data.get('metadata', {}),
|
||||
classes=classes,
|
||||
object_properties=object_props,
|
||||
datatype_properties=datatype_props
|
||||
)
|
||||
|
||||
# Validate structure
|
||||
issues = ontology.validate_structure()
|
||||
if issues:
|
||||
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
|
||||
|
||||
self.ontologies[ont_id] = ontology
|
||||
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
|
||||
f"{len(object_props)} object properties, "
|
||||
f"{len(datatype_props)} datatype properties")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontologies from config: {e}", exc_info=True)
|
||||
|
||||
return self.ontologies
|
||||
|
||||
async def refresh_ontologies(self):
|
||||
"""Refresh ontologies from configuration service."""
|
||||
logger.info("Refreshing ontologies...")
|
||||
return await self.load_ontologies()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
|
||||
|
||||
def get_ontology(self, ont_id: str) -> Optional[Ontology]:
|
||||
"""Get a specific ontology by ID.
|
||||
|
|
@ -256,6 +233,14 @@ class OntologyLoader:
|
|||
"""
|
||||
return self.ontologies
|
||||
|
||||
def list_ontology_ids(self) -> List[str]:
|
||||
"""Get list of loaded ontology IDs.
|
||||
|
||||
Returns:
|
||||
List of ontology IDs
|
||||
"""
|
||||
return list(self.ontologies.keys())
|
||||
|
||||
def clear(self):
|
||||
"""Clear all loaded ontologies."""
|
||||
self.ontologies.clear()
|
||||
|
|
|
|||
|
|
@ -7,38 +7,26 @@ import logging
|
|||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import NLTK for advanced text processing
|
||||
# Ensure required NLTK data is downloaded
|
||||
try:
|
||||
import nltk
|
||||
NLTK_AVAILABLE = True
|
||||
# Try to ensure required NLTK data is downloaded
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download('punkt', quiet=True)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download('averaged_perceptron_tagger', quiet=True)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
nltk.data.find('corpora/stopwords')
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download('stopwords', quiet=True)
|
||||
except:
|
||||
pass
|
||||
except ImportError:
|
||||
NLTK_AVAILABLE = False
|
||||
logger.warning("NLTK not available, using basic text processing")
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
nltk.download('punkt', quiet=True)
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except LookupError:
|
||||
nltk.download('averaged_perceptron_tagger', quiet=True)
|
||||
|
||||
try:
|
||||
nltk.data.find('corpora/stopwords')
|
||||
except LookupError:
|
||||
nltk.download('stopwords', quiet=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -52,18 +40,12 @@ class TextSegment:
|
|||
|
||||
|
||||
class SentenceSplitter:
|
||||
"""Splits text into sentences using available NLP tools."""
|
||||
"""Splits text into sentences using NLTK."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sentence splitter."""
|
||||
self.use_nltk = NLTK_AVAILABLE
|
||||
if self.use_nltk:
|
||||
try:
|
||||
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
logger.info("Using NLTK sentence tokenizer")
|
||||
except:
|
||||
self.use_nltk = False
|
||||
logger.warning("NLTK punkt tokenizer not available, using regex")
|
||||
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
|
||||
logger.info("Using NLTK sentence tokenizer")
|
||||
|
||||
def split(self, text: str) -> List[str]:
|
||||
"""Split text into sentences.
|
||||
|
|
@ -74,35 +56,16 @@ class SentenceSplitter:
|
|||
Returns:
|
||||
List of sentences
|
||||
"""
|
||||
if self.use_nltk:
|
||||
try:
|
||||
sentences = self.sent_detector.tokenize(text)
|
||||
return sentences
|
||||
except Exception as e:
|
||||
logger.warning(f"NLTK sentence splitting failed: {e}, falling back to regex")
|
||||
|
||||
# Fallback to regex-based splitting
|
||||
# Simple sentence boundary detection
|
||||
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)
|
||||
# Filter out empty sentences
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
sentences = self.sent_detector.tokenize(text)
|
||||
return sentences
|
||||
|
||||
|
||||
class PhraseExtractor:
|
||||
"""Extracts meaningful phrases from sentences."""
|
||||
"""Extracts meaningful phrases from sentences using NLTK."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize phrase extractor."""
|
||||
self.use_nltk = NLTK_AVAILABLE
|
||||
if self.use_nltk:
|
||||
try:
|
||||
# Test that POS tagger is available
|
||||
nltk.pos_tag(['test'])
|
||||
logger.info("Using NLTK phrase extraction")
|
||||
except:
|
||||
self.use_nltk = False
|
||||
logger.warning("NLTK POS tagger not available, using basic extraction")
|
||||
logger.info("Using NLTK phrase extraction")
|
||||
|
||||
def extract(self, sentence: str) -> List[Dict[str, str]]:
|
||||
"""Extract phrases from a sentence.
|
||||
|
|
@ -115,104 +78,49 @@ class PhraseExtractor:
|
|||
"""
|
||||
phrases = []
|
||||
|
||||
if self.use_nltk:
|
||||
try:
|
||||
phrases.extend(self._extract_nltk_phrases(sentence))
|
||||
except Exception as e:
|
||||
logger.warning(f"NLTK phrase extraction failed: {e}, using basic extraction")
|
||||
phrases.extend(self._extract_basic_phrases(sentence))
|
||||
else:
|
||||
phrases.extend(self._extract_basic_phrases(sentence))
|
||||
# Tokenize and POS tag
|
||||
tokens = nltk.word_tokenize(sentence)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
|
||||
return phrases
|
||||
# Extract noun phrases (simple pattern)
|
||||
noun_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('NN') or pos.startswith('JJ'):
|
||||
noun_phrase.append(word)
|
||||
elif noun_phrase:
|
||||
if len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
noun_phrase = []
|
||||
|
||||
def _extract_nltk_phrases(self, sentence: str) -> List[Dict[str, str]]:
|
||||
"""Extract phrases using NLTK.
|
||||
# Add last noun phrase if exists
|
||||
if noun_phrase and len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
|
||||
Args:
|
||||
sentence: Sentence to process
|
||||
# Extract verb phrases (simple pattern)
|
||||
verb_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('VB') or pos.startswith('RB'):
|
||||
verb_phrase.append(word)
|
||||
elif verb_phrase:
|
||||
if len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
verb_phrase = []
|
||||
|
||||
Returns:
|
||||
List of phrases with types
|
||||
"""
|
||||
phrases = []
|
||||
|
||||
try:
|
||||
# Tokenize and POS tag
|
||||
tokens = nltk.word_tokenize(sentence)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
|
||||
# Extract noun phrases (simple pattern)
|
||||
noun_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('NN') or pos.startswith('JJ'):
|
||||
noun_phrase.append(word)
|
||||
elif noun_phrase:
|
||||
if len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
noun_phrase = []
|
||||
|
||||
# Add last noun phrase if exists
|
||||
if noun_phrase and len(noun_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(noun_phrase),
|
||||
'type': 'noun_phrase'
|
||||
})
|
||||
|
||||
# Extract verb phrases (simple pattern)
|
||||
verb_phrase = []
|
||||
for word, pos in pos_tags:
|
||||
if pos.startswith('VB') or pos.startswith('RB'):
|
||||
verb_phrase.append(word)
|
||||
elif verb_phrase:
|
||||
if len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
verb_phrase = []
|
||||
|
||||
# Add last verb phrase if exists
|
||||
if verb_phrase and len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in NLTK phrase extraction: {e}")
|
||||
|
||||
return phrases
|
||||
|
||||
def _extract_basic_phrases(self, sentence: str) -> List[Dict[str, str]]:
|
||||
"""Extract phrases using basic regex patterns.
|
||||
|
||||
Args:
|
||||
sentence: Sentence to process
|
||||
|
||||
Returns:
|
||||
List of phrases with types
|
||||
"""
|
||||
phrases = []
|
||||
|
||||
# Extract quoted phrases
|
||||
quoted = re.findall(r'"([^"]+)"', sentence)
|
||||
for q in quoted:
|
||||
phrases.append({'text': q, 'type': 'phrase'})
|
||||
|
||||
# Extract parenthetical phrases
|
||||
parens = re.findall(r'\(([^)]+)\)', sentence)
|
||||
for p in parens:
|
||||
phrases.append({'text': p, 'type': 'phrase'})
|
||||
|
||||
# Extract capitalized sequences (potential entities)
|
||||
caps = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', sentence)
|
||||
for c in caps:
|
||||
if len(c.split()) > 1: # Multi-word entities
|
||||
phrases.append({'text': c, 'type': 'noun_phrase'})
|
||||
# Add last verb phrase if exists
|
||||
if verb_phrase and len(verb_phrase) > 1:
|
||||
phrases.append({
|
||||
'text': ' '.join(verb_phrase),
|
||||
'type': 'verb_phrase'
|
||||
})
|
||||
|
||||
return phrases
|
||||
|
||||
|
|
@ -279,21 +187,8 @@ class TextProcessor:
|
|||
# Split on word boundaries
|
||||
words = re.findall(r'\b\w+\b', text.lower())
|
||||
|
||||
# Filter common stop words (basic list)
|
||||
stop_words = {
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
||||
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
|
||||
'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
||||
'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall'
|
||||
}
|
||||
|
||||
# Use NLTK stopwords if available
|
||||
if NLTK_AVAILABLE:
|
||||
try:
|
||||
from nltk.corpus import stopwords
|
||||
stop_words = set(stopwords.words('english'))
|
||||
except:
|
||||
pass
|
||||
# Use NLTK stopwords
|
||||
stop_words = set(stopwords.words('english'))
|
||||
|
||||
# Filter stopwords and short words
|
||||
terms = [w for w in words if w not in stop_words and len(w) > 2]
|
||||
|
|
@ -322,4 +217,4 @@ class TextProcessor:
|
|||
# Normalize quotes
|
||||
text = text.replace('"', '"').replace('"', '"')
|
||||
text = text.replace(''', "'").replace(''', "'")
|
||||
return text
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -1,23 +1,16 @@
|
|||
"""
|
||||
Vector store implementations for OntoRAG system.
|
||||
Provides both FAISS and NumPy-based vector storage for ontology embeddings.
|
||||
Vector store implementation for OntoRAG system.
|
||||
Provides FAISS-based vector storage for ontology embeddings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
import faiss
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import FAISS, fall back to NumPy implementation if not available
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, using NumPy implementation")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
|
|
@ -27,34 +20,8 @@ class SearchResult:
|
|||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Abstract base class for vector stores."""
|
||||
|
||||
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Add single embedding with metadata."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_batch(self, ids: List[str], embeddings: np.ndarray,
|
||||
metadata_list: List[Dict[str, Any]]):
|
||||
"""Batch add for initial ontology loading."""
|
||||
raise NotImplementedError
|
||||
|
||||
def search(self, embedding: np.ndarray, top_k: int = 10,
|
||||
threshold: float = 0.0) -> List[SearchResult]:
|
||||
"""Search for similar vectors."""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
"""Reset the store."""
|
||||
raise NotImplementedError
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return number of stored vectors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FAISSVectorStore(VectorStore):
|
||||
"""FAISS-based vector store implementation."""
|
||||
class InMemoryVectorStore:
|
||||
"""FAISS-based vector store implementation for ontology embeddings."""
|
||||
|
||||
def __init__(self, dimension: int = 1536, index_type: str = 'flat'):
|
||||
"""Initialize FAISS vector store.
|
||||
|
|
@ -63,9 +30,6 @@ class FAISSVectorStore(VectorStore):
|
|||
dimension: Embedding dimension (1536 for text-embedding-3-small)
|
||||
index_type: 'flat' for exact search, 'ivf' for larger datasets
|
||||
"""
|
||||
if not FAISS_AVAILABLE:
|
||||
raise RuntimeError("FAISS is not installed")
|
||||
|
||||
self.dimension = dimension
|
||||
self.metadata = []
|
||||
self.ids = []
|
||||
|
|
@ -141,107 +105,6 @@ class FAISSVectorStore(VectorStore):
|
|||
return self.index.ntotal
|
||||
|
||||
|
||||
class SimpleVectorStore(VectorStore):
|
||||
"""NumPy-based vector store implementation for development/small deployments."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize simple NumPy-based vector store."""
|
||||
self.embeddings = []
|
||||
self.metadata = []
|
||||
self.ids = []
|
||||
logger.info("Created SimpleVectorStore (NumPy implementation)")
|
||||
|
||||
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Add single embedding with metadata."""
|
||||
# Normalize for cosine similarity
|
||||
normalized = embedding / np.linalg.norm(embedding)
|
||||
self.embeddings.append(normalized)
|
||||
self.metadata.append(metadata)
|
||||
self.ids.append(id)
|
||||
|
||||
def add_batch(self, ids: List[str], embeddings: np.ndarray,
|
||||
metadata_list: List[Dict[str, Any]]):
|
||||
"""Batch add for initial ontology loading."""
|
||||
# Normalize all embeddings
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1, norms)
|
||||
normalized = embeddings / norms
|
||||
|
||||
for i in range(len(ids)):
|
||||
self.embeddings.append(normalized[i])
|
||||
self.metadata.append(metadata_list[i])
|
||||
self.ids.append(ids[i])
|
||||
|
||||
logger.debug(f"Added batch of {len(ids)} embeddings to simple store")
|
||||
|
||||
def search(self, embedding: np.ndarray, top_k: int = 10,
|
||||
threshold: float = 0.0) -> List[SearchResult]:
|
||||
"""Search for similar vectors using cosine similarity."""
|
||||
if not self.embeddings:
|
||||
return []
|
||||
|
||||
# Normalize query embedding
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
|
||||
# Compute cosine similarities
|
||||
embeddings_array = np.array(self.embeddings)
|
||||
similarities = np.dot(embeddings_array, embedding)
|
||||
|
||||
# Get top-k indices
|
||||
top_k = min(top_k, len(self.embeddings))
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
if similarities[idx] >= threshold:
|
||||
results.append(SearchResult(
|
||||
id=self.ids[idx],
|
||||
score=float(similarities[idx]),
|
||||
metadata=self.metadata[idx]
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def clear(self):
|
||||
"""Reset the store."""
|
||||
self.embeddings = []
|
||||
self.metadata = []
|
||||
self.ids = []
|
||||
logger.info("Cleared simple vector store")
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return number of stored vectors."""
|
||||
return len(self.embeddings)
|
||||
|
||||
|
||||
class InMemoryVectorStore:
|
||||
"""Factory class to create appropriate vector store based on availability."""
|
||||
|
||||
@staticmethod
|
||||
def create(dimension: int = 1536, prefer_faiss: bool = True,
|
||||
index_type: str = 'flat') -> VectorStore:
|
||||
"""Create a vector store instance.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension
|
||||
prefer_faiss: Whether to prefer FAISS if available
|
||||
index_type: Type of FAISS index ('flat' or 'ivf')
|
||||
|
||||
Returns:
|
||||
VectorStore instance (FAISS or Simple)
|
||||
"""
|
||||
if prefer_faiss and FAISS_AVAILABLE:
|
||||
try:
|
||||
return FAISSVectorStore(dimension, index_type)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create FAISS store: {e}, falling back to NumPy")
|
||||
return SimpleVectorStore()
|
||||
else:
|
||||
return SimpleVectorStore()
|
||||
|
||||
|
||||
# Utility functions for vector operations
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
|
|
@ -264,4 +127,4 @@ def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndar
|
|||
|
||||
# Compute dot product
|
||||
similarities = np.dot(queries_norm, targets_norm.T)
|
||||
return similarities
|
||||
return similarities
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue