Implementation

This commit is contained in:
Cyber MacGeddon 2026-03-09 09:59:57 +00:00
parent 4fa7cc7d7c
commit 356d7f75ac
38 changed files with 503 additions and 460 deletions

View file

@ -193,12 +193,20 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]] test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors] mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation # Mock EntityMatch objects with entity that has string representation
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1") mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock() mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2") mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2] mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query # Initialize Query
query = Query( query = Query(
@ -216,9 +224,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list) # Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query]) mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors) # Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with( mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors, vector=test_vectors,
limit=25, limit=25,
user="test_user", user="test_user",
collection="test_collection" collection="test_collection"

View file

@ -612,12 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}") print(f"{entity['name']}: {entity['score']}")
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request_data = { request_data = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -810,12 +810,12 @@ class AsyncFlowInstance:
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})") print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request_data = { request_data = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search""" """Query graph embeddings for semantic search"""
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs limit: int = 10, **kwargs
): ):
"""Query row embeddings for semantic search on structured data""" """Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = await self.embeddings(texts=[text]) emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -602,13 +602,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search # Query graph embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -648,13 +648,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search # Query document embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -1362,13 +1362,13 @@ class FlowInstance:
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search # Query row embeddings for semantic search
input = { input = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -649,12 +649,12 @@ class SocketFlowInstance:
) )
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -698,12 +698,12 @@ class SocketFlowInstance:
# results contains {"chunk_ids": ["doc1/p0/c0", ...]} # results contains {"chunk_ids": ["doc1/p0/c0", ...]}
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"user": user, "user": user,
"collection": collection, "collection": collection,
"limit": limit "limit": limit
@ -936,12 +936,12 @@ class SocketFlowInstance:
) )
``` ```
""" """
# First convert text to embeddings vectors # First convert text to embedding vector
emb_result = self.embeddings(texts=[text]) emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0] vector = emb_result.get("vectors", [[]])[0]
request = { request = {
"vectors": vectors, "vector": vector,
"schema_name": schema_name, "schema_name": schema_name,
"user": user, "user": user,
"collection": collection, "collection": collection,

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse): class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph", async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30): collection="default", timeout=30):
resp = await self.request( resp = await self.request(
DocumentEmbeddingsRequest( DocumentEmbeddingsRequest(
vectors = vectors, vector = vector,
limit = limit, limit = limit,
user = user, user = user,
collection = collection collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return resp.chunk_ids # Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec): class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request) docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...") logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None) r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed") logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error", type = "document-embeddings-query-error",
message = str(e), message = str(e),
), ),
chunk_ids=[], chunks=[],
) )
await flow("response").send(r, properties={"id": id}) await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri) return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse): class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph", async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30): collection="default", timeout=30):
resp = await self.request( resp = await self.request(
GraphEmbeddingsRequest( GraphEmbeddingsRequest(
vectors = vectors, vector = vector,
limit = limit, limit = limit,
user = user, user = user,
collection = collection collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error: if resp.error:
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
return [ # Return EntityMatch objects with entity and score
to_value(v) return resp.entities
for v in resp.entities
]
class GraphEmbeddingsClientSpec(RequestResponseSpec): class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__( def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse): class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query( async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default", self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600 index_name=None, limit=10, timeout=600
): ):
request = RowEmbeddingsRequest( request = RowEmbeddingsRequest(
vectors=vectors, vector=vector,
schema_name=schema_name, schema_name=schema_name,
user=user, user=user,
collection=collection, collection=collection,

View file

@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, user="trustgraph", collection="default", self, vector, user="trustgraph", collection="default",
limit=10, timeout=300 limit=10, timeout=300
): ):
return self.call( return self.call(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vector=vector, limit=limit, timeout=timeout
).chunks ).chunks

View file

@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, user="trustgraph", collection="default", self, vector, user="trustgraph", collection="default",
limit=10, timeout=300 limit=10, timeout=300
): ):
return self.call( return self.call(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vector=vector, limit=limit, timeout=timeout
).entities ).entities

View file

@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
) )
def request( def request(
self, vectors, schema_name, user="trustgraph", collection="default", self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300 index_name=None, limit=10, timeout=300
): ):
kwargs = dict( kwargs = dict(
user=user, collection=collection, user=user, collection=collection,
vectors=vectors, schema_name=schema_name, vector=vector, schema_name=schema_name,
limit=limit, timeout=timeout limit=limit, timeout=timeout
) )
if index_name: if index_name:

View file

@ -10,18 +10,18 @@ from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator): class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects""" """Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest( return DocumentEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default") collection=data.get("collection", "default")
) )
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return { return {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection "collection": obj.collection
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
class DocumentEmbeddingsResponseTranslator(MessageTranslator): class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects""" """Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed") raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.chunk_ids is not None: if obj.chunks is not None:
result["chunk_ids"] = list(obj.chunk_ids) result["chunks"] = [
{
"chunk_id": chunk.chunk_id,
"score": chunk.score
}
for chunk in obj.chunks
]
return result return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True return self.from_pulsar(obj), True
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
class GraphEmbeddingsRequestTranslator(MessageTranslator): class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects""" """Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest( return GraphEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default") collection=data.get("collection", "default")
) )
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return { return {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection "collection": obj.collection
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
class GraphEmbeddingsResponseTranslator(MessageTranslator): class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects""" """Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self): def __init__(self):
self.value_translator = ValueTranslator() self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed") raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.entities is not None: if obj.entities is not None:
result["entities"] = [ result["entities"] = [
self.value_translator.from_pulsar(entity) {
for entity in obj.entities "entity": self.value_translator.from_pulsar(match.entity),
"score": match.score
}
for match in obj.entities
] ]
return result return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True return self.from_pulsar(obj), True
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest: def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest( return RowEmbeddingsRequest(
vectors=data["vectors"], vector=data["vector"],
limit=int(data.get("limit", 10)), limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"), user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"), collection=data.get("collection", "default"),
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]: def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = { result = {
"vectors": obj.vectors, "vector": obj.vector,
"limit": obj.limit, "limit": obj.limit,
"user": obj.user, "user": obj.user,
"collection": obj.collection, "collection": obj.collection,

View file

@ -11,7 +11,7 @@ from ..core.topic import topic
@dataclass @dataclass
class EntityEmbeddings: class EntityEmbeddings:
entity: Term | None = None entity: Term | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
# Provenance: which chunk this embedding was derived from # Provenance: which chunk this embedding was derived from
chunk_id: str = "" chunk_id: str = ""
@ -28,7 +28,7 @@ class GraphEmbeddings:
@dataclass @dataclass
class ChunkEmbeddings: class ChunkEmbeddings:
chunk_id: str = "" chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data
@dataclass @dataclass
@ -44,7 +44,7 @@ class DocumentEmbeddings:
@dataclass @dataclass
class ObjectEmbeddings: class ObjectEmbeddings:
metadata: Metadata | None = None metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
name: str = "" name: str = ""
key_name: str = "" key_name: str = ""
id: str = "" id: str = ""
@ -56,7 +56,7 @@ class ObjectEmbeddings:
@dataclass @dataclass
class StructuredObjectEmbedding: class StructuredObjectEmbedding:
metadata: Metadata | None = None metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
schema_name: str = "" schema_name: str = ""
object_id: str = "" # Primary key value object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
@ -72,7 +72,7 @@ class RowIndexEmbedding:
index_name: str = "" # The indexed field name(s) index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s) index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
@dataclass @dataclass
class RowEmbeddings: class RowEmbeddings:

View file

@ -34,7 +34,7 @@ class EmbeddingsRequest:
@dataclass @dataclass
class EmbeddingsResponse: class EmbeddingsResponse:
error: Error | None = None error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list) vectors: list[list[float]] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -9,15 +9,21 @@ from ..core.topic import topic
@dataclass @dataclass
class GraphEmbeddingsRequest: class GraphEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
limit: int = 0 limit: int = 0
user: str = "" user: str = ""
collection: str = "" collection: str = ""
@dataclass
class EntityMatch:
"""A matching entity from a semantic search with similarity score"""
entity: Term | None = None
score: float = 0.0
@dataclass @dataclass
class GraphEmbeddingsResponse: class GraphEmbeddingsResponse:
error: Error | None = None error: Error | None = None
entities: list[Term] = field(default_factory=list) entities: list[EntityMatch] = field(default_factory=list)
############################################################################ ############################################################################
@ -44,15 +50,21 @@ class TriplesQueryResponse:
@dataclass @dataclass
class DocumentEmbeddingsRequest: class DocumentEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list) vector: list[float] = field(default_factory=list)
limit: int = 0 limit: int = 0
user: str = "" user: str = ""
collection: str = "" collection: str = ""
@dataclass
class ChunkMatch:
"""A matching chunk from a semantic search with similarity score"""
chunk_id: str = ""
score: float = 0.0
@dataclass @dataclass
class DocumentEmbeddingsResponse: class DocumentEmbeddingsResponse:
error: Error | None = None error: Error | None = None
chunk_ids: list[str] = field(default_factory=list) chunks: list[ChunkMatch] = field(default_factory=list)
document_embeddings_request_queue = topic( document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
@ -76,7 +88,7 @@ class RowIndexMatch:
@dataclass @dataclass
class RowEmbeddingsRequest: class RowEmbeddingsRequest:
"""Request for row embeddings semantic search""" """Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors vector: list[float] = field(default_factory=list) # Query vector
limit: int = 10 # Max results to return limit: int = 10 # Max results to return
user: str = "" # User/keyspace user: str = "" # User/keyspace
collection: str = "" # Collection name collection: str = "" # Collection name

View file

@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
query_text = arguments.get("query") query_text = arguments.get("query")
all_vectors = await embeddings_client.embed([query_text]) all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else [] vector = all_vectors[0] if all_vectors else []
# Now query row embeddings # Now query row embeddings
client = self.context("row-embeddings-query-request") client = self.context("row-embeddings-query-request")
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
user = getattr(client, '_current_user', self.user or "trustgraph") user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query( matches = await client.row_embeddings_query(
vectors=vectors, vector=vector,
schema_name=self.schema_name, schema_name=self.schema_name,
user=user, user=user,
collection=self.collection or "default", collection=self.collection or "default",

View file

@ -66,13 +66,13 @@ class Processor(FlowProcessor):
) )
) )
# vectors[0] is the vector set for the first (only) text # vectors[0] is the vector for the first (only) text
vectors = resp.vectors[0] if resp.vectors else [] vector = resp.vectors[0] if resp.vectors else []
embeds = [ embeds = [
ChunkEmbeddings( ChunkEmbeddings(
chunk_id=v.document_id, chunk_id=v.document_id,
vectors=vectors, vector=vector,
) )
] ]

View file

@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
# FastEmbed processes the full batch efficiently # FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts)) vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text # Return list of vectors, one per input text
return [ return [v.tolist() for v in vecs]
[v.tolist()]
for v in vecs
]
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -72,10 +72,10 @@ class Processor(FlowProcessor):
entities = [ entities = [
EntityEmbeddings( EntityEmbeddings(
entity=entity.entity, entity=entity.entity,
vectors=vectors, # Vector set for this entity vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk chunk_id=entity.chunk_id, # Provenance: source chunk
) )
for entity, vectors in zip(v.entities, all_vectors) for entity, vector in zip(v.entities, all_vectors)
] ]
# Send in batches to avoid oversized messages # Send in batches to avoid oversized messages

View file

@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
input = texts input = texts
) )
# Return list of vector sets, one per input text # Return list of vectors, one per input text
return [ return list(embeds.embeddings)
[embedding]
for embedding in embeds.embeddings
]
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
all_vectors = await flow("embeddings-request").embed(texts=texts) all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata # Pair results with metadata
for text, (index_name, index_value), vectors in zip( for text, (index_name, index_value), vector in zip(
texts, metadata, all_vectors texts, metadata, all_vectors
): ):
embeddings_list.append( embeddings_list.append(
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name, index_name=index_name,
index_value=index_value, index_value=index_value,
text=text, text=text,
vectors=vectors # Vector set for this text vector=vector
) )
) )

View file

@ -7,7 +7,7 @@ of chunk_ids
import logging import logging
from .... direct.milvus_doc_embeddings import DocVectors from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error from .... schema import Error
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunk_ids = [] resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for vec in msg.vectors: chunks = []
for r in resp:
chunk_id = r["entity"]["chunk_id"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
resp = self.vecstore.search( return chunks
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk_id = r["entity"]["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -11,6 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import ChunkMatch
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
# Module logger # Module logger
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
chunk_ids = [] dim = len(vec)
for vec in msg.vectors: # Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
dim = len(vec) # Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
# Use dimension suffix in index name index = self.pinecone.Index(index_name)
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not results = index.query(
if not self.pinecone.has_index(index_name): vector=vec,
logger.info(f"Index {index_name} does not exist, skipping this vector") top_k=msg.limit,
continue include_values=False,
include_metadata=True
)
index = self.pinecone.Index(index_name) chunks = []
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
results = index.query( return chunks
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error from .... schema import Error
from .... base import DocumentEmbeddingsQueryService from .... base import DocumentEmbeddingsQueryService
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
try: try:
chunk_ids = [] vec = msg.vector
if not vec:
return []
for vec in msg.vectors: # Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Use dimension suffix in collection name # Check if collection exists - return empty if not
dim = len(vec) if not self.collection_exists(collection):
collection = f"d_{msg.user}_{msg.collection}_{dim}" logger.info(f"Collection {collection} does not exist, returning empty results")
return []
# Check if collection exists - return empty if not search_result = self.qdrant.query_points(
if not self.collection_exists(collection): collection_name=collection,
logger.info(f"Collection {collection} does not exist, returning empty results") query=vec,
continue limit=msg.limit,
with_payload=True,
).points
search_result = self.qdrant.query_points( chunks = []
collection_name=collection, for r in search_result:
query=vec, chunk_id = r.payload["chunk_id"]
limit=msg.limit, score = r.score if hasattr(r, 'score') else 0.0
with_payload=True, chunks.append(ChunkMatch(
).points chunk_id=chunk_id,
score=score,
))
for r in search_result: return chunks
chunk_id = r.payload["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
except Exception as e: except Exception as e:

View file

@ -7,7 +7,7 @@ entities
import logging import logging
from .... direct.milvus_graph_embeddings import EntityVectors from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
entity_set = set() vec = msg.vector
entities = [] if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
for vec in msg.vectors: resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
resp = self.vecstore.search( entity_set = set()
vec, entities = []
msg.user,
msg.collection,
limit=msg.limit * 2
)
for r in resp: for r in resp:
ent = r["entity"]["entity"] ent = r["entity"]["entity"]
# Milvus returns distance, convert to similarity score
# De-dupe entities distance = r.get("distance", 0.0)
if ent not in entity_set: score = 1.0 - distance if distance else 0.0
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit # De-dupe entities, keep highest score
if len(entity_set) >= msg.limit: break if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
logger.debug("Send response...") logger.debug("Send response...")
return entities return entities

View file

@ -11,7 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case # Handle zero limit case
if msg.limit <= 0: if msg.limit <= 0:
return [] return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
entity_set = set() entity_set = set()
entities = [] entities = []
for vec in msg.vectors: for r in results.matches:
ent = r.metadata["entity"]
score = r.score if hasattr(r, 'score') else 0.0
dim = len(vec) # De-dupe entities, keep highest score
if ent not in entity_set:
# Use dimension suffix in index name entity_set.add(ent)
index_name = f"t-{msg.user}-{msg.collection}-{dim}" entities.append(EntityMatch(
entity=self.create_value(ent),
# Check if index exists - skip if not score=score,
if not self.pinecone.has_index(index_name): ))
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
return entities return entities

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService from .... base import GraphEmbeddingsQueryService
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
try: try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
entity_set = set() entity_set = set()
entities = [] entities = []
for vec in msg.vectors: for r in search_result:
ent = r.payload["entity"]
score = r.score if hasattr(r, 'score') else 0.0
# Use dimension suffix in collection name # De-dupe entities, keep highest score
dim = len(vec) if ent not in entity_set:
collection = f"t_{msg.user}_{msg.collection}_{dim}" entity_set.add(ent)
entities.append(EntityMatch(
# Check if collection exists - return empty if not entity=self.create_value(ent),
if not self.collection_exists(collection): score=score,
logger.info(f"Collection {collection} does not exist, skipping this vector") ))
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# Keep adding entities until limit # Keep adding entities until limit
if len(entity_set) >= msg.limit: break if len(entities) >= msg.limit:
break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
logger.debug("Send response...") logger.debug("Send response...")
return entities return entities

View file

@ -93,7 +93,9 @@ class Processor(FlowProcessor):
async def query_row_embeddings(self, request: RowEmbeddingsRequest): async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query""" """Execute row embeddings query"""
matches = [] vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema # Find the collection for this user/collection/schema
qdrant_collection = self.find_collection( qdrant_collection = self.find_collection(
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
f"No Qdrant collection found for " f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}" f"{request.user}/{request.collection}/{request.schema_name}"
) )
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
return matches return matches
for vec in request.vectors: except Exception as e:
try: logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
# Build optional filter for index_name raise
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
"""Handle incoming query request""" """Handle incoming query request"""

View file

@ -37,26 +37,26 @@ class Query:
vectors = await self.get_vector(query) vectors = await self.get_vector(query)
if self.verbose: if self.verbose:
logger.debug("Getting chunk_ids from embeddings store...") logger.debug("Getting chunks from embeddings store...")
# Get chunk_ids from embeddings store # Get chunk matches from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query( chunk_matches = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit, vector=vectors, limit=self.doc_limit,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
) )
if self.verbose: if self.verbose:
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...") logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage # Fetch chunk content from Garage
docs = [] docs = []
for chunk_id in chunk_ids: for match in chunk_matches:
if chunk_id: if match.chunk_id:
try: try:
content = await self.rag.fetch_chunk(chunk_id, self.user) content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content) docs.append(content)
except Exception as e: except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}") logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose: if self.verbose:
logger.debug("Documents fetched:") logger.debug("Documents fetched:")

View file

@ -87,14 +87,14 @@ class Query:
if self.verbose: if self.verbose:
logger.debug("Getting entities...") logger.debug("Getting entities...")
entities = await self.rag.graph_embeddings_client.query( entity_matches = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit, vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection, user=self.user, collection=self.collection,
) )
entities = [ entities = [
str(e) str(e.entity)
for e in entities for e in entity_matches
] ]
if self.verbose: if self.verbose:

View file

@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if vec:
self.vecstore.insert( self.vecstore.insert(
vec, chunk_id, vec, chunk_id,
message.metadata.user, message.metadata.user,

View file

@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation # Create index name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
index_name = ( index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
) )
# Lazily create index if it doesn't exist (but only if authorized in config) # Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name): if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim) self.create_index(index_name, dim)
index = self.pinecone.Index(index_name) index = self.pinecone.Index(index_name)
# Generate unique ID for each vector # Generate unique ID for each vector
vector_id = str(uuid.uuid4()) vector_id = str(uuid.uuid4())
records = [ records = [
{ {
"id": vector_id, "id": vector_id,
"values": vec, "values": vec,
"metadata": { "chunk_id": chunk_id }, "metadata": { "chunk_id": chunk_id },
} }
] ]
index.upsert( index.upsert(
vectors = records, vectors = records,
) )
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "": if chunk_id == "":
continue continue
for vec in emb.vectors: vec = emb.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation # Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) # Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection): if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection( self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection, collection_name=collection,
points=[ vectors_config=VectorParams(
PointStruct( size=dim,
id=str(uuid.uuid4()), distance=Distance.COSINE
vector=vec, )
payload={
"chunk_id": chunk_id,
}
)
]
) )
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
entity_value = get_term_value(entity.entity) entity_value = get_term_value(entity.entity)
if entity_value != "" and entity_value is not None: if entity_value != "" and entity_value is not None:
for vec in entity.vectors: vec = entity.vector
if vec:
self.vecstore.insert( self.vecstore.insert(
vec, entity_value, vec, entity_value,
message.metadata.user, message.metadata.user,

View file

@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None: if entity_value == "" or entity_value is None:
continue continue
for vec in entity.vectors: vec = entity.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation # Create index name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
index_name = ( index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
) )
# Lazily create index if it doesn't exist (but only if authorized in config) # Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name): if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim) self.create_index(index_name, dim)
index = self.pinecone.Index(index_name) index = self.pinecone.Index(index_name)
# Generate unique ID for each vector # Generate unique ID for each vector
vector_id = str(uuid.uuid4()) vector_id = str(uuid.uuid4())
metadata = {"entity": entity_value} metadata = {"entity": entity_value}
if entity.chunk_id: if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id metadata["chunk_id"] = entity.chunk_id
records = [ records = [
{ {
"id": vector_id, "id": vector_id,
"values": vec, "values": vec,
"metadata": metadata, "metadata": metadata,
} }
] ]
index.upsert( index.upsert(
vectors = records, vectors = records,
) )
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None: if entity_value == "" or entity_value is None:
continue continue
for vec in entity.vectors: vec = entity.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation # Create collection name with dimension suffix for lazy creation
dim = len(vec) dim = len(vec)
collection = ( collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
) )
# Lazily create collection if it doesn't exist (but only if authorized in config) # Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection): if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection( self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection, collection_name=collection,
points=[ vectors_config=VectorParams(
PointStruct( size=dim,
id=str(uuid.uuid4()), distance=Distance.COSINE
vector=vec, )
payload=payload,
)
]
) )
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
qdrant_collection = None qdrant_collection = None
for row_emb in embeddings.embeddings: for row_emb in embeddings.embeddings:
if not row_emb.vectors: vector = row_emb.vector
if not vector:
logger.warning( logger.warning(
f"No vectors for index {row_emb.index_name} - skipping" f"No vector for index {row_emb.index_name} - skipping"
) )
continue continue
# Use first vector (there may be multiple from different models) dimension = len(vector)
for vector in row_emb.vectors:
dimension = len(vector)
# Create/get collection name (lazily on first vector) # Create/get collection name (lazily on first vector)
if qdrant_collection is None: if qdrant_collection is None:
qdrant_collection = self.get_collection_name( qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
) )
embeddings_written += 1 self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")