Update embeddings integration for new batch embeddings interfaces (#669)

* Fix vector extraction

* Fix embeddings integration
This commit is contained in:
cybermaggedon 2026-03-08 19:41:52 +00:00 committed by GitHub
parent 0a2ce47a88
commit 919b760c05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 55 additions and 56 deletions

View file

@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
logger.debug("Getting embeddings for row query...")
query_text = arguments.get("query")
vectors = await embeddings_client.embed(query_text)
all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else []
# Now query row embeddings
client = self.context("row-embeddings-query-request")

View file

@ -67,8 +67,7 @@ class Processor(FlowProcessor):
)
# vectors[0] is the vector set for the first (only) text
# vectors[0][0] is the first vector in that set
vectors = resp.vectors[0][0] if resp.vectors else []
vectors = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(

View file

@ -72,7 +72,7 @@ class Processor(FlowProcessor):
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors[0], # First vector from the set
vectors=vectors, # Vector set for this entity
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vectors in zip(v.entities, all_vectors)

View file

@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors[0] # First vector from the set
vectors=vectors # Vector set for this text
)
)

View file

@ -148,8 +148,8 @@ class Processor(FlowProcessor):
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
test_embedding_response = await embeddings_client.embed(["test"])
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")

View file

@ -153,13 +153,11 @@ class OntologyEmbedder:
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Call embedding service for each text
# Note: embed() returns 2D array [[vector]], so extract first element
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Single batch embedding call
embeddings_response = await self.embedding_service.embed(texts)
# Extract vectors from responses (each is [[vector]])
embeddings_list = [resp[0] for resp in embeddings_responses]
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
# Convert to numpy array
embeddings = np.array(embeddings_list)
@ -218,9 +216,9 @@ class OntologyEmbedder:
return None
try:
# embed() returns 2D array [[vector]], extract first element
embedding_response = await self.embedding_service.embed(text)
return np.array(embedding_response[0])
# embed() with single text, extract first vector from first text
embedding_response = await self.embedding_service.embed([text])
return np.array(embedding_response[0][0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
return None
@ -239,11 +237,10 @@ class OntologyEmbedder:
return None
try:
# Call embed() for each text (returns [[vector]] per call)
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract first vector from each response
embeddings_list = [resp[0] for resp in embeddings_responses]
# Single batch embedding call
embeddings_response = await self.embedding_service.embed(texts)
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
return np.array(embeddings_list)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")

View file

@ -24,12 +24,13 @@ class Query:
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed([query])
if self.verbose:
logger.debug("Embeddings computed")
return qembeds
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
async def get_docs(self, query):

View file

@ -72,12 +72,13 @@ class Query:
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed([query])
if self.verbose:
logger.debug("Done.")
return qembeds
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
async def get_entities(self, query):