trustgraph/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py
cybermaggedon f2ae0e8623
Embeddings API scores (#671)
- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
2026-03-09 10:53:44 +00:00

102 lines
2.4 KiB
Python
Executable file

"""
Document embeddings, calls the embeddings service to get embeddings for a
chunk of text. Input is chunk of text plus metadata.
Output is chunk plus embedding.
"""
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... base import FlowProcessor, RequestResponseSpec, ConsumerSpec
from ... base import ProducerSpec
import logging
logger = logging.getLogger(__name__)
default_ident = "document-embeddings"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
super(Processor, self).__init__(
**params | {
"id": id,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message,
)
)
self.register_specification(
RequestResponseSpec(
request_name = "embeddings-request",
request_schema = EmbeddingsRequest,
response_name = "embeddings-response",
response_schema = EmbeddingsResponse,
)
)
self.register_specification(
ProducerSpec(
name = "output",
schema = DocumentEmbeddings
)
)
async def on_message(self, msg, consumer, flow):
v = msg.value()
logger.info(f"Indexing {v.metadata.id}...")
try:
resp = await flow("embeddings-request").request(
EmbeddingsRequest(
texts=[v.chunk]
)
)
# vectors[0] is the vector for the first (only) text
vector = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(
chunk_id=v.document_id,
vector=vector,
)
]
r = DocumentEmbeddings(
metadata=v.metadata,
chunks=embeds,
)
await flow("output").send(r)
except Exception as e:
logger.error("Exception occurred", exc_info=True)
# Retry
raise e
logger.info("Done.")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
def run():
Processor.launch(default_ident, __doc__)