diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 9bc8354a..452ebddf 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "minio", "mistralai", "neo4j", + "nltk", "ollama", "openai", "pinecone[grpc]", diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index b0942dc2..10247464 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -12,8 +12,7 @@ from .... schema import Chunk, Triple, Triples, Metadata, Value from .... schema import PromptRequest, PromptResponse from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base import PromptClientSpec -from .... tables.config import ConfigTableStore +from .... base import PromptClientSpec, EmbeddingsClientSpec from .ontology_loader import OntologyLoader from .ontology_embedder import OntologyEmbedder @@ -58,6 +57,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + EmbeddingsClientSpec( + request_name="embeddings-request", + response_name="embeddings-response" + ) + ) + self.register_specification( ProducerSpec( name="triples", @@ -65,6 +71,9 @@ class Processor(FlowProcessor): ) ) + # Register config handler for ontology updates + self.register_config_handler(self.on_ontology_config) + # Initialize components self.ontology_loader = None self.ontology_embedder = None @@ -75,13 +84,10 @@ class Processor(FlowProcessor): # Configuration self.top_k = params.get("top_k", 10) self.similarity_threshold = params.get("similarity_threshold", 0.7) - self.refresh_interval = params.get("ontology_refresh_interval", 300) - # Cassandra configuration for config store - self.cassandra_host = params.get("cassandra_host", "localhost") - self.cassandra_username = params.get("cassandra_username", "cassandra") - self.cassandra_password = params.get("cassandra_password", "cassandra") - self.cassandra_keyspace = params.get("cassandra_keyspace", "trustgraph") + # Track loaded ontology version + self.current_ontology_version = None + self.loaded_ontology_ids = set() async def initialize_components(self, flow): """Initialize OntoRAG components.""" @@ -89,38 +95,24 @@ class Processor(FlowProcessor): return try: - # Create configuration store - config_store = ConfigTableStore( - self.cassandra_host, - self.cassandra_username, - self.cassandra_password, - self.cassandra_keyspace - ) + # Initialize ontology loader (no ConfigTableStore needed) + self.ontology_loader = OntologyLoader() + logger.info("Ontology loader initialized") - # Initialize ontology loader - self.ontology_loader = OntologyLoader(config_store) - ontologies = await self.ontology_loader.load_ontologies() - logger.info(f"Loaded {len(ontologies)} ontologies") - - # Initialize vector store - vector_store = InMemoryVectorStore.create( + # Initialize vector store (FAISS only, no fallback) + vector_store = InMemoryVectorStore( dimension=1536, # text-embedding-3-small - prefer_faiss=True, index_type='flat' ) - # Initialize ontology embedder with embedding service wrapper - embedding_service = EmbeddingServiceWrapper(flow) + # Use embeddings client directly (no wrapper needed) + embeddings_client = flow("embeddings-request") + self.ontology_embedder = OntologyEmbedder( - embedding_service=embedding_service, + embedding_service=embeddings_client, vector_store=vector_store ) - # Embed all ontologies - if ontologies: - await self.ontology_embedder.embed_ontologies(ontologies) - logger.info(f"Embedded {self.ontology_embedder.get_embedded_count()} ontology elements") - # Initialize ontology selector self.ontology_selector = OntologySelector( ontology_embedder=self.ontology_embedder, @@ -132,28 +124,90 @@ class Processor(FlowProcessor): self.initialized = True logger.info("OntoRAG components initialized successfully") - # Schedule periodic refresh - asyncio.create_task(self.refresh_ontologies_periodically()) + # NOTE: Ontologies will be loaded via on_ontology_config() handler + # when ConfigPush messages arrive (including initial config on startup) except Exception as e: logger.error(f"Failed to initialize OntoRAG components: {e}", exc_info=True) raise - async def refresh_ontologies_periodically(self): - """Periodically refresh ontologies from configuration.""" - while True: - await asyncio.sleep(self.refresh_interval) - try: - logger.info("Refreshing ontologies...") - ontologies = await self.ontology_loader.refresh_ontologies() - if ontologies: - # Re-embed new ontologies - for ont_id in ontologies: - if not self.ontology_embedder.is_ontology_embedded(ont_id): - await self.ontology_embedder.embed_ontology(ontologies[ont_id]) - logger.info("Ontology refresh complete") - except Exception as e: - logger.error(f"Error refreshing ontologies: {e}", exc_info=True) + async def on_ontology_config(self, config, version): + """ + Handle ontology configuration updates from ConfigPush queue. + + Called automatically when: + - Processor starts (gets full config history via start_of_messages=True) + - Config service pushes updates (immediate event-driven notification) + + Args: + config: Full configuration map - config[type][key] = value + version: Config version number (monotonically increasing) + """ + try: + logger.info(f"Received ontology config update, version={version}") + + # Skip if we've already processed this version + if version == self.current_ontology_version: + logger.debug(f"Already at version {version}, skipping") + return + + # Extract ontology configurations + if "ontology" not in config: + logger.warning("No 'ontology' section in config") + return + + ontology_configs = config["ontology"] + + # Parse ontology definitions + ontologies = {} + for ont_id, ont_json in ontology_configs.items(): + try: + ontologies[ont_id] = json.loads(ont_json) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse ontology '{ont_id}': {e}") + continue + + logger.info(f"Loaded {len(ontologies)} ontology definitions") + + # Determine what changed (for incremental updates) + new_ids = set(ontologies.keys()) + added_ids = new_ids - self.loaded_ontology_ids + removed_ids = self.loaded_ontology_ids - new_ids + updated_ids = new_ids & self.loaded_ontology_ids # May have changed content + + if added_ids: + logger.info(f"New ontologies: {added_ids}") + if removed_ids: + logger.info(f"Removed ontologies: {removed_ids}") + if updated_ids: + logger.info(f"Updated ontologies: {updated_ids}") + + # Update ontology loader's internal state + self.ontology_loader.update_ontologies(ontologies) + + # Re-embed changed ontologies + if self.ontology_embedder: + # Remove embeddings for deleted ontologies + for ont_id in removed_ids: + self.ontology_embedder.remove_ontology(ont_id) + + # Embed new and updated ontologies + for ont_id in added_ids | updated_ids: + if ont_id in self.ontology_loader.get_all_ontologies(): + await self.ontology_embedder.embed_ontology( + self.ontology_loader.get_ontology(ont_id) + ) + + logger.info(f"Re-embedded ontologies, total elements: {self.ontology_embedder.get_embedded_count()}") + + # Update tracking + self.current_ontology_version = version + self.loaded_ontology_ids = new_ids + + logger.info(f"Ontology config update complete, version={version}") + + except Exception as e: + logger.error(f"Failed to process ontology config: {e}", exc_info=True) async def on_message(self, msg, consumer, flow): """Process incoming chunk message.""" @@ -403,71 +457,9 @@ TRIPLES (JSON array):""" default=0.7, help='Similarity threshold for ontology matching (default: 0.7)' ) - parser.add_argument( - '--ontology-refresh-interval', - type=int, - default=300, - help='Ontology refresh interval in seconds (default: 300)' - ) - parser.add_argument( - '--cassandra-host', - type=str, - default='localhost', - help='Cassandra host (default: localhost)' - ) - parser.add_argument( - '--cassandra-username', - type=str, - default='cassandra', - help='Cassandra username (default: cassandra)' - ) - parser.add_argument( - '--cassandra-password', - type=str, - default='cassandra', - help='Cassandra password (default: cassandra)' - ) - parser.add_argument( - '--cassandra-keyspace', - type=str, - default='trustgraph', - help='Cassandra keyspace (default: trustgraph)' - ) FlowProcessor.add_args(parser) -class EmbeddingServiceWrapper: - """Wrapper to adapt flow prompt service to embedding service interface.""" - - def __init__(self, flow): - self.flow = flow - - async def embed(self, text: str): - """Generate embedding for single text.""" - try: - response = await self.flow("prompt-request").get_embedding(text=text) - return response - except Exception as e: - logger.error(f"Embedding service error: {e}") - return None - - async def embed_batch(self, texts: List[str]): - """Generate embeddings for multiple texts.""" - try: - # Process in parallel for better performance - tasks = [self.embed(text) for text in texts] - embeddings = await asyncio.gather(*tasks) - # Filter out None values and convert to array - import numpy as np - valid_embeddings = [e for e in embeddings if e is not None] - if valid_embeddings: - return np.array(valid_embeddings) - return None - except Exception as e: - logger.error(f"Batch embedding service error: {e}") - return None - - def run(): """Launch the OntoRAG extraction service.""" Processor.launch(default_ident, __doc__) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py index 402f3b7a..cbb0f0bf 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py @@ -9,7 +9,7 @@ from typing import Dict, List, Any, Optional from dataclasses import dataclass from .ontology_loader import Ontology, OntologyClass, OntologyProperty -from .vector_store import VectorStore, InMemoryVectorStore +from .vector_store import InMemoryVectorStore logger = logging.getLogger(__name__) @@ -27,15 +27,15 @@ class OntologyElementMetadata: class OntologyEmbedder: """Generates embeddings for ontology elements and stores them in vector store.""" - def __init__(self, embedding_service=None, vector_store: Optional[VectorStore] = None): + def __init__(self, embedding_service=None, vector_store: Optional[InMemoryVectorStore] = None): """Initialize the ontology embedder. Args: embedding_service: Service for generating embeddings - vector_store: Vector store instance (defaults to InMemoryVectorStore) + vector_store: Vector store instance (InMemoryVectorStore) """ self.embedding_service = embedding_service - self.vector_store = vector_store or InMemoryVectorStore.create() + self.vector_store = vector_store or InMemoryVectorStore() self.embedded_ontologies = set() def _create_text_representation(self, element_id: str, element: Any, @@ -232,6 +232,25 @@ class OntologyEmbedder: logger.error(f"Failed to embed texts: {e}") return None + def remove_ontology(self, ontology_id: str): + """Remove all embeddings for a specific ontology. + + Note: FAISS doesn't support efficient deletion, so this currently + requires rebuilding the entire index without the removed ontology. + + Args: + ontology_id: ID of ontology to remove + """ + if ontology_id not in self.embedded_ontologies: + logger.debug(f"Ontology '{ontology_id}' not embedded, nothing to remove") + return + + # FAISS doesn't support selective deletion, so we'd need to rebuild the index + # For now, just remove from tracking set + # TODO: Implement index rebuilding if selective removal is needed + self.embedded_ontologies.discard(ontology_id) + logger.info(f"Removed ontology '{ontology_id}' from embedded set (note: vectors still in store)") + def clear_embeddings(self, ontology_id: Optional[str] = None): """Clear embeddings from vector store. @@ -240,15 +259,13 @@ class OntologyEmbedder: Otherwise, clear all embeddings """ if ontology_id: - # Would need to implement selective clearing in vector store - # For now, log warning - logger.warning(f"Selective clearing not implemented, would clear {ontology_id}") + self.remove_ontology(ontology_id) else: self.vector_store.clear() self.embedded_ontologies.clear() logger.info("Cleared all embeddings from vector store") - def get_vector_store(self) -> VectorStore: + def get_vector_store(self) -> InMemoryVectorStore: """Get the vector store instance. Returns: diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py index 2dc53003..710108b6 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py @@ -158,84 +158,61 @@ class Ontology: class OntologyLoader: - """Loads and manages ontologies from configuration service.""" + """Manages ontologies received via event-driven config updates. - def __init__(self, config_store=None): - """Initialize the ontology loader. + No direct database access - receives ontologies via config handler. + """ + + def __init__(self): + """Initialize empty ontology store.""" + self.ontologies: Dict[str, Ontology] = {} + + def update_ontologies(self, ontology_configs: Dict[str, Any]): + """Update ontology definitions from config. Args: - config_store: Configuration store instance (injected dependency) + ontology_configs: Dict mapping ontology_id -> ontology_definition (parsed dicts) """ - self.config_store = config_store - self.ontologies: Dict[str, Ontology] = {} - self.refresh_interval = 300 # Default 5 minutes + self.ontologies.clear() - async def load_ontologies(self) -> Dict[str, Ontology]: - """Load all ontologies from configuration service. + for ont_id, ont_data in ontology_configs.items(): + try: + # Parse classes + classes = {} + for class_id, class_data in ont_data.get('classes', {}).items(): + classes[class_id] = OntologyClass.from_dict(class_id, class_data) - Returns: - Dictionary of ontology ID to Ontology objects - """ - if not self.config_store: - logger.warning("No configuration store available, returning empty ontologies") - return {} + # Parse object properties + object_props = {} + for prop_id, prop_data in ont_data.get('objectProperties', {}).items(): + object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) - try: - # Get all ontology configurations - ontology_configs = await self.config_store.get("ontology").values() + # Parse datatype properties + datatype_props = {} + for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items(): + datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) - for ont_id, ont_data in ontology_configs.items(): - try: - # Parse JSON if string - if isinstance(ont_data, str): - ont_data = json.loads(ont_data) + # Create ontology + ontology = Ontology( + id=ont_id, + metadata=ont_data.get('metadata', {}), + classes=classes, + object_properties=object_props, + datatype_properties=datatype_props + ) - # Parse classes - classes = {} - for class_id, class_data in ont_data.get('classes', {}).items(): - classes[class_id] = OntologyClass.from_dict(class_id, class_data) + # Validate structure + issues = ontology.validate_structure() + if issues: + logger.warning(f"Ontology {ont_id} has validation issues: {issues}") - # Parse object properties - object_props = {} - for prop_id, prop_data in ont_data.get('objectProperties', {}).items(): - object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) + self.ontologies[ont_id] = ontology + logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, " + f"{len(object_props)} object properties, " + f"{len(datatype_props)} datatype properties") - # Parse datatype properties - datatype_props = {} - for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items(): - datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) - - # Create ontology - ontology = Ontology( - id=ont_id, - metadata=ont_data.get('metadata', {}), - classes=classes, - object_properties=object_props, - datatype_properties=datatype_props - ) - - # Validate structure - issues = ontology.validate_structure() - if issues: - logger.warning(f"Ontology {ont_id} has validation issues: {issues}") - - self.ontologies[ont_id] = ontology - logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, " - f"{len(object_props)} object properties, " - f"{len(datatype_props)} datatype properties") - - except Exception as e: - logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True) - - except Exception as e: - logger.error(f"Failed to load ontologies from config: {e}", exc_info=True) - - return self.ontologies - - async def refresh_ontologies(self): - """Refresh ontologies from configuration service.""" - logger.info("Refreshing ontologies...") - return await self.load_ontologies() + except Exception as e: + logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True) def get_ontology(self, ont_id: str) -> Optional[Ontology]: """Get a specific ontology by ID. @@ -256,6 +233,14 @@ class OntologyLoader: """ return self.ontologies + def list_ontology_ids(self) -> List[str]: + """Get list of loaded ontology IDs. + + Returns: + List of ontology IDs + """ + return list(self.ontologies.keys()) + def clear(self): """Clear all loaded ontologies.""" self.ontologies.clear() diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py index 98563bba..e6c92f98 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py @@ -7,38 +7,26 @@ import logging import re from typing import List, Dict, Any, Optional from dataclasses import dataclass +import nltk +from nltk.corpus import stopwords logger = logging.getLogger(__name__) -# Try to import NLTK for advanced text processing +# Ensure required NLTK data is downloaded try: - import nltk - NLTK_AVAILABLE = True - # Try to ensure required NLTK data is downloaded - try: - nltk.data.find('tokenizers/punkt') - except LookupError: - try: - nltk.download('punkt', quiet=True) - except: - pass - try: - nltk.data.find('taggers/averaged_perceptron_tagger') - except LookupError: - try: - nltk.download('averaged_perceptron_tagger', quiet=True) - except: - pass - try: - nltk.data.find('corpora/stopwords') - except LookupError: - try: - nltk.download('stopwords', quiet=True) - except: - pass -except ImportError: - NLTK_AVAILABLE = False - logger.warning("NLTK not available, using basic text processing") + nltk.data.find('tokenizers/punkt') +except LookupError: + nltk.download('punkt', quiet=True) + +try: + nltk.data.find('taggers/averaged_perceptron_tagger') +except LookupError: + nltk.download('averaged_perceptron_tagger', quiet=True) + +try: + nltk.data.find('corpora/stopwords') +except LookupError: + nltk.download('stopwords', quiet=True) @dataclass @@ -52,18 +40,12 @@ class TextSegment: class SentenceSplitter: - """Splits text into sentences using available NLP tools.""" + """Splits text into sentences using NLTK.""" def __init__(self): """Initialize sentence splitter.""" - self.use_nltk = NLTK_AVAILABLE - if self.use_nltk: - try: - self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') - logger.info("Using NLTK sentence tokenizer") - except: - self.use_nltk = False - logger.warning("NLTK punkt tokenizer not available, using regex") + self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') + logger.info("Using NLTK sentence tokenizer") def split(self, text: str) -> List[str]: """Split text into sentences. @@ -74,35 +56,16 @@ class SentenceSplitter: Returns: List of sentences """ - if self.use_nltk: - try: - sentences = self.sent_detector.tokenize(text) - return sentences - except Exception as e: - logger.warning(f"NLTK sentence splitting failed: {e}, falling back to regex") - - # Fallback to regex-based splitting - # Simple sentence boundary detection - sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) - # Filter out empty sentences - sentences = [s.strip() for s in sentences if s.strip()] + sentences = self.sent_detector.tokenize(text) return sentences class PhraseExtractor: - """Extracts meaningful phrases from sentences.""" + """Extracts meaningful phrases from sentences using NLTK.""" def __init__(self): """Initialize phrase extractor.""" - self.use_nltk = NLTK_AVAILABLE - if self.use_nltk: - try: - # Test that POS tagger is available - nltk.pos_tag(['test']) - logger.info("Using NLTK phrase extraction") - except: - self.use_nltk = False - logger.warning("NLTK POS tagger not available, using basic extraction") + logger.info("Using NLTK phrase extraction") def extract(self, sentence: str) -> List[Dict[str, str]]: """Extract phrases from a sentence. @@ -115,104 +78,49 @@ class PhraseExtractor: """ phrases = [] - if self.use_nltk: - try: - phrases.extend(self._extract_nltk_phrases(sentence)) - except Exception as e: - logger.warning(f"NLTK phrase extraction failed: {e}, using basic extraction") - phrases.extend(self._extract_basic_phrases(sentence)) - else: - phrases.extend(self._extract_basic_phrases(sentence)) + # Tokenize and POS tag + tokens = nltk.word_tokenize(sentence) + pos_tags = nltk.pos_tag(tokens) - return phrases + # Extract noun phrases (simple pattern) + noun_phrase = [] + for word, pos in pos_tags: + if pos.startswith('NN') or pos.startswith('JJ'): + noun_phrase.append(word) + elif noun_phrase: + if len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) + noun_phrase = [] - def _extract_nltk_phrases(self, sentence: str) -> List[Dict[str, str]]: - """Extract phrases using NLTK. + # Add last noun phrase if exists + if noun_phrase and len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) - Args: - sentence: Sentence to process + # Extract verb phrases (simple pattern) + verb_phrase = [] + for word, pos in pos_tags: + if pos.startswith('VB') or pos.startswith('RB'): + verb_phrase.append(word) + elif verb_phrase: + if len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) + verb_phrase = [] - Returns: - List of phrases with types - """ - phrases = [] - - try: - # Tokenize and POS tag - tokens = nltk.word_tokenize(sentence) - pos_tags = nltk.pos_tag(tokens) - - # Extract noun phrases (simple pattern) - noun_phrase = [] - for word, pos in pos_tags: - if pos.startswith('NN') or pos.startswith('JJ'): - noun_phrase.append(word) - elif noun_phrase: - if len(noun_phrase) > 1: - phrases.append({ - 'text': ' '.join(noun_phrase), - 'type': 'noun_phrase' - }) - noun_phrase = [] - - # Add last noun phrase if exists - if noun_phrase and len(noun_phrase) > 1: - phrases.append({ - 'text': ' '.join(noun_phrase), - 'type': 'noun_phrase' - }) - - # Extract verb phrases (simple pattern) - verb_phrase = [] - for word, pos in pos_tags: - if pos.startswith('VB') or pos.startswith('RB'): - verb_phrase.append(word) - elif verb_phrase: - if len(verb_phrase) > 1: - phrases.append({ - 'text': ' '.join(verb_phrase), - 'type': 'verb_phrase' - }) - verb_phrase = [] - - # Add last verb phrase if exists - if verb_phrase and len(verb_phrase) > 1: - phrases.append({ - 'text': ' '.join(verb_phrase), - 'type': 'verb_phrase' - }) - - except Exception as e: - logger.error(f"Error in NLTK phrase extraction: {e}") - - return phrases - - def _extract_basic_phrases(self, sentence: str) -> List[Dict[str, str]]: - """Extract phrases using basic regex patterns. - - Args: - sentence: Sentence to process - - Returns: - List of phrases with types - """ - phrases = [] - - # Extract quoted phrases - quoted = re.findall(r'"([^"]+)"', sentence) - for q in quoted: - phrases.append({'text': q, 'type': 'phrase'}) - - # Extract parenthetical phrases - parens = re.findall(r'\(([^)]+)\)', sentence) - for p in parens: - phrases.append({'text': p, 'type': 'phrase'}) - - # Extract capitalized sequences (potential entities) - caps = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', sentence) - for c in caps: - if len(c.split()) > 1: # Multi-word entities - phrases.append({'text': c, 'type': 'noun_phrase'}) + # Add last verb phrase if exists + if verb_phrase and len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) return phrases @@ -279,21 +187,8 @@ class TextProcessor: # Split on word boundaries words = re.findall(r'\b\w+\b', text.lower()) - # Filter common stop words (basic list) - stop_words = { - 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', - 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be', - 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', - 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall' - } - - # Use NLTK stopwords if available - if NLTK_AVAILABLE: - try: - from nltk.corpus import stopwords - stop_words = set(stopwords.words('english')) - except: - pass + # Use NLTK stopwords + stop_words = set(stopwords.words('english')) # Filter stopwords and short words terms = [w for w in words if w not in stop_words and len(w) > 2] @@ -322,4 +217,4 @@ class TextProcessor: # Normalize quotes text = text.replace('"', '"').replace('"', '"') text = text.replace(''', "'").replace(''', "'") - return text \ No newline at end of file + return text diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py index 42c3dc7f..6f456861 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py @@ -1,23 +1,16 @@ """ -Vector store implementations for OntoRAG system. -Provides both FAISS and NumPy-based vector storage for ontology embeddings. +Vector store implementation for OntoRAG system. +Provides FAISS-based vector storage for ontology embeddings. """ import logging import numpy as np -from typing import List, Dict, Any, Optional, Tuple +from typing import List, Dict, Any from dataclasses import dataclass +import faiss logger = logging.getLogger(__name__) -# Try to import FAISS, fall back to NumPy implementation if not available -try: - import faiss - FAISS_AVAILABLE = True -except ImportError: - FAISS_AVAILABLE = False - logger.warning("FAISS not available, using NumPy implementation") - @dataclass class SearchResult: @@ -27,34 +20,8 @@ class SearchResult: metadata: Dict[str, Any] -class VectorStore: - """Abstract base class for vector stores.""" - - def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): - """Add single embedding with metadata.""" - raise NotImplementedError - - def add_batch(self, ids: List[str], embeddings: np.ndarray, - metadata_list: List[Dict[str, Any]]): - """Batch add for initial ontology loading.""" - raise NotImplementedError - - def search(self, embedding: np.ndarray, top_k: int = 10, - threshold: float = 0.0) -> List[SearchResult]: - """Search for similar vectors.""" - raise NotImplementedError - - def clear(self): - """Reset the store.""" - raise NotImplementedError - - def size(self) -> int: - """Return number of stored vectors.""" - raise NotImplementedError - - -class FAISSVectorStore(VectorStore): - """FAISS-based vector store implementation.""" +class InMemoryVectorStore: + """FAISS-based vector store implementation for ontology embeddings.""" def __init__(self, dimension: int = 1536, index_type: str = 'flat'): """Initialize FAISS vector store. @@ -63,9 +30,6 @@ class FAISSVectorStore(VectorStore): dimension: Embedding dimension (1536 for text-embedding-3-small) index_type: 'flat' for exact search, 'ivf' for larger datasets """ - if not FAISS_AVAILABLE: - raise RuntimeError("FAISS is not installed") - self.dimension = dimension self.metadata = [] self.ids = [] @@ -141,107 +105,6 @@ class FAISSVectorStore(VectorStore): return self.index.ntotal -class SimpleVectorStore(VectorStore): - """NumPy-based vector store implementation for development/small deployments.""" - - def __init__(self): - """Initialize simple NumPy-based vector store.""" - self.embeddings = [] - self.metadata = [] - self.ids = [] - logger.info("Created SimpleVectorStore (NumPy implementation)") - - def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): - """Add single embedding with metadata.""" - # Normalize for cosine similarity - normalized = embedding / np.linalg.norm(embedding) - self.embeddings.append(normalized) - self.metadata.append(metadata) - self.ids.append(id) - - def add_batch(self, ids: List[str], embeddings: np.ndarray, - metadata_list: List[Dict[str, Any]]): - """Batch add for initial ontology loading.""" - # Normalize all embeddings - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - # Avoid division by zero - norms = np.where(norms == 0, 1, norms) - normalized = embeddings / norms - - for i in range(len(ids)): - self.embeddings.append(normalized[i]) - self.metadata.append(metadata_list[i]) - self.ids.append(ids[i]) - - logger.debug(f"Added batch of {len(ids)} embeddings to simple store") - - def search(self, embedding: np.ndarray, top_k: int = 10, - threshold: float = 0.0) -> List[SearchResult]: - """Search for similar vectors using cosine similarity.""" - if not self.embeddings: - return [] - - # Normalize query embedding - embedding = embedding / np.linalg.norm(embedding) - - # Compute cosine similarities - embeddings_array = np.array(self.embeddings) - similarities = np.dot(embeddings_array, embedding) - - # Get top-k indices - top_k = min(top_k, len(self.embeddings)) - top_indices = np.argsort(similarities)[::-1][:top_k] - - # Build results - results = [] - for idx in top_indices: - if similarities[idx] >= threshold: - results.append(SearchResult( - id=self.ids[idx], - score=float(similarities[idx]), - metadata=self.metadata[idx] - )) - - return results - - def clear(self): - """Reset the store.""" - self.embeddings = [] - self.metadata = [] - self.ids = [] - logger.info("Cleared simple vector store") - - def size(self) -> int: - """Return number of stored vectors.""" - return len(self.embeddings) - - -class InMemoryVectorStore: - """Factory class to create appropriate vector store based on availability.""" - - @staticmethod - def create(dimension: int = 1536, prefer_faiss: bool = True, - index_type: str = 'flat') -> VectorStore: - """Create a vector store instance. - - Args: - dimension: Embedding dimension - prefer_faiss: Whether to prefer FAISS if available - index_type: Type of FAISS index ('flat' or 'ivf') - - Returns: - VectorStore instance (FAISS or Simple) - """ - if prefer_faiss and FAISS_AVAILABLE: - try: - return FAISSVectorStore(dimension, index_type) - except Exception as e: - logger.warning(f"Failed to create FAISS store: {e}, falling back to NumPy") - return SimpleVectorStore() - else: - return SimpleVectorStore() - - # Utility functions for vector operations def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """Compute cosine similarity between two vectors.""" @@ -264,4 +127,4 @@ def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndar # Compute dot product similarities = np.dot(queries_norm, targets_norm.T) - return similarities \ No newline at end of file + return similarities