diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 15b0c82d..e763d089 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -193,12 +193,20 @@ class TestQuery: test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = [test_vectors] - # Mock entity objects that have string representation + # Mock EntityMatch objects with entity that has string representation mock_entity1 = MagicMock() mock_entity1.__str__ = MagicMock(return_value="entity1") + mock_match1 = MagicMock() + mock_match1.entity = mock_entity1 + mock_match1.score = 0.95 + mock_entity2 = MagicMock() mock_entity2.__str__ = MagicMock(return_value="entity2") - mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] + mock_match2 = MagicMock() + mock_match2.entity = mock_entity2 + mock_match2.score = 0.85 + + mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2] # Initialize Query query = Query( @@ -216,9 +224,9 @@ class TestQuery: # Verify embeddings client was called (now expects list) mock_embeddings_client.embed.assert_called_once_with([test_query]) - # Verify graph embeddings client was called correctly (with extracted vectors) + # Verify graph embeddings client was called correctly (with extracted vector) mock_graph_embeddings_client.query.assert_called_once_with( - vectors=test_vectors, + vector=test_vectors, limit=25, user="test_user", collection="test_collection" diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index b4d7aac7..2ff37307 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -612,12 +612,12 @@ class AsyncFlowInstance: print(f"{entity['name']}: {entity['score']}") ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -810,12 +810,12 @@ class AsyncFlowInstance: print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request_data = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 843c5979..99938d5b 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -282,12 +282,12 @@ class AsyncSocketFlowInstance: async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): """Query graph embeddings for semantic search""" - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -352,12 +352,12 @@ class AsyncSocketFlowInstance: limit: int = 10, **kwargs ): """Query row embeddings for semantic search on structured data""" - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = await self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index f2dad323..49e2f9fa 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -602,13 +602,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query graph embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -648,13 +648,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query document embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -1362,13 +1362,13 @@ class FlowInstance: ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] # Query row embeddings for semantic search input = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index d68d9e98..113ebe35 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -649,12 +649,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -698,12 +698,12 @@ class SocketFlowInstance: # results contains {"chunk_ids": ["doc1/p0/c0", ...]} ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "user": user, "collection": collection, "limit": limit @@ -936,12 +936,12 @@ class SocketFlowInstance: ) ``` """ - # First convert text to embeddings vectors + # First convert text to embedding vector emb_result = self.embeddings(texts=[text]) - vectors = emb_result.get("vectors", [[]])[0] + vector = emb_result.get("vectors", [[]])[0] request = { - "vectors": vectors, + "vector": vector, "schema_name": schema_name, "user": user, "collection": collection, diff --git a/trustgraph-base/trustgraph/base/document_embeddings_client.py b/trustgraph-base/trustgraph/base/document_embeddings_client.py index d403ff21..dd985eab 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_client.py @@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal logger = logging.getLogger(__name__) class DocumentEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( DocumentEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return resp.chunk_ids + # Return ChunkMatch objects with chunk_id and score + return resp.chunks class DocumentEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index 013847d4..b8979776 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): docs = await self.query_document_embeddings(request) logger.debug("Sending document embeddings query response...") - r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None) + r = DocumentEmbeddingsResponse(chunks=docs, error=None) await flow("response").send(r, properties={"id": id}) logger.debug("Document embeddings query request completed") @@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor): type = "document-embeddings-query-error", message = str(e), ), - chunk_ids=[], + chunks=[], ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_client.py b/trustgraph-base/trustgraph/base/graph_embeddings_client.py index 07eb2bc7..fec82378 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_client.py @@ -19,12 +19,12 @@ def to_value(x): return Literal(x.value or x.iri) class GraphEmbeddingsClient(RequestResponse): - async def query(self, vectors, limit=20, user="trustgraph", + async def query(self, vector, limit=20, user="trustgraph", collection="default", timeout=30): resp = await self.request( GraphEmbeddingsRequest( - vectors = vectors, + vector = vector, limit = limit, user = user, collection = collection @@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - return [ - to_value(v) - for v in resp.entities - ] + # Return EntityMatch objects with entity and score + return resp.entities class GraphEmbeddingsClientSpec(RequestResponseSpec): def __init__( diff --git a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py index 0141da31..811adf40 100644 --- a/trustgraph-base/trustgraph/base/row_embeddings_query_client.py +++ b/trustgraph-base/trustgraph/base/row_embeddings_query_client.py @@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse class RowEmbeddingsQueryClient(RequestResponse): async def row_embeddings_query( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=600 ): request = RowEmbeddingsRequest( - vectors=vectors, + vector=vector, schema_name=schema_name, user=user, collection=collection, diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index 124cf3c8..1ab47aab 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).chunks diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 1a7a9512..f85c91ee 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient): ) def request( - self, vectors, user="trustgraph", collection="default", + self, vector, user="trustgraph", collection="default", limit=10, timeout=300 ): return self.call( user=user, collection=collection, - vectors=vectors, limit=limit, timeout=timeout + vector=vector, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/row_embeddings_client.py b/trustgraph-base/trustgraph/clients/row_embeddings_client.py index 4f911e3c..19d4b338 100644 --- a/trustgraph-base/trustgraph/clients/row_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/row_embeddings_client.py @@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient): ) def request( - self, vectors, schema_name, user="trustgraph", collection="default", + self, vector, schema_name, user="trustgraph", collection="default", index_name=None, limit=10, timeout=300 ): kwargs = dict( user=user, collection=collection, - vectors=vectors, schema_name=schema_name, + vector=vector, schema_name=schema_name, limit=limit, timeout=timeout ) if index_name: diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py index cc5f1534..f10ca4c6 100644 --- a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -10,18 +10,18 @@ from .primitives import ValueTranslator class DocumentEmbeddingsRequestTranslator(MessageTranslator): """Translator for DocumentEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: return DocumentEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator): """Translator for DocumentEmbeddingsResponse schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: result = {} - if obj.chunk_ids is not None: - result["chunk_ids"] = list(obj.chunk_ids) + if obj.chunks is not None: + result["chunks"] = [ + { + "chunk_id": chunk.chunk_id, + "score": chunk.score + } + for chunk in obj.chunks + ] return result - + def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator): class GraphEmbeddingsRequestTranslator(MessageTranslator): """Translator for GraphEmbeddingsRequest schema objects""" - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: return GraphEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default") ) - + def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: return { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection @@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator): class GraphEmbeddingsResponseTranslator(MessageTranslator): """Translator for GraphEmbeddingsResponse schema objects""" - + def __init__(self): self.value_translator = ValueTranslator() - + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - + def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: result = {} - + if obj.entities is not None: result["entities"] = [ - self.value_translator.from_pulsar(entity) - for entity in obj.entities + { + "entity": self.value_translator.from_pulsar(match.entity), + "score": match.score + } + for match in obj.entities ] - + return result - + def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" return self.from_pulsar(obj), True @@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: return RowEmbeddingsRequest( - vectors=data["vectors"], + vector=data["vector"], limit=int(data.get("limit", 10)), user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), @@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator): def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: result = { - "vectors": obj.vectors, + "vector": obj.vector, "limit": obj.limit, "user": obj.user, "collection": obj.collection, diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index c7d5b775..a8bae35c 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -11,7 +11,7 @@ from ..core.topic import topic @dataclass class EntityEmbeddings: entity: Term | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) # Provenance: which chunk this embedding was derived from chunk_id: str = "" @@ -28,7 +28,7 @@ class GraphEmbeddings: @dataclass class ChunkEmbeddings: chunk_id: str = "" - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) # This is a 'batching' mechanism for the above data @dataclass @@ -44,7 +44,7 @@ class DocumentEmbeddings: @dataclass class ObjectEmbeddings: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) name: str = "" key_name: str = "" id: str = "" @@ -56,7 +56,7 @@ class ObjectEmbeddings: @dataclass class StructuredObjectEmbedding: metadata: Metadata | None = None - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) schema_name: str = "" object_id: str = "" # Primary key value field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings @@ -72,7 +72,7 @@ class RowIndexEmbedding: index_name: str = "" # The indexed field name(s) index_value: list[str] = field(default_factory=list) # The field value(s) text: str = "" # Text that was embedded - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) @dataclass class RowEmbeddings: diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index a9d19e51..681638c3 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -34,7 +34,7 @@ class EmbeddingsRequest: @dataclass class EmbeddingsResponse: error: Error | None = None - vectors: list[list[list[float]]] = field(default_factory=list) + vectors: list[list[float]] = field(default_factory=list) ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 68857e07..67caa2be 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -9,15 +9,21 @@ from ..core.topic import topic @dataclass class GraphEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class EntityMatch: + """A matching entity from a semantic search with similarity score""" + entity: Term | None = None + score: float = 0.0 + @dataclass class GraphEmbeddingsResponse: error: Error | None = None - entities: list[Term] = field(default_factory=list) + entities: list[EntityMatch] = field(default_factory=list) ############################################################################ @@ -44,15 +50,21 @@ class TriplesQueryResponse: @dataclass class DocumentEmbeddingsRequest: - vectors: list[list[float]] = field(default_factory=list) + vector: list[float] = field(default_factory=list) limit: int = 0 user: str = "" collection: str = "" +@dataclass +class ChunkMatch: + """A matching chunk from a semantic search with similarity score""" + chunk_id: str = "" + score: float = 0.0 + @dataclass class DocumentEmbeddingsResponse: error: Error | None = None - chunk_ids: list[str] = field(default_factory=list) + chunks: list[ChunkMatch] = field(default_factory=list) document_embeddings_request_queue = topic( "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' @@ -76,7 +88,7 @@ class RowIndexMatch: @dataclass class RowEmbeddingsRequest: """Request for row embeddings semantic search""" - vectors: list[list[float]] = field(default_factory=list) # Query vectors + vector: list[float] = field(default_factory=list) # Query vector limit: int = 10 # Max results to return user: str = "" # User/keyspace collection: str = "" # Collection name diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 71fe7409..441c8f38 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl: query_text = arguments.get("query") all_vectors = await embeddings_client.embed([query_text]) - vectors = all_vectors[0] if all_vectors else [] + vector = all_vectors[0] if all_vectors else [] # Now query row embeddings client = self.context("row-embeddings-query-request") @@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl: user = getattr(client, '_current_user', self.user or "trustgraph") matches = await client.row_embeddings_query( - vectors=vectors, + vector=vector, schema_name=self.schema_name, user=user, collection=self.collection or "default", diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index f038a9b5..16ca1ad9 100755 --- a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -66,13 +66,13 @@ class Processor(FlowProcessor): ) ) - # vectors[0] is the vector set for the first (only) text - vectors = resp.vectors[0] if resp.vectors else [] + # vectors[0] is the vector for the first (only) text + vector = resp.vectors[0] if resp.vectors else [] embeds = [ ChunkEmbeddings( chunk_id=v.document_id, - vectors=vectors, + vector=vector, ) ] diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index ac2c6f49..1a03ac9f 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -59,11 +59,8 @@ class Processor(EmbeddingsService): # FastEmbed processes the full batch efficiently vecs = list(self.embeddings.embed(texts)) - # Return list of vector sets, one per input text - return [ - [v.tolist()] - for v in vecs - ] + # Return list of vectors, one per input text + return [v.tolist() for v in vecs] @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index e83d608b..3b441bd6 100755 --- a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -72,10 +72,10 @@ class Processor(FlowProcessor): entities = [ EntityEmbeddings( entity=entity.entity, - vectors=vectors, # Vector set for this entity + vector=vector, chunk_id=entity.chunk_id, # Provenance: source chunk ) - for entity, vectors in zip(v.entities, all_vectors) + for entity, vector in zip(v.entities, all_vectors) ] # Send in batches to avoid oversized messages diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index c95850e2..a65b4ff7 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -43,11 +43,8 @@ class Processor(EmbeddingsService): input = texts ) - # Return list of vector sets, one per input text - return [ - [embedding] - for embedding in embeds.embeddings - ] + # Return list of vectors, one per input text + return list(embeds.embeddings) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py index f81e4374..1365cb14 100644 --- a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): all_vectors = await flow("embeddings-request").embed(texts=texts) # Pair results with metadata - for text, (index_name, index_value), vectors in zip( + for text, (index_name, index_value), vector in zip( texts, metadata, all_vectors ): embeddings_list.append( @@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): index_name=index_name, index_value=index_value, text=text, - vectors=vectors # Vector set for this text + vector=vector ) ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 6d897b71..98350961 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -7,7 +7,7 @@ of chunk_ids import logging from .... direct.milvus_doc_embeddings import DocVectors -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] - chunk_ids = [] + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit + ) - for vec in msg.vectors: + chunks = [] + for r in resp: + chunk_id = r["entity"]["chunk_id"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit - ) - - for r in resp: - chunk_id = r["entity"]["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 41857ab0..406f979c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -11,6 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig +from .... schema import ChunkMatch from .... base import DocumentEmbeddingsQueryService # Module logger @@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] - chunk_ids = [] + dim = len(vec) - for vec in msg.vectors: + # Use dimension suffix in index name + index_name = f"d-{msg.user}-{msg.collection}-{dim}" - dim = len(vec) + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] - # Use dimension suffix in index name - index_name = f"d-{msg.user}-{msg.collection}-{dim}" + index = self.pinecone.Index(index_name) - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue + results = index.query( + vector=vec, + top_k=msg.limit, + include_values=False, + include_metadata=True + ) - index = self.pinecone.Index(index_name) + chunks = [] + for r in results.matches: + chunk_id = r.metadata["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - results = index.query( - vector=vec, - top_k=msg.limit, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - chunk_id = r.metadata["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 562023c7..f056b1c1 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import DocumentEmbeddingsResponse +from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService @@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService): try: - chunk_ids = [] + vec = msg.vector + if not vec: + return [] - for vec in msg.vectors: + # Use dimension suffix in collection name + dim = len(vec) + collection = f"d_{msg.user}_{msg.collection}_{dim}" - # Use dimension suffix in collection name - dim = len(vec) - collection = f"d_{msg.user}_{msg.collection}_{dim}" + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist, returning empty results") + return [] - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, returning empty results") - continue + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit, + with_payload=True, + ).points - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit, - with_payload=True, - ).points + chunks = [] + for r in search_result: + chunk_id = r.payload["chunk_id"] + score = r.score if hasattr(r, 'score') else 0.0 + chunks.append(ChunkMatch( + chunk_id=chunk_id, + score=score, + )) - for r in search_result: - chunk_id = r.payload["chunk_id"] - chunk_ids.append(chunk_id) - - return chunk_ids + return chunks except Exception as e: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index c5cdb6d8..94eee387 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -7,7 +7,7 @@ entities import logging from .... direct.milvus_graph_embeddings import EntityVectors -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService): try: - entity_set = set() - entities = [] + vec = msg.vector + if not vec: + return [] # Handle zero limit case if msg.limit <= 0: return [] - for vec in msg.vectors: + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit * 2 + ) - resp = self.vecstore.search( - vec, - msg.user, - msg.collection, - limit=msg.limit * 2 - ) + entity_set = set() + entities = [] - for r in resp: - ent = r["entity"]["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) + for r in resp: + ent = r["entity"]["entity"] + # Milvus returns distance, convert to similarity score + distance = r.get("distance", 0.0) + score = 1.0 - distance if distance else 0.0 - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 5882f21c..ca443a6f 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -11,7 +11,7 @@ import os from pinecone import Pinecone, ServerlessSpec from pinecone.grpc import PineconeGRPC, GRPCClientConfig -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + # Handle zero limit case if msg.limit <= 0: return [] + dim = len(vec) + + # Use dimension suffix in index name + index_name = f"t-{msg.user}-{msg.collection}-{dim}" + + # Check if index exists - return empty if not + if not self.pinecone.has_index(index_name): + logger.info(f"Index {index_name} does not exist") + return [] + + index = self.pinecone.Index(index_name) + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + results = index.query( + vector=vec, + top_k=msg.limit * 2, + include_values=False, + include_metadata=True + ) + entity_set = set() entities = [] - for vec in msg.vectors: + for r in results.matches: + ent = r.metadata["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - dim = len(vec) - - # Use dimension suffix in index name - index_name = f"t-{msg.user}-{msg.collection}-{dim}" - - # Check if index exists - skip if not - if not self.pinecone.has_index(index_name): - logger.info(f"Index {index_name} does not exist, skipping this vector") - continue - - index = self.pinecone.Index(index_name) - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - results = index.query( - vector=vec, - top_k=msg.limit * 2, - include_values=False, - include_metadata=True - ) - - for r in results.matches: - - ent = r.metadata["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break return entities diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index a76059ef..df93ad8b 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -from .... schema import GraphEmbeddingsResponse +from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService @@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService): try: + vec = msg.vector + if not vec: + return [] + + # Use dimension suffix in collection name + dim = len(vec) + collection = f"t_{msg.user}_{msg.collection}_{dim}" + + # Check if collection exists - return empty if not + if not self.collection_exists(collection): + logger.info(f"Collection {collection} does not exist") + return [] + + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) unique entities + search_result = self.qdrant.query_points( + collection_name=collection, + query=vec, + limit=msg.limit * 2, + with_payload=True, + ).points + entity_set = set() entities = [] - for vec in msg.vectors: + for r in search_result: + ent = r.payload["entity"] + score = r.score if hasattr(r, 'score') else 0.0 - # Use dimension suffix in collection name - dim = len(vec) - collection = f"t_{msg.user}_{msg.collection}_{dim}" - - # Check if collection exists - return empty if not - if not self.collection_exists(collection): - logger.info(f"Collection {collection} does not exist, skipping this vector") - continue - - # Heuristic hack, get (2*limit), so that we have more chance - # of getting (limit) entities - search_result = self.qdrant.query_points( - collection_name=collection, - query=vec, - limit=msg.limit * 2, - with_payload=True, - ).points - - for r in search_result: - ent = r.payload["entity"] - - # De-dupe entities - if ent not in entity_set: - entity_set.add(ent) - entities.append(ent) - - # Keep adding entities until limit - if len(entity_set) >= msg.limit: break + # De-dupe entities, keep highest score + if ent not in entity_set: + entity_set.add(ent) + entities.append(EntityMatch( + entity=self.create_value(ent), + score=score, + )) # Keep adding entities until limit - if len(entity_set) >= msg.limit: break - - ents2 = [] - - for ent in entities: - ents2.append(self.create_value(ent)) - - entities = ents2 + if len(entities) >= msg.limit: + break logger.debug("Send response...") return entities diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 7ed6192f..307899d6 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -93,7 +93,9 @@ class Processor(FlowProcessor): async def query_row_embeddings(self, request: RowEmbeddingsRequest): """Execute row embeddings query""" - matches = [] + vec = request.vector + if not vec: + return [] # Find the collection for this user/collection/schema qdrant_collection = self.find_collection( @@ -105,47 +107,47 @@ class Processor(FlowProcessor): f"No Qdrant collection found for " f"{request.user}/{request.collection}/{request.schema_name}" ) + return [] + + try: + # Build optional filter for index_name + query_filter = None + if request.index_name: + query_filter = Filter( + must=[ + FieldCondition( + key="index_name", + match=MatchValue(value=request.index_name) + ) + ] + ) + + # Query Qdrant + search_result = self.qdrant.query_points( + collection_name=qdrant_collection, + query=vec, + limit=request.limit, + with_payload=True, + query_filter=query_filter, + ).points + + # Convert to RowIndexMatch objects + matches = [] + for point in search_result: + payload = point.payload or {} + match = RowIndexMatch( + index_name=payload.get("index_name", ""), + index_value=payload.get("index_value", []), + text=payload.get("text", ""), + score=point.score if hasattr(point, 'score') else 0.0 + ) + matches.append(match) + return matches - for vec in request.vectors: - try: - # Build optional filter for index_name - query_filter = None - if request.index_name: - query_filter = Filter( - must=[ - FieldCondition( - key="index_name", - match=MatchValue(value=request.index_name) - ) - ] - ) - - # Query Qdrant - search_result = self.qdrant.query_points( - collection_name=qdrant_collection, - query=vec, - limit=request.limit, - with_payload=True, - query_filter=query_filter, - ).points - - # Convert to RowIndexMatch objects - for point in search_result: - payload = point.payload or {} - match = RowIndexMatch( - index_name=payload.get("index_name", ""), - index_value=payload.get("index_value", []), - text=payload.get("text", ""), - score=point.score if hasattr(point, 'score') else 0.0 - ) - matches.append(match) - - except Exception as e: - logger.error(f"Failed to query Qdrant: {e}", exc_info=True) - raise - - return matches + except Exception as e: + logger.error(f"Failed to query Qdrant: {e}", exc_info=True) + raise async def on_message(self, msg, consumer, flow): """Handle incoming query request""" diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 6402010a..5e77f733 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -37,26 +37,26 @@ class Query: vectors = await self.get_vector(query) if self.verbose: - logger.debug("Getting chunk_ids from embeddings store...") + logger.debug("Getting chunks from embeddings store...") - # Get chunk_ids from embeddings store - chunk_ids = await self.rag.doc_embeddings_client.query( - vectors, limit=self.doc_limit, + # Get chunk matches from embeddings store + chunk_matches = await self.rag.doc_embeddings_client.query( + vector=vectors, limit=self.doc_limit, user=self.user, collection=self.collection, ) if self.verbose: - logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...") + logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") # Fetch chunk content from Garage docs = [] - for chunk_id in chunk_ids: - if chunk_id: + for match in chunk_matches: + if match.chunk_id: try: - content = await self.rag.fetch_chunk(chunk_id, self.user) + content = await self.rag.fetch_chunk(match.chunk_id, self.user) docs.append(content) except Exception as e: - logger.warning(f"Failed to fetch chunk {chunk_id}: {e}") + logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}") if self.verbose: logger.debug("Documents fetched:") diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 21d5aed1..2bf6b2ea 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -87,14 +87,14 @@ class Query: if self.verbose: logger.debug("Getting entities...") - entities = await self.rag.graph_embeddings_client.query( - vectors=vectors, limit=self.entity_limit, + entity_matches = await self.rag.graph_embeddings_client.query( + vector=vectors, limit=self.entity_limit, user=self.user, collection=self.collection, ) entities = [ - str(e) - for e in entities + str(e.entity) + for e in entity_matches ] if self.verbose: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index a4ff0838..e282f876 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if vec: self.vecstore.insert( vec, chunk_id, message.metadata.user, diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index f6393053..ea091d35 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if not vec: + continue - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - index = self.pinecone.Index(index_name) + index = self.pinecone.Index(index_name) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - records = [ - { - "id": vector_id, - "values": vec, - "metadata": { "chunk_id": chunk_id }, - } - ] + records = [ + { + "id": vector_id, + "values": vec, + "metadata": { "chunk_id": chunk_id }, + } + ] - index.upsert( - vectors = records, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 21ea9a98..a87f2128 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if chunk_id == "": continue - for vec in emb.vectors: + vec = emb.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "chunk_id": chunk_id, - } - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "chunk_id": chunk_id, + } + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 8e1c4485..0f27adf9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): entity_value = get_term_value(entity.entity) if entity_value != "" and entity_value is not None: - for vec in entity.vectors: + vec = entity.vector + if vec: self.vecstore.insert( vec, entity_value, message.metadata.user, diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index f4de7f82..d907e873 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create index name with dimension suffix for lazy creation - dim = len(vec) - index_name = ( - f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" - ) + # Create index name with dimension suffix for lazy creation + dim = len(vec) + index_name = ( + f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" + ) - # Lazily create index if it doesn't exist (but only if authorized in config) - if not self.pinecone.has_index(index_name): - logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") - self.create_index(index_name, dim) + # Lazily create index if it doesn't exist (but only if authorized in config) + if not self.pinecone.has_index(index_name): + logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") + self.create_index(index_name, dim) - index = self.pinecone.Index(index_name) + index = self.pinecone.Index(index_name) - # Generate unique ID for each vector - vector_id = str(uuid.uuid4()) + # Generate unique ID for each vector + vector_id = str(uuid.uuid4()) - metadata = {"entity": entity_value} - if entity.chunk_id: - metadata["chunk_id"] = entity.chunk_id + metadata = {"entity": entity_value} + if entity.chunk_id: + metadata["chunk_id"] = entity.chunk_id - records = [ - { - "id": vector_id, - "values": vec, - "metadata": metadata, - } - ] + records = [ + { + "id": vector_id, + "values": vec, + "metadata": metadata, + } + ] - index.upsert( - vectors = records, - ) + index.upsert( + vectors = records, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 4877ae96..f887d487 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity_value == "" or entity_value is None: continue - for vec in entity.vectors: + vec = entity.vector + if not vec: + continue - # Create collection name with dimension suffix for lazy creation - dim = len(vec) - collection = ( - f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" - ) + # Create collection name with dimension suffix for lazy creation + dim = len(vec) + collection = ( + f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" + ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) - - payload = { - "entity": entity_value, - } - if entity.chunk_id: - payload["chunk_id"] = entity.chunk_id - - self.qdrant.upsert( + # Lazily create collection if it doesn't exist (but only if authorized in config) + if not self.qdrant.collection_exists(collection): + logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") + self.qdrant.create_collection( collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload=payload, - ) - ] + vectors_config=VectorParams( + size=dim, + distance=Distance.COSINE + ) ) + payload = { + "entity": entity_value, + } + if entity.chunk_id: + payload["chunk_id"] = entity.chunk_id + + self.qdrant.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload=payload, + ) + ] + ) + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 29848c4c..42e59012 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor): qdrant_collection = None for row_emb in embeddings.embeddings: - if not row_emb.vectors: + vector = row_emb.vector + if not vector: logger.warning( - f"No vectors for index {row_emb.index_name} - skipping" + f"No vector for index {row_emb.index_name} - skipping" ) continue - # Use first vector (there may be multiple from different models) - for vector in row_emb.vectors: - dimension = len(vector) + dimension = len(vector) - # Create/get collection name (lazily on first vector) - if qdrant_collection is None: - qdrant_collection = self.get_collection_name( - user, collection, schema_name, dimension - ) - self.ensure_collection(qdrant_collection, dimension) - - # Write to Qdrant - self.qdrant.upsert( - collection_name=qdrant_collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vector, - payload={ - "index_name": row_emb.index_name, - "index_value": row_emb.index_value, - "text": row_emb.text - } - ) - ] + # Create/get collection name (lazily on first vector) + if qdrant_collection is None: + qdrant_collection = self.get_collection_name( + user, collection, schema_name, dimension ) - embeddings_written += 1 + self.ensure_collection(qdrant_collection, dimension) + + # Write to Qdrant + self.qdrant.upsert( + collection_name=qdrant_collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "index_name": row_emb.index_name, + "index_value": row_emb.index_value, + "text": row_emb.text + } + ) + ] + ) + embeddings_written += 1 logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")