diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 403f27db..764600dc 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -74,12 +74,12 @@ class Processor(FlowProcessor): # Register config handler for ontology updates self.register_config_handler(self.on_ontology_config) - # Initialize components - self.ontology_loader = None - self.ontology_embedder = None + # Shared components (not flow-specific) + self.ontology_loader = OntologyLoader() self.text_processor = TextProcessor() - self.ontology_selector = None - self.initialized = False + + # Per-flow components (each flow gets its own embedder/vector store/selector) + self.flow_components = {} # flow_id -> {embedder, vector_store, selector} # Configuration self.top_k = params.get("top_k", 10) @@ -88,17 +88,27 @@ class Processor(FlowProcessor): # Track loaded ontology version self.current_ontology_version = None self.loaded_ontology_ids = set() - self.pending_config = None # Store config until components initialized - async def initialize_components(self, flow): - """Initialize OntoRAG components.""" - if self.initialized: - return + async def initialize_flow_components(self, flow): + """Initialize per-flow OntoRAG components. + + Each flow gets its own vector store and embedder to support + different embedding models across flows. + + Args: + flow: Flow object for this processing context + + Returns: + flow_id: Identifier for this flow's components + """ + # Use flow object as identifier + flow_id = id(flow) + + if flow_id in self.flow_components: + return flow_id # Already initialized for this flow try: - # Initialize ontology loader (no ConfigTableStore needed) - self.ontology_loader = OntologyLoader() - logger.info("Ontology loader initialized") + logger.info(f"Initializing components for flow {flow_id}") # Initialize vector store (FAISS only, no fallback) vector_store = InMemoryVectorStore( @@ -109,37 +119,46 @@ class Processor(FlowProcessor): # Use embeddings client directly (no wrapper needed) embeddings_client = flow("embeddings-request") - self.ontology_embedder = OntologyEmbedder( + ontology_embedder = OntologyEmbedder( embedding_service=embeddings_client, vector_store=vector_store ) + # Embed all loaded ontologies for this flow + if self.ontology_loader.get_all_ontologies(): + logger.info(f"Embedding ontologies for flow {flow_id}") + for ont_id, ontology in self.ontology_loader.get_all_ontologies().items(): + await ontology_embedder.embed_ontology(ontology) + logger.info(f"Embedded {ontology_embedder.get_embedded_count()} ontology elements for flow {flow_id}") + # Initialize ontology selector - self.ontology_selector = OntologySelector( - ontology_embedder=self.ontology_embedder, + ontology_selector = OntologySelector( + ontology_embedder=ontology_embedder, ontology_loader=self.ontology_loader, top_k=self.top_k, similarity_threshold=self.similarity_threshold ) - self.initialized = True - logger.info("OntoRAG components initialized successfully") + # Store flow-specific components + self.flow_components[flow_id] = { + 'embedder': ontology_embedder, + 'vector_store': vector_store, + 'selector': ontology_selector + } - # Process pending config if available - if self.pending_config: - logger.info("Processing pending config from startup") - config, version = self.pending_config - self.pending_config = None - await self.on_ontology_config(config, version) + logger.info(f"Flow {flow_id} components initialized successfully") + return flow_id except Exception as e: - logger.error(f"Failed to initialize OntoRAG components: {e}", exc_info=True) + logger.error(f"Failed to initialize flow {flow_id} components: {e}", exc_info=True) raise async def on_ontology_config(self, config, version): """ Handle ontology configuration updates from ConfigPush queue. + Parses and stores ontologies. Embedding happens per-flow on first message. + Called automatically when: - Processor starts (gets full config history via start_of_messages=True) - Config service pushes updates (immediate event-driven notification) @@ -161,12 +180,6 @@ class Processor(FlowProcessor): logger.warning("No 'ontology' section in config") return - # Check if components are initialized - if not self.ontology_loader: - logger.debug("Components not yet initialized, storing config for later processing") - self.pending_config = (config, version) - return - ontology_configs = config["ontology"] # Parse ontology definitions @@ -196,20 +209,10 @@ class Processor(FlowProcessor): # Update ontology loader's internal state self.ontology_loader.update_ontologies(ontologies) - # Re-embed changed ontologies - if self.ontology_embedder: - # Remove embeddings for deleted ontologies - for ont_id in removed_ids: - self.ontology_embedder.remove_ontology(ont_id) - - # Embed new and updated ontologies - for ont_id in added_ids | updated_ids: - if ont_id in self.ontology_loader.get_all_ontologies(): - await self.ontology_embedder.embed_ontology( - self.ontology_loader.get_ontology(ont_id) - ) - - logger.info(f"Re-embedded ontologies, total elements: {self.ontology_embedder.get_embedded_count()}") + # Clear all flow components to force re-embedding with new ontologies + if added_ids or removed_ids or updated_ids: + logger.info("Clearing flow components to trigger re-embedding") + self.flow_components.clear() # Update tracking self.current_ontology_version = version @@ -225,9 +228,9 @@ class Processor(FlowProcessor): v = msg.value() logger.info(f"Extracting ontology-based triples from {v.metadata.id}...") - # Initialize components if needed - if not self.initialized: - await self.initialize_components(flow) + # Initialize flow-specific components if needed + flow_id = await self.initialize_flow_components(flow) + components = self.flow_components[flow_id] chunk = v.chunk.decode("utf-8") logger.debug(f"Processing chunk: {chunk[:200]}...") @@ -237,8 +240,8 @@ class Processor(FlowProcessor): segments = self.text_processor.process_chunk(chunk, extract_phrases=True) logger.debug(f"Split chunk into {len(segments)} segments") - # Select relevant ontology subset - ontology_subsets = await self.ontology_selector.select_ontology_subset(segments) + # Select relevant ontology subset (using flow-specific selector) + ontology_subsets = await components['selector'].select_ontology_subset(segments) if not ontology_subsets: logger.warning("No relevant ontology elements found for chunk") @@ -252,7 +255,7 @@ class Processor(FlowProcessor): # Merge subsets if multiple ontologies matched if len(ontology_subsets) > 1: - ontology_subset = self.ontology_selector.merge_subsets(ontology_subsets) + ontology_subset = components['selector'].merge_subsets(ontology_subsets) else: ontology_subset = ontology_subsets[0]