Update embeddings integration for new batch embeddings interfaces (#669)

* Fix vector extraction

* Fix embeddings integration
This commit is contained in:
cybermaggedon 2026-03-08 19:41:52 +00:00 committed by GitHub
parent 0a2ce47a88
commit 919b760c05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 55 additions and 56 deletions

View file

@ -613,8 +613,8 @@ class AsyncFlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
@ -626,20 +626,20 @@ class AsyncFlowInstance:
return await self.request("graph-embeddings", request_data)
async def embeddings(self, text: str, **kwargs: Any):
async def embeddings(self, texts: list, **kwargs: Any):
"""
Generate embeddings for input text.
Generate embeddings for input texts.
Converts text into a numerical vector representation using the flow's
Converts texts into numerical vector representations using the flow's
configured embedding model. Useful for semantic search and similarity
comparisons.
Args:
text: Input text to embed
texts: List of input texts to embed
**kwargs: Additional service-specific parameters
Returns:
dict: Response containing embedding vector and metadata
dict: Response containing embedding vectors
Example:
```python
@ -647,12 +647,12 @@ class AsyncFlowInstance:
flow = async_flow.id("default")
# Generate embeddings
result = await flow.embeddings(text="Sample text to embed")
vector = result.get("embedding")
print(f"Embedding dimension: {len(vector)}")
result = await flow.embeddings(texts=["Sample text to embed"])
vectors = result.get("vectors")
print(f"Embedding dimension: {len(vectors[0][0])}")
```
"""
request_data = {"text": text}
request_data = {"texts": texts}
request_data.update(kwargs)
return await self.request("embeddings", request_data)
@ -811,8 +811,8 @@ class AsyncFlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,

View file

@ -283,8 +283,8 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
@ -296,9 +296,9 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("graph-embeddings", self.flow_id, request)
async def embeddings(self, text: str, **kwargs):
async def embeddings(self, texts: list, **kwargs):
"""Generate text embeddings"""
request = {"text": text}
request = {"texts": texts}
request.update(kwargs)
return await self.client._send_request("embeddings", self.flow_id, request)
@ -353,8 +353,8 @@ class AsyncSocketFlowInstance:
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,

View file

@ -603,8 +603,8 @@ class FlowInstance:
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search
input = {
@ -649,8 +649,8 @@ class FlowInstance:
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search
input = {
@ -1363,8 +1363,8 @@ class FlowInstance:
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search
input = {

View file

@ -650,8 +650,8 @@ class SocketFlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
@ -699,8 +699,8 @@ class SocketFlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
@ -937,8 +937,8 @@ class SocketFlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,

View file

@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
logger.debug("Getting embeddings for row query...")
query_text = arguments.get("query")
vectors = await embeddings_client.embed(query_text)
all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else []
# Now query row embeddings
client = self.context("row-embeddings-query-request")

View file

@ -67,8 +67,7 @@ class Processor(FlowProcessor):
)
# vectors[0] is the vector set for the first (only) text
# vectors[0][0] is the first vector in that set
vectors = resp.vectors[0][0] if resp.vectors else []
vectors = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(

View file

@ -72,7 +72,7 @@ class Processor(FlowProcessor):
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors[0], # First vector from the set
vectors=vectors, # Vector set for this entity
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vectors in zip(v.entities, all_vectors)

View file

@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors[0] # First vector from the set
vectors=vectors # Vector set for this text
)
)

View file

@ -148,8 +148,8 @@ class Processor(FlowProcessor):
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
test_embedding_response = await embeddings_client.embed(["test"])
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")

View file

@ -153,13 +153,11 @@ class OntologyEmbedder:
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Call embedding service for each text
# Note: embed() returns 2D array [[vector]], so extract first element
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Single batch embedding call
embeddings_response = await self.embedding_service.embed(texts)
# Extract vectors from responses (each is [[vector]])
embeddings_list = [resp[0] for resp in embeddings_responses]
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
# Convert to numpy array
embeddings = np.array(embeddings_list)
@ -218,9 +216,9 @@ class OntologyEmbedder:
return None
try:
# embed() returns 2D array [[vector]], extract first element
embedding_response = await self.embedding_service.embed(text)
return np.array(embedding_response[0])
# embed() with single text, extract first vector from first text
embedding_response = await self.embedding_service.embed([text])
return np.array(embedding_response[0][0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
return None
@ -239,11 +237,10 @@ class OntologyEmbedder:
return None
try:
# Call embed() for each text (returns [[vector]] per call)
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract first vector from each response
embeddings_list = [resp[0] for resp in embeddings_responses]
# Single batch embedding call
embeddings_response = await self.embedding_service.embed(texts)
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
return np.array(embeddings_list)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")

View file

@ -24,12 +24,13 @@ class Query:
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed([query])
if self.verbose:
logger.debug("Embeddings computed")
return qembeds
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
async def get_docs(self, query):

View file

@ -72,12 +72,13 @@ class Query:
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed([query])
if self.verbose:
logger.debug("Done.")
return qembeds
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
async def get_entities(self, query):