trustgraph/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py
cybermaggedon f2ae0e8623
Embeddings API scores (#671)
- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
2026-03-09 10:53:44 +00:00

113 lines
3 KiB
Python
Executable file

"""
Graph embeddings, calls the embeddings service to get embeddings for a
set of entity contexts. Input is entity plus textual context.
Output is entity plus embedding.
"""
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec
import logging
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings"
default_batch_size = 5
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
self.batch_size = params.get("batch_size", default_batch_size)
super(Processor, self).__init__(
**params | {
"id": id,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = EntityContexts,
handler = self.on_message,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = GraphEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value()
logger.info(f"Indexing {v.metadata.id}...")
try:
# Collect all contexts for batch embedding
contexts = [entity.context for entity in v.entities]
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(
texts=contexts
)
# Pair results with entities
entities = [
EntityEmbeddings(
entity=entity.entity,
vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vector in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages
for i in range(0, len(entities), self.batch_size):
batch = entities[i:i + self.batch_size]
r = GraphEmbeddings(
metadata=v.metadata,
entities=batch,
)
await flow("output").send(r)
except Exception as e:
logger.error("Exception occurred", exc_info=True)
# Retry
raise e
logger.info("Done.")
@staticmethod
def add_args(parser):
parser.add_argument(
'--batch-size',
type=int,
default=default_batch_size,
help=f'Maximum entities per output message (default: {default_batch_size})'
)
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)