Refactor to use config event queue

This commit is contained in:
Cyber MacGeddon 2025-11-12 16:05:10 +00:00
parent ef0f2b6837
commit dfd7ad3a56
6 changed files with 251 additions and 498 deletions

View file

@ -30,6 +30,7 @@ dependencies = [
"minio",
"mistralai",
"neo4j",
"nltk",
"ollama",
"openai",
"pinecone[grpc]",

View file

@ -12,8 +12,7 @@ from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... tables.config import ConfigTableStore
from .... base import PromptClientSpec, EmbeddingsClientSpec
from .ontology_loader import OntologyLoader
from .ontology_embedder import OntologyEmbedder
@ -58,6 +57,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response"
)
)
self.register_specification(
ProducerSpec(
name="triples",
@ -65,6 +71,9 @@ class Processor(FlowProcessor):
)
)
# Register config handler for ontology updates
self.register_config_handler(self.on_ontology_config)
# Initialize components
self.ontology_loader = None
self.ontology_embedder = None
@ -75,13 +84,10 @@ class Processor(FlowProcessor):
# Configuration
self.top_k = params.get("top_k", 10)
self.similarity_threshold = params.get("similarity_threshold", 0.7)
self.refresh_interval = params.get("ontology_refresh_interval", 300)
# Cassandra configuration for config store
self.cassandra_host = params.get("cassandra_host", "localhost")
self.cassandra_username = params.get("cassandra_username", "cassandra")
self.cassandra_password = params.get("cassandra_password", "cassandra")
self.cassandra_keyspace = params.get("cassandra_keyspace", "trustgraph")
# Track loaded ontology version
self.current_ontology_version = None
self.loaded_ontology_ids = set()
async def initialize_components(self, flow):
"""Initialize OntoRAG components."""
@ -89,38 +95,24 @@ class Processor(FlowProcessor):
return
try:
# Create configuration store
config_store = ConfigTableStore(
self.cassandra_host,
self.cassandra_username,
self.cassandra_password,
self.cassandra_keyspace
)
# Initialize ontology loader (no ConfigTableStore needed)
self.ontology_loader = OntologyLoader()
logger.info("Ontology loader initialized")
# Initialize ontology loader
self.ontology_loader = OntologyLoader(config_store)
ontologies = await self.ontology_loader.load_ontologies()
logger.info(f"Loaded {len(ontologies)} ontologies")
# Initialize vector store
vector_store = InMemoryVectorStore.create(
# Initialize vector store (FAISS only, no fallback)
vector_store = InMemoryVectorStore(
dimension=1536, # text-embedding-3-small
prefer_faiss=True,
index_type='flat'
)
# Initialize ontology embedder with embedding service wrapper
embedding_service = EmbeddingServiceWrapper(flow)
# Use embeddings client directly (no wrapper needed)
embeddings_client = flow("embeddings-request")
self.ontology_embedder = OntologyEmbedder(
embedding_service=embedding_service,
embedding_service=embeddings_client,
vector_store=vector_store
)
# Embed all ontologies
if ontologies:
await self.ontology_embedder.embed_ontologies(ontologies)
logger.info(f"Embedded {self.ontology_embedder.get_embedded_count()} ontology elements")
# Initialize ontology selector
self.ontology_selector = OntologySelector(
ontology_embedder=self.ontology_embedder,
@ -132,28 +124,90 @@ class Processor(FlowProcessor):
self.initialized = True
logger.info("OntoRAG components initialized successfully")
# Schedule periodic refresh
asyncio.create_task(self.refresh_ontologies_periodically())
# NOTE: Ontologies will be loaded via on_ontology_config() handler
# when ConfigPush messages arrive (including initial config on startup)
except Exception as e:
logger.error(f"Failed to initialize OntoRAG components: {e}", exc_info=True)
raise
async def refresh_ontologies_periodically(self):
"""Periodically refresh ontologies from configuration."""
while True:
await asyncio.sleep(self.refresh_interval)
try:
logger.info("Refreshing ontologies...")
ontologies = await self.ontology_loader.refresh_ontologies()
if ontologies:
# Re-embed new ontologies
for ont_id in ontologies:
if not self.ontology_embedder.is_ontology_embedded(ont_id):
await self.ontology_embedder.embed_ontology(ontologies[ont_id])
logger.info("Ontology refresh complete")
except Exception as e:
logger.error(f"Error refreshing ontologies: {e}", exc_info=True)
async def on_ontology_config(self, config, version):
"""
Handle ontology configuration updates from ConfigPush queue.
Called automatically when:
- Processor starts (gets full config history via start_of_messages=True)
- Config service pushes updates (immediate event-driven notification)
Args:
config: Full configuration map - config[type][key] = value
version: Config version number (monotonically increasing)
"""
try:
logger.info(f"Received ontology config update, version={version}")
# Skip if we've already processed this version
if version == self.current_ontology_version:
logger.debug(f"Already at version {version}, skipping")
return
# Extract ontology configurations
if "ontology" not in config:
logger.warning("No 'ontology' section in config")
return
ontology_configs = config["ontology"]
# Parse ontology definitions
ontologies = {}
for ont_id, ont_json in ontology_configs.items():
try:
ontologies[ont_id] = json.loads(ont_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse ontology '{ont_id}': {e}")
continue
logger.info(f"Loaded {len(ontologies)} ontology definitions")
# Determine what changed (for incremental updates)
new_ids = set(ontologies.keys())
added_ids = new_ids - self.loaded_ontology_ids
removed_ids = self.loaded_ontology_ids - new_ids
updated_ids = new_ids & self.loaded_ontology_ids # May have changed content
if added_ids:
logger.info(f"New ontologies: {added_ids}")
if removed_ids:
logger.info(f"Removed ontologies: {removed_ids}")
if updated_ids:
logger.info(f"Updated ontologies: {updated_ids}")
# Update ontology loader's internal state
self.ontology_loader.update_ontologies(ontologies)
# Re-embed changed ontologies
if self.ontology_embedder:
# Remove embeddings for deleted ontologies
for ont_id in removed_ids:
self.ontology_embedder.remove_ontology(ont_id)
# Embed new and updated ontologies
for ont_id in added_ids | updated_ids:
if ont_id in self.ontology_loader.get_all_ontologies():
await self.ontology_embedder.embed_ontology(
self.ontology_loader.get_ontology(ont_id)
)
logger.info(f"Re-embedded ontologies, total elements: {self.ontology_embedder.get_embedded_count()}")
# Update tracking
self.current_ontology_version = version
self.loaded_ontology_ids = new_ids
logger.info(f"Ontology config update complete, version={version}")
except Exception as e:
logger.error(f"Failed to process ontology config: {e}", exc_info=True)
async def on_message(self, msg, consumer, flow):
"""Process incoming chunk message."""
@ -403,71 +457,9 @@ TRIPLES (JSON array):"""
default=0.7,
help='Similarity threshold for ontology matching (default: 0.7)'
)
parser.add_argument(
'--ontology-refresh-interval',
type=int,
default=300,
help='Ontology refresh interval in seconds (default: 300)'
)
parser.add_argument(
'--cassandra-host',
type=str,
default='localhost',
help='Cassandra host (default: localhost)'
)
parser.add_argument(
'--cassandra-username',
type=str,
default='cassandra',
help='Cassandra username (default: cassandra)'
)
parser.add_argument(
'--cassandra-password',
type=str,
default='cassandra',
help='Cassandra password (default: cassandra)'
)
parser.add_argument(
'--cassandra-keyspace',
type=str,
default='trustgraph',
help='Cassandra keyspace (default: trustgraph)'
)
FlowProcessor.add_args(parser)
class EmbeddingServiceWrapper:
"""Wrapper to adapt flow prompt service to embedding service interface."""
def __init__(self, flow):
self.flow = flow
async def embed(self, text: str):
"""Generate embedding for single text."""
try:
response = await self.flow("prompt-request").get_embedding(text=text)
return response
except Exception as e:
logger.error(f"Embedding service error: {e}")
return None
async def embed_batch(self, texts: List[str]):
"""Generate embeddings for multiple texts."""
try:
# Process in parallel for better performance
tasks = [self.embed(text) for text in texts]
embeddings = await asyncio.gather(*tasks)
# Filter out None values and convert to array
import numpy as np
valid_embeddings = [e for e in embeddings if e is not None]
if valid_embeddings:
return np.array(valid_embeddings)
return None
except Exception as e:
logger.error(f"Batch embedding service error: {e}")
return None
def run():
"""Launch the OntoRAG extraction service."""
Processor.launch(default_ident, __doc__)

View file

@ -9,7 +9,7 @@ from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from .ontology_loader import Ontology, OntologyClass, OntologyProperty
from .vector_store import VectorStore, InMemoryVectorStore
from .vector_store import InMemoryVectorStore
logger = logging.getLogger(__name__)
@ -27,15 +27,15 @@ class OntologyElementMetadata:
class OntologyEmbedder:
"""Generates embeddings for ontology elements and stores them in vector store."""
def __init__(self, embedding_service=None, vector_store: Optional[VectorStore] = None):
def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None):
"""Initialize the ontology embedder.
Args:
embedding_service: Service for generating embeddings
vector_store: Vector store instance (defaults to InMemoryVectorStore)
vector_store: Vector store instance (InMemoryVectorStore)
"""
self.embedding_service = embedding_service
self.vector_store = vector_store or InMemoryVectorStore.create()
self.vector_store = vector_store or InMemoryVectorStore()
self.embedded_ontologies = set()
def _create_text_representation(self, element_id: str, element: Any,
@ -232,6 +232,25 @@ class OntologyEmbedder:
logger.error(f"Failed to embed texts: {e}")
return None
def remove_ontology(self, ontology_id: str):
"""Remove all embeddings for a specific ontology.
Note: FAISS doesn't support efficient deletion, so this currently
requires rebuilding the entire index without the removed ontology.
Args:
ontology_id: ID of ontology to remove
"""
if ontology_id not in self.embedded_ontologies:
logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove")
return
# FAISS doesn't support selective deletion, so we'd need to rebuild the index
# For now, just remove from tracking set
# TODO: Implement index rebuilding if selective removal is needed
self.embedded_ontologies.discard(ontology_id)
logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)")
def clear_embeddings(self, ontology_id: Optional[str] = None):
"""Clear embeddings from vector store.
@ -240,15 +259,13 @@ class OntologyEmbedder:
Otherwise, clear all embeddings
"""
if ontology_id:
# Would need to implement selective clearing in vector store
# For now, log warning
logger.warning(f"Selective clearing not implemented, would clear {ontology_id}")
self.remove_ontology(ontology_id)
else:
self.vector_store.clear()
self.embedded_ontologies.clear()
logger.info("Cleared all embeddings from vector store")
def get_vector_store(self) -> VectorStore:
def get_vector_store(self) -> InMemoryVectorStore:
"""Get the vector store instance.
Returns:

View file

@ -158,84 +158,61 @@ class Ontology:
class OntologyLoader:
"""Loads and manages ontologies from configuration service."""
"""Manages ontologies received via event-driven config updates.
def __init__(self, config_store=None):
"""Initialize the ontology loader.
No direct database access - receives ontologies via config handler.
"""
def __init__(self):
"""Initialize empty ontology store."""
self.ontologies: Dict[str, Ontology] = {}
def update_ontologies(self, ontology_configs: Dict[str, Any]):
"""Update ontology definitions from config.
Args:
config_store: Configuration store instance (injected dependency)
ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts)
"""
self.config_store = config_store
self.ontologies: Dict[str, Ontology] = {}
self.refresh_interval = 300 # Default 5 minutes
self.ontologies.clear()
async def load_ontologies(self) -> Dict[str, Ontology]:
"""Load all ontologies from configuration service.
for ont_id, ont_data in ontology_configs.items():
try:
# Parse classes
classes = {}
for class_id, class_data in ont_data.get('classes', {}).items():
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
Returns:
Dictionary of ontology ID to Ontology objects
"""
if not self.config_store:
logger.warning("No configuration store available, returning empty ontologies")
return {}
# Parse object properties
object_props = {}
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
try:
# Get all ontology configurations
ontology_configs = await self.config_store.get("ontology").values()
# Parse datatype properties
datatype_props = {}
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
for ont_id, ont_data in ontology_configs.items():
try:
# Parse JSON if string
if isinstance(ont_data, str):
ont_data = json.loads(ont_data)
# Create ontology
ontology = Ontology(
id=ont_id,
metadata=ont_data.get('metadata', {}),
classes=classes,
object_properties=object_props,
datatype_properties=datatype_props
)
# Parse classes
classes = {}
for class_id, class_data in ont_data.get('classes', {}).items():
classes[class_id] = OntologyClass.from_dict(class_id, class_data)
# Validate structure
issues = ontology.validate_structure()
if issues:
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
# Parse object properties
object_props = {}
for prop_id, prop_data in ont_data.get('objectProperties', {}).items():
object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
self.ontologies[ont_id] = ontology
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
f"{len(object_props)} object properties, "
f"{len(datatype_props)} datatype properties")
# Parse datatype properties
datatype_props = {}
for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items():
datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data)
# Create ontology
ontology = Ontology(
id=ont_id,
metadata=ont_data.get('metadata', {}),
classes=classes,
object_properties=object_props,
datatype_properties=datatype_props
)
# Validate structure
issues = ontology.validate_structure()
if issues:
logger.warning(f"Ontology {ont_id} has validation issues: {issues}")
self.ontologies[ont_id] = ontology
logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, "
f"{len(object_props)} object properties, "
f"{len(datatype_props)} datatype properties")
except Exception as e:
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Failed to load ontologies from config: {e}", exc_info=True)
return self.ontologies
async def refresh_ontologies(self):
"""Refresh ontologies from configuration service."""
logger.info("Refreshing ontologies...")
return await self.load_ontologies()
except Exception as e:
logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True)
def get_ontology(self, ont_id: str) -> Optional[Ontology]:
"""Get a specific ontology by ID.
@ -256,6 +233,14 @@ class OntologyLoader:
"""
return self.ontologies
def list_ontology_ids(self) -> List[str]:
"""Get list of loaded ontology IDs.
Returns:
List of ontology IDs
"""
return list(self.ontologies.keys())
def clear(self):
"""Clear all loaded ontologies."""
self.ontologies.clear()

View file

@ -7,38 +7,26 @@ import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import nltk
from nltk.corpus import stopwords
logger = logging.getLogger(__name__)
# Try to import NLTK for advanced text processing
# Ensure required NLTK data is downloaded
try:
import nltk
NLTK_AVAILABLE = True
# Try to ensure required NLTK data is downloaded
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
try:
nltk.download('punkt', quiet=True)
except:
pass
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
try:
nltk.download('averaged_perceptron_tagger', quiet=True)
except:
pass
try:
nltk.data.find('corpora/stopwords')
except LookupError:
try:
nltk.download('stopwords', quiet=True)
except:
pass
except ImportError:
NLTK_AVAILABLE = False
logger.warning("NLTK not available, using basic text processing")
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt', quiet=True)
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
nltk.download('averaged_perceptron_tagger', quiet=True)
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords', quiet=True)
@dataclass
@ -52,18 +40,12 @@ class TextSegment:
class SentenceSplitter:
"""Splits text into sentences using available NLP tools."""
"""Splits text into sentences using NLTK."""
def __init__(self):
"""Initialize sentence splitter."""
self.use_nltk = NLTK_AVAILABLE
if self.use_nltk:
try:
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
logger.info("Using NLTK sentence tokenizer")
except:
self.use_nltk = False
logger.warning("NLTK punkt tokenizer not available, using regex")
self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
logger.info("Using NLTK sentence tokenizer")
def split(self, text: str) -> List[str]:
"""Split text into sentences.
@ -74,35 +56,16 @@ class SentenceSplitter:
Returns:
List of sentences
"""
if self.use_nltk:
try:
sentences = self.sent_detector.tokenize(text)
return sentences
except Exception as e:
logger.warning(f"NLTK sentence splitting failed: {e}, falling back to regex")
# Fallback to regex-based splitting
# Simple sentence boundary detection
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)
# Filter out empty sentences
sentences = [s.strip() for s in sentences if s.strip()]
sentences = self.sent_detector.tokenize(text)
return sentences
class PhraseExtractor:
"""Extracts meaningful phrases from sentences."""
"""Extracts meaningful phrases from sentences using NLTK."""
def __init__(self):
"""Initialize phrase extractor."""
self.use_nltk = NLTK_AVAILABLE
if self.use_nltk:
try:
# Test that POS tagger is available
nltk.pos_tag(['test'])
logger.info("Using NLTK phrase extraction")
except:
self.use_nltk = False
logger.warning("NLTK POS tagger not available, using basic extraction")
logger.info("Using NLTK phrase extraction")
def extract(self, sentence: str) -> List[Dict[str, str]]:
"""Extract phrases from a sentence.
@ -115,104 +78,49 @@ class PhraseExtractor:
"""
phrases = []
if self.use_nltk:
try:
phrases.extend(self._extract_nltk_phrases(sentence))
except Exception as e:
logger.warning(f"NLTK phrase extraction failed: {e}, using basic extraction")
phrases.extend(self._extract_basic_phrases(sentence))
else:
phrases.extend(self._extract_basic_phrases(sentence))
# Tokenize and POS tag
tokens = nltk.word_tokenize(sentence)
pos_tags = nltk.pos_tag(tokens)
return phrases
# Extract noun phrases (simple pattern)
noun_phrase = []
for word, pos in pos_tags:
if pos.startswith('NN') or pos.startswith('JJ'):
noun_phrase.append(word)
elif noun_phrase:
if len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
noun_phrase = []
def _extract_nltk_phrases(self, sentence: str) -> List[Dict[str, str]]:
"""Extract phrases using NLTK.
# Add last noun phrase if exists
if noun_phrase and len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
Args:
sentence: Sentence to process
# Extract verb phrases (simple pattern)
verb_phrase = []
for word, pos in pos_tags:
if pos.startswith('VB') or pos.startswith('RB'):
verb_phrase.append(word)
elif verb_phrase:
if len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
verb_phrase = []
Returns:
List of phrases with types
"""
phrases = []
try:
# Tokenize and POS tag
tokens = nltk.word_tokenize(sentence)
pos_tags = nltk.pos_tag(tokens)
# Extract noun phrases (simple pattern)
noun_phrase = []
for word, pos in pos_tags:
if pos.startswith('NN') or pos.startswith('JJ'):
noun_phrase.append(word)
elif noun_phrase:
if len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
noun_phrase = []
# Add last noun phrase if exists
if noun_phrase and len(noun_phrase) > 1:
phrases.append({
'text': ' '.join(noun_phrase),
'type': 'noun_phrase'
})
# Extract verb phrases (simple pattern)
verb_phrase = []
for word, pos in pos_tags:
if pos.startswith('VB') or pos.startswith('RB'):
verb_phrase.append(word)
elif verb_phrase:
if len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
verb_phrase = []
# Add last verb phrase if exists
if verb_phrase and len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
except Exception as e:
logger.error(f"Error in NLTK phrase extraction: {e}")
return phrases
def _extract_basic_phrases(self, sentence: str) -> List[Dict[str, str]]:
"""Extract phrases using basic regex patterns.
Args:
sentence: Sentence to process
Returns:
List of phrases with types
"""
phrases = []
# Extract quoted phrases
quoted = re.findall(r'"([^"]+)"', sentence)
for q in quoted:
phrases.append({'text': q, 'type': 'phrase'})
# Extract parenthetical phrases
parens = re.findall(r'\(([^)]+)\)', sentence)
for p in parens:
phrases.append({'text': p, 'type': 'phrase'})
# Extract capitalized sequences (potential entities)
caps = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', sentence)
for c in caps:
if len(c.split()) > 1: # Multi-word entities
phrases.append({'text': c, 'type': 'noun_phrase'})
# Add last verb phrase if exists
if verb_phrase and len(verb_phrase) > 1:
phrases.append({
'text': ' '.join(verb_phrase),
'type': 'verb_phrase'
})
return phrases
@ -279,21 +187,8 @@ class TextProcessor:
# Split on word boundaries
words = re.findall(r'\b\w+\b', text.lower())
# Filter common stop words (basic list)
stop_words = {
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall'
}
# Use NLTK stopwords if available
if NLTK_AVAILABLE:
try:
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
except:
pass
# Use NLTK stopwords
stop_words = set(stopwords.words('english'))
# Filter stopwords and short words
terms = [w for w in words if w not in stop_words and len(w) > 2]
@ -322,4 +217,4 @@ class TextProcessor:
# Normalize quotes
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
return text
return text

View file

@ -1,23 +1,16 @@
"""
Vector store implementations for OntoRAG system.
Provides both FAISS and NumPy-based vector storage for ontology embeddings.
Vector store implementation for OntoRAG system.
Provides FAISS-based vector storage for ontology embeddings.
"""
import logging
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Any
from dataclasses import dataclass
import faiss
logger = logging.getLogger(__name__)
# Try to import FAISS, fall back to NumPy implementation if not available
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using NumPy implementation")
@dataclass
class SearchResult:
@ -27,34 +20,8 @@ class SearchResult:
metadata: Dict[str, Any]
class VectorStore:
"""Abstract base class for vector stores."""
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
"""Add single embedding with metadata."""
raise NotImplementedError
def add_batch(self, ids: List[str], embeddings: np.ndarray,
metadata_list: List[Dict[str, Any]]):
"""Batch add for initial ontology loading."""
raise NotImplementedError
def search(self, embedding: np.ndarray, top_k: int = 10,
threshold: float = 0.0) -> List[SearchResult]:
"""Search for similar vectors."""
raise NotImplementedError
def clear(self):
"""Reset the store."""
raise NotImplementedError
def size(self) -> int:
"""Return number of stored vectors."""
raise NotImplementedError
class FAISSVectorStore(VectorStore):
"""FAISS-based vector store implementation."""
class InMemoryVectorStore:
"""FAISS-based vector store implementation for ontology embeddings."""
def __init__(self, dimension: int = 1536, index_type: str = 'flat'):
"""Initialize FAISS vector store.
@ -63,9 +30,6 @@ class FAISSVectorStore(VectorStore):
dimension: Embedding dimension (1536 for text-embedding-3-small)
index_type: 'flat' for exact search, 'ivf' for larger datasets
"""
if not FAISS_AVAILABLE:
raise RuntimeError("FAISS is not installed")
self.dimension = dimension
self.metadata = []
self.ids = []
@ -141,107 +105,6 @@ class FAISSVectorStore(VectorStore):
return self.index.ntotal
class SimpleVectorStore(VectorStore):
"""NumPy-based vector store implementation for development/small deployments."""
def __init__(self):
"""Initialize simple NumPy-based vector store."""
self.embeddings = []
self.metadata = []
self.ids = []
logger.info("Created SimpleVectorStore (NumPy implementation)")
def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]):
"""Add single embedding with metadata."""
# Normalize for cosine similarity
normalized = embedding / np.linalg.norm(embedding)
self.embeddings.append(normalized)
self.metadata.append(metadata)
self.ids.append(id)
def add_batch(self, ids: List[str], embeddings: np.ndarray,
metadata_list: List[Dict[str, Any]]):
"""Batch add for initial ontology loading."""
# Normalize all embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
# Avoid division by zero
norms = np.where(norms == 0, 1, norms)
normalized = embeddings / norms
for i in range(len(ids)):
self.embeddings.append(normalized[i])
self.metadata.append(metadata_list[i])
self.ids.append(ids[i])
logger.debug(f"Added batch of {len(ids)} embeddings to simple store")
def search(self, embedding: np.ndarray, top_k: int = 10,
threshold: float = 0.0) -> List[SearchResult]:
"""Search for similar vectors using cosine similarity."""
if not self.embeddings:
return []
# Normalize query embedding
embedding = embedding / np.linalg.norm(embedding)
# Compute cosine similarities
embeddings_array = np.array(self.embeddings)
similarities = np.dot(embeddings_array, embedding)
# Get top-k indices
top_k = min(top_k, len(self.embeddings))
top_indices = np.argsort(similarities)[::-1][:top_k]
# Build results
results = []
for idx in top_indices:
if similarities[idx] >= threshold:
results.append(SearchResult(
id=self.ids[idx],
score=float(similarities[idx]),
metadata=self.metadata[idx]
))
return results
def clear(self):
"""Reset the store."""
self.embeddings = []
self.metadata = []
self.ids = []
logger.info("Cleared simple vector store")
def size(self) -> int:
"""Return number of stored vectors."""
return len(self.embeddings)
class InMemoryVectorStore:
"""Factory class to create appropriate vector store based on availability."""
@staticmethod
def create(dimension: int = 1536, prefer_faiss: bool = True,
index_type: str = 'flat') -> VectorStore:
"""Create a vector store instance.
Args:
dimension: Embedding dimension
prefer_faiss: Whether to prefer FAISS if available
index_type: Type of FAISS index ('flat' or 'ivf')
Returns:
VectorStore instance (FAISS or Simple)
"""
if prefer_faiss and FAISS_AVAILABLE:
try:
return FAISSVectorStore(dimension, index_type)
except Exception as e:
logger.warning(f"Failed to create FAISS store: {e}, falling back to NumPy")
return SimpleVectorStore()
else:
return SimpleVectorStore()
# Utility functions for vector operations
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
@ -264,4 +127,4 @@ def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndar
# Compute dot product
similarities = np.dot(queries_norm, targets_norm.T)
return similarities
return similarities