Implementation

This commit is contained in:
Cyber MacGeddon 2026-03-09 09:59:57 +00:00
parent 4fa7cc7d7c
commit 356d7f75ac
38 changed files with 503 additions and 460 deletions

View file

@ -193,12 +193,20 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock entity objects that have string representation
# Mock EntityMatch objects with entity that has string representation
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
@ -216,9 +224,9 @@ class TestQuery:
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify graph embeddings client was called correctly (with extracted vectors)
# Verify graph embeddings client was called correctly (with extracted vector)
mock_graph_embeddings_client.query.assert_called_once_with(
vectors=test_vectors,
vector=test_vectors,
limit=25,
user="test_user",
collection="test_collection"

View file

@ -612,12 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -810,12 +810,12 @@ class AsyncFlowInstance:
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -602,13 +602,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -648,13 +648,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -1362,13 +1362,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -649,12 +649,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -698,12 +698,12 @@ class SocketFlowInstance:
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -936,12 +936,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return resp.chunk_ids
# Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error",
message = str(e),
),
chunk_ids=[],
chunks=[],
)
await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
# Return EntityMatch objects with entity and score
return resp.entities
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
vector=vector,
schema_name=schema_name,
user=user,
collection=collection,

View file

@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).chunks

View file

@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).entities

View file

@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
)
def request(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300
):
kwargs = dict(
user=user, collection=collection,
vectors=vectors, schema_name=schema_name,
vector=vector, schema_name=schema_name,
limit=limit, timeout=timeout
)
if index_name:

View file

@ -10,18 +10,18 @@ from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.chunk_ids is not None:
result["chunk_ids"] = list(obj.chunk_ids)
if obj.chunks is not None:
result["chunks"] = [
{
"chunk_id": chunk.chunk_id,
"score": chunk.score
}
for chunk in obj.chunks
]
return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.entities is not None:
result["entities"] = [
self.value_translator.from_pulsar(entity)
for entity in obj.entities
{
"entity": self.value_translator.from_pulsar(match.entity),
"score": match.score
}
for match in obj.entities
]
return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection,

View file

@ -11,7 +11,7 @@ from ..core.topic import topic
@dataclass
class EntityEmbeddings:
entity: Term | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# Provenance: which chunk this embedding was derived from
chunk_id: str = ""
@ -28,7 +28,7 @@ class GraphEmbeddings:
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# This is a 'batching' mechanism for the above data
@dataclass
@ -44,7 +44,7 @@ class DocumentEmbeddings:
@dataclass
class ObjectEmbeddings:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
name: str = ""
key_name: str = ""
id: str = ""
@ -56,7 +56,7 @@ class ObjectEmbeddings:
@dataclass
class StructuredObjectEmbedding:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
schema_name: str = ""
object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
@ -72,7 +72,7 @@ class RowIndexEmbedding:
index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
@dataclass
class RowEmbeddings:

View file

@ -34,7 +34,7 @@ class EmbeddingsRequest:
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
vectors: list[list[float]] = field(default_factory=list)
############################################################################

View file

@ -9,15 +9,21 @@ from ..core.topic import topic
@dataclass
class GraphEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class EntityMatch:
"""A matching entity from a semantic search with similarity score"""
entity: Term | None = None
score: float = 0.0
@dataclass
class GraphEmbeddingsResponse:
error: Error | None = None
entities: list[Term] = field(default_factory=list)
entities: list[EntityMatch] = field(default_factory=list)
############################################################################
@ -44,15 +50,21 @@ class TriplesQueryResponse:
@dataclass
class DocumentEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class ChunkMatch:
"""A matching chunk from a semantic search with similarity score"""
chunk_id: str = ""
score: float = 0.0
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
chunks: list[ChunkMatch] = field(default_factory=list)
document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
@ -76,7 +88,7 @@ class RowIndexMatch:
@dataclass
class RowEmbeddingsRequest:
"""Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors
vector: list[float] = field(default_factory=list) # Query vector
limit: int = 10 # Max results to return
user: str = "" # User/keyspace
collection: str = "" # Collection name

View file

@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
query_text = arguments.get("query")
all_vectors = await embeddings_client.embed([query_text])
vectors = all_vectors[0] if all_vectors else []
vector = all_vectors[0] if all_vectors else []
# Now query row embeddings
client = self.context("row-embeddings-query-request")
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vectors=vectors,
vector=vector,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",

View file

@ -66,13 +66,13 @@ class Processor(FlowProcessor):
)
)
# vectors[0] is the vector set for the first (only) text
vectors = resp.vectors[0] if resp.vectors else []
# vectors[0] is the vector for the first (only) text
vector = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(
chunk_id=v.document_id,
vectors=vectors,
vector=vector,
)
]

View file

@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
# FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [
[v.tolist()]
for v in vecs
]
# Return list of vectors, one per input text
return [v.tolist() for v in vecs]
@staticmethod
def add_args(parser):

View file

@ -72,10 +72,10 @@ class Processor(FlowProcessor):
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors, # Vector set for this entity
vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vectors in zip(v.entities, all_vectors)
for entity, vector in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages

View file

@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
input = texts
)
# Return list of vector sets, one per input text
return [
[embedding]
for embedding in embeds.embeddings
]
# Return list of vectors, one per input text
return list(embeds.embeddings)
@staticmethod
def add_args(parser):

View file

@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata
for text, (index_name, index_value), vectors in zip(
for text, (index_name, index_value), vector in zip(
texts, metadata, all_vectors
):
embeddings_list.append(
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors # Vector set for this text
vector=vector
)
)

View file

@ -7,7 +7,7 @@ of chunk_ids
import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
chunk_ids = []
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for vec in msg.vectors:
chunks = []
for r in resp:
chunk_id = r["entity"]["chunk_id"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk_id = r["entity"]["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -11,6 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import ChunkMatch
from .... base import DocumentEmbeddingsQueryService
# Module logger
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
chunk_ids = []
dim = len(vec)
for vec in msg.vectors:
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
dim = len(vec)
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
index = self.pinecone.Index(index_name)
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
index = self.pinecone.Index(index_name)
chunks = []
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
try:
chunk_ids = []
vec = msg.vector
if not vec:
return []
for vec in msg.vectors:
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
continue
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
chunks = []
for r in search_result:
chunk_id = r.payload["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
for r in search_result:
chunk_id = r.payload["chunk_id"]
chunk_ids.append(chunk_id)
return chunk_ids
return chunks
except Exception as e:

View file

@ -7,7 +7,7 @@ entities
import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
try:
entity_set = set()
entities = []
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
for vec in msg.vectors:
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
entity_set = set()
entities = []
for r in resp:
ent = r["entity"]["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
for r in resp:
ent = r["entity"]["entity"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -11,7 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
entity_set = set()
entities = []
for vec in msg.vectors:
for r in results.matches:
ent = r.metadata["entity"]
score = r.score if hasattr(r, 'score') else 0.0
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
return entities

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
entity_set = set()
entities = []
for vec in msg.vectors:
for r in search_result:
ent = r.payload["entity"]
score = r.score if hasattr(r, 'score') else 0.0
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, skipping this vector")
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -93,7 +93,9 @@ class Processor(FlowProcessor):
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""

View file

@ -37,26 +37,26 @@ class Query:
vectors = await self.get_vector(query)
if self.verbose:
logger.debug("Getting chunk_ids from embeddings store...")
logger.debug("Getting chunks from embeddings store...")
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit,
# Get chunk matches from embeddings store
chunk_matches = await self.rag.doc_embeddings_client.query(
vector=vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
)
if self.verbose:
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
if chunk_id:
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(chunk_id, self.user)
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content)
except Exception as e:
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")

View file

@ -87,14 +87,14 @@ class Query:
if self.verbose:
logger.debug("Getting entities...")
entities = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit,
entity_matches = await self.rag.graph_embeddings_client.query(
vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
)
entities = [
str(e)
for e in entities
str(e.entity)
for e in entity_matches
]
if self.verbose:

View file

@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if vec:
self.vecstore.insert(
vec, chunk_id,
message.metadata.user,

View file

@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "chunk_id": chunk_id },
}
]
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "chunk_id": chunk_id },
}
]
index.upsert(
vectors = records,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
)
@staticmethod
def add_args(parser):

View file

@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
entity_value = get_term_value(entity.entity)
if entity_value != "" and entity_value is not None:
for vec in entity.vectors:
vec = entity.vector
if vec:
self.vecstore.insert(
vec, entity_value,
message.metadata.user,

View file

@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
metadata = {"entity": entity_value}
if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id
metadata = {"entity": entity_value}
if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id
records = [
{
"id": vector_id,
"values": vec,
"metadata": metadata,
}
]
records = [
{
"id": vector_id,
"values": vec,
"metadata": metadata,
}
]
index.upsert(
vectors = records,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
)
@staticmethod
def add_args(parser):

View file

@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
vector = row_emb.vector
if not vector:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
f"No vector for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
embeddings_written += 1
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")