mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 17:39:39 +02:00
Update embeddings integration for new batch embeddings interfaces (#669)
* Fix vector extraction * Fix embeddings integration
This commit is contained in:
parent
0a2ce47a88
commit
919b760c05
12 changed files with 55 additions and 56 deletions
|
|
@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
|
|||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
all_vectors = await embeddings_client.embed([query_text])
|
||||
vectors = all_vectors[0] if all_vectors else []
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
|
|
|
|||
|
|
@ -67,8 +67,7 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
# vectors[0] is the vector set for the first (only) text
|
||||
# vectors[0][0] is the first vector in that set
|
||||
vectors = resp.vectors[0][0] if resp.vectors else []
|
||||
vectors = resp.vectors[0] if resp.vectors else []
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class Processor(FlowProcessor):
|
|||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
vectors=vectors, # Vector set for this entity
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors[0] # First vector from the set
|
||||
vectors=vectors # Vector set for this text
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -148,8 +148,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Detect embedding dimension by embedding a test string
|
||||
logger.info("Detecting embedding dimension from embeddings service...")
|
||||
test_embedding_response = await embeddings_client.embed("test")
|
||||
test_embedding = test_embedding_response[0] # Extract from [[vector]]
|
||||
test_embedding_response = await embeddings_client.embed(["test"])
|
||||
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
|
||||
dimension = len(test_embedding)
|
||||
logger.info(f"Detected embedding dimension: {dimension}")
|
||||
|
||||
|
|
|
|||
|
|
@ -153,13 +153,11 @@ class OntologyEmbedder:
|
|||
# Get embeddings for batch
|
||||
texts = [elem['text'] for elem in batch]
|
||||
try:
|
||||
# Call embedding service for each text
|
||||
# Note: embed() returns 2D array [[vector]], so extract first element
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Single batch embedding call
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
|
||||
# Extract vectors from responses (each is [[vector]])
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
|
||||
# Convert to numpy array
|
||||
embeddings = np.array(embeddings_list)
|
||||
|
|
@ -218,9 +216,9 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# embed() returns 2D array [[vector]], extract first element
|
||||
embedding_response = await self.embedding_service.embed(text)
|
||||
return np.array(embedding_response[0])
|
||||
# embed() with single text, extract first vector from first text
|
||||
embedding_response = await self.embedding_service.embed([text])
|
||||
return np.array(embedding_response[0][0])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed text: {e}")
|
||||
return None
|
||||
|
|
@ -239,11 +237,10 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# Call embed() for each text (returns [[vector]] per call)
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Extract first vector from each response
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
# Single batch embedding call
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
return np.array(embeddings_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed texts: {e}")
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
||||
return qembeds
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
|
||||
async def get_docs(self, query):
|
||||
|
||||
|
|
|
|||
|
|
@ -72,12 +72,13 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return qembeds
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
|
||||
async def get_entities(self, query):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue