mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Update embeddings integration for new batch embeddings interfaces (#669)
* Fix vector extraction * Fix embeddings integration
This commit is contained in:
parent
0a2ce47a88
commit
919b760c05
12 changed files with 55 additions and 56 deletions
|
|
@ -613,8 +613,8 @@ class AsyncFlowInstance:
|
|||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
|
|
@ -626,20 +626,20 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("graph-embeddings", request_data)
|
||||
|
||||
async def embeddings(self, text: str, **kwargs: Any):
|
||||
async def embeddings(self, texts: list, **kwargs: Any):
|
||||
"""
|
||||
Generate embeddings for input text.
|
||||
Generate embeddings for input texts.
|
||||
|
||||
Converts text into a numerical vector representation using the flow's
|
||||
Converts texts into numerical vector representations using the flow's
|
||||
configured embedding model. Useful for semantic search and similarity
|
||||
comparisons.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
texts: List of input texts to embed
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
dict: Response containing embedding vector and metadata
|
||||
dict: Response containing embedding vectors
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
|
@ -647,12 +647,12 @@ class AsyncFlowInstance:
|
|||
flow = async_flow.id("default")
|
||||
|
||||
# Generate embeddings
|
||||
result = await flow.embeddings(text="Sample text to embed")
|
||||
vector = result.get("embedding")
|
||||
print(f"Embedding dimension: {len(vector)}")
|
||||
result = await flow.embeddings(texts=["Sample text to embed"])
|
||||
vectors = result.get("vectors")
|
||||
print(f"Embedding dimension: {len(vectors[0][0])}")
|
||||
```
|
||||
"""
|
||||
request_data = {"text": text}
|
||||
request_data = {"texts": texts}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
|
@ -811,8 +811,8 @@ class AsyncFlowInstance:
|
|||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
|
|
|
|||
|
|
@ -283,8 +283,8 @@ class AsyncSocketFlowInstance:
|
|||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
|
|
@ -296,9 +296,9 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("graph-embeddings", self.flow_id, request)
|
||||
|
||||
async def embeddings(self, text: str, **kwargs):
|
||||
async def embeddings(self, texts: list, **kwargs):
|
||||
"""Generate text embeddings"""
|
||||
request = {"text": text}
|
||||
request = {"texts": texts}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
|
@ -353,8 +353,8 @@ class AsyncSocketFlowInstance:
|
|||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
|
|
|
|||
|
|
@ -603,8 +603,8 @@ class FlowInstance:
|
|||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query graph embeddings for semantic search
|
||||
input = {
|
||||
|
|
@ -649,8 +649,8 @@ class FlowInstance:
|
|||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query document embeddings for semantic search
|
||||
input = {
|
||||
|
|
@ -1363,8 +1363,8 @@ class FlowInstance:
|
|||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query row embeddings for semantic search
|
||||
input = {
|
||||
|
|
|
|||
|
|
@ -650,8 +650,8 @@ class SocketFlowInstance:
|
|||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
|
|
@ -699,8 +699,8 @@ class SocketFlowInstance:
|
|||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
|
|
@ -937,8 +937,8 @@ class SocketFlowInstance:
|
|||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
|
|
|
|||
|
|
@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
|
|||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
all_vectors = await embeddings_client.embed([query_text])
|
||||
vectors = all_vectors[0] if all_vectors else []
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
|
|
|
|||
|
|
@ -67,8 +67,7 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
# vectors[0] is the vector set for the first (only) text
|
||||
# vectors[0][0] is the first vector in that set
|
||||
vectors = resp.vectors[0][0] if resp.vectors else []
|
||||
vectors = resp.vectors[0] if resp.vectors else []
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class Processor(FlowProcessor):
|
|||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
vectors=vectors, # Vector set for this entity
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors[0] # First vector from the set
|
||||
vectors=vectors # Vector set for this text
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -148,8 +148,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Detect embedding dimension by embedding a test string
|
||||
logger.info("Detecting embedding dimension from embeddings service...")
|
||||
test_embedding_response = await embeddings_client.embed("test")
|
||||
test_embedding = test_embedding_response[0] # Extract from [[vector]]
|
||||
test_embedding_response = await embeddings_client.embed(["test"])
|
||||
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
|
||||
dimension = len(test_embedding)
|
||||
logger.info(f"Detected embedding dimension: {dimension}")
|
||||
|
||||
|
|
|
|||
|
|
@ -153,13 +153,11 @@ class OntologyEmbedder:
|
|||
# Get embeddings for batch
|
||||
texts = [elem['text'] for elem in batch]
|
||||
try:
|
||||
# Call embedding service for each text
|
||||
# Note: embed() returns 2D array [[vector]], so extract first element
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Single batch embedding call
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
|
||||
# Extract vectors from responses (each is [[vector]])
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
|
||||
# Convert to numpy array
|
||||
embeddings = np.array(embeddings_list)
|
||||
|
|
@ -218,9 +216,9 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# embed() returns 2D array [[vector]], extract first element
|
||||
embedding_response = await self.embedding_service.embed(text)
|
||||
return np.array(embedding_response[0])
|
||||
# embed() with single text, extract first vector from first text
|
||||
embedding_response = await self.embedding_service.embed([text])
|
||||
return np.array(embedding_response[0][0])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed text: {e}")
|
||||
return None
|
||||
|
|
@ -239,11 +237,10 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# Call embed() for each text (returns [[vector]] per call)
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Extract first vector from each response
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
# Single batch embedding call
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
return np.array(embeddings_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed texts: {e}")
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
||||
return qembeds
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
|
||||
async def get_docs(self, query):
|
||||
|
||||
|
|
|
|||
|
|
@ -72,12 +72,13 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return qembeds
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
|
||||
async def get_entities(self, query):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue