Protect from null embeddings in cores (#626)

* Don't emit graph embeddings if there aren't any.

* Don't store graph embeddings in a knowledge store if there's an empty list.

* Translate between Cassandra's 'null' representing an empty list and an
empty list which is what the surrounding code wants (and stored in the
first place).

* Avoid emitting empty embedding lists

* Avoid output empty triple lists

* Fix tests
This commit is contained in:
cybermaggedon 2026-02-09 14:07:07 +00:00 committed by GitHub
parent e214eb4e02
commit ca626c8471
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 88 additions and 91 deletions

View file

@ -16,7 +16,7 @@ from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsPr
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
@ -405,9 +405,14 @@ class TestKnowledgeGraphPipelineIntegration:
collection="test_collection", collection="test_collection",
metadata=[] metadata=[]
), ),
entities=[] entities=[
EntityEmbeddings(
entity=Value(value="http://example.org/entity", is_uri=True),
vectors=[[0.1, 0.2, 0.3]]
)
]
) )
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = sample_embeddings mock_msg.value.return_value = sample_embeddings
@ -496,12 +501,12 @@ class TestKnowledgeGraphPipelineIntegration:
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context) await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
# Assert # Assert
# Should still call producers but with empty results # Should NOT call producers with empty results (avoids Cassandra NULL issues)
triples_producer = mock_flow_context("triples") triples_producer = mock_flow_context("triples")
entity_contexts_producer = mock_flow_context("entity-contexts") entity_contexts_producer = mock_flow_context("entity-contexts")
triples_producer.send.assert_called_once() triples_producer.send.assert_not_called()
entity_contexts_producer.send.assert_called_once() entity_contexts_producer.send.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk): async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):

View file

@ -73,12 +73,13 @@ class Processor(FlowProcessor):
) )
) )
r = GraphEmbeddings( if entities:
metadata=v.metadata, r = GraphEmbeddings(
entities=entities, metadata=v.metadata,
) entities=entities,
)
await flow("output").send(r) await flow("output").send(r)
except Exception as e: except Exception as e:
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)

View file

@ -168,27 +168,29 @@ class Processor(FlowProcessor):
entities.append(ec) entities.append(ec)
await self.emit_triples( if triples:
flow("triples"), await self.emit_triples(
Metadata( flow("triples"),
id=v.metadata.id, Metadata(
metadata=[], id=v.metadata.id,
user=v.metadata.user, metadata=[],
collection=v.metadata.collection, user=v.metadata.user,
), collection=v.metadata.collection,
triples ),
) triples
)
await self.emit_ecs( if entities:
flow("entity-contexts"), await self.emit_ecs(
Metadata( flow("entity-contexts"),
id=v.metadata.id, Metadata(
metadata=[], id=v.metadata.id,
user=v.metadata.user, metadata=[],
collection=v.metadata.collection, user=v.metadata.user,
), collection=v.metadata.collection,
entities ),
) entities
)
except Exception as e: except Exception as e:
logger.error(f"Definitions extraction exception: {e}", exc_info=True) logger.error(f"Definitions extraction exception: {e}", exc_info=True)

View file

@ -274,17 +274,6 @@ class Processor(FlowProcessor):
if not ontology_subsets: if not ontology_subsets:
logger.warning("No relevant ontology elements found for chunk") logger.warning("No relevant ontology elements found for chunk")
# Emit empty outputs
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
return return
# Merge subsets if multiple ontologies matched # Merge subsets if multiple ontologies matched
@ -319,35 +308,26 @@ class Processor(FlowProcessor):
entity_contexts = self.build_entity_contexts(all_triples) entity_contexts = self.build_entity_contexts(all_triples)
# Emit all triples (extracted + ontology definitions) # Emit all triples (extracted + ontology definitions)
await self.emit_triples( if all_triples:
flow("triples"), await self.emit_triples(
v.metadata, flow("triples"),
all_triples v.metadata,
) all_triples
)
# Emit entity contexts # Emit entity contexts
await self.emit_entity_contexts( if entity_contexts:
flow("entity-contexts"), await self.emit_entity_contexts(
v.metadata, flow("entity-contexts"),
entity_contexts v.metadata,
) entity_contexts
)
logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples " logger.info(f"Extracted {len(triples)} content triples + {len(ontology_triples)} ontology triples "
f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts") f"= {len(all_triples)} total triples and {len(entity_contexts)} entity contexts")
except Exception as e: except Exception as e:
logger.error(f"OntoRAG extraction exception: {e}", exc_info=True) logger.error(f"OntoRAG extraction exception: {e}", exc_info=True)
# Emit empty outputs on error
await self.emit_triples(
flow("triples"),
v.metadata,
[]
)
await self.emit_entity_contexts(
flow("entity-contexts"),
v.metadata,
[]
)
async def extract_with_simplified_format( async def extract_with_simplified_format(
self, self,

View file

@ -181,16 +181,17 @@ class Processor(FlowProcessor):
o=Value(value=v.metadata.id, is_uri=True) o=Value(value=v.metadata.id, is_uri=True)
)) ))
await self.emit_triples( if triples:
flow("triples"), await self.emit_triples(
Metadata( flow("triples"),
id=v.metadata.id, Metadata(
metadata=[], id=v.metadata.id,
user=v.metadata.user, metadata=[],
collection=v.metadata.collection, user=v.metadata.user,
), collection=v.metadata.collection,
triples ),
) triples
)
except Exception as e: except Exception as e:
logger.error(f"Relationship extraction exception: {e}", exc_info=True) logger.error(f"Relationship extraction exception: {e}", exc_info=True)

View file

@ -64,12 +64,14 @@ class Processor(FlowProcessor):
async def on_triples(self, msg, consumer, flow): async def on_triples(self, msg, consumer, flow):
v = msg.value() v = msg.value()
await self.table_store.add_triples(v) if v.triples:
await self.table_store.add_triples(v)
async def on_graph_embeddings(self, msg, consumer, flow): async def on_graph_embeddings(self, msg, consumer, flow):
v = msg.value() v = msg.value()
await self.table_store.add_graph_embeddings(v) if v.entities:
await self.table_store.add_graph_embeddings(v)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -423,14 +423,17 @@ class KnowledgeTableStore:
else: else:
metadata = [] metadata = []
triples = [ if row[3]:
Triple( triples = [
s = Value(value = elt[0], is_uri = elt[1]), Triple(
p = Value(value = elt[2], is_uri = elt[3]), s = Value(value = elt[0], is_uri = elt[1]),
o = Value(value = elt[4], is_uri = elt[5]), p = Value(value = elt[2], is_uri = elt[3]),
) o = Value(value = elt[4], is_uri = elt[5]),
for elt in row[3] )
] for elt in row[3]
]
else:
triples = []
await receiver( await receiver(
Triples( Triples(
@ -479,13 +482,16 @@ class KnowledgeTableStore:
else: else:
metadata = [] metadata = []
entities = [ if row[3]:
EntityEmbeddings( entities = [
entity = Value(value = ent[0][0], is_uri = ent[0][1]), EntityEmbeddings(
vectors = ent[1] entity = Value(value = ent[0][0], is_uri = ent[0][1]),
) vectors = ent[1]
for ent in row[3] )
] for ent in row[3]
]
else:
entities = []
await receiver( await receiver(
GraphEmbeddings( GraphEmbeddings(