mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-22 13:55:12 +02:00
Implementation
This commit is contained in:
parent
4fa7cc7d7c
commit
356d7f75ac
38 changed files with 503 additions and 460 deletions
|
|
@ -193,12 +193,20 @@ class TestQuery:
|
||||||
test_vectors = [[0.1, 0.2, 0.3]]
|
test_vectors = [[0.1, 0.2, 0.3]]
|
||||||
mock_embeddings_client.embed.return_value = [test_vectors]
|
mock_embeddings_client.embed.return_value = [test_vectors]
|
||||||
|
|
||||||
# Mock entity objects that have string representation
|
# Mock EntityMatch objects with entity that has string representation
|
||||||
mock_entity1 = MagicMock()
|
mock_entity1 = MagicMock()
|
||||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||||
|
mock_match1 = MagicMock()
|
||||||
|
mock_match1.entity = mock_entity1
|
||||||
|
mock_match1.score = 0.95
|
||||||
|
|
||||||
mock_entity2 = MagicMock()
|
mock_entity2 = MagicMock()
|
||||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||||
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
|
mock_match2 = MagicMock()
|
||||||
|
mock_match2.entity = mock_entity2
|
||||||
|
mock_match2.score = 0.85
|
||||||
|
|
||||||
|
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -216,9 +224,9 @@ class TestQuery:
|
||||||
# Verify embeddings client was called (now expects list)
|
# Verify embeddings client was called (now expects list)
|
||||||
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
mock_embeddings_client.embed.assert_called_once_with([test_query])
|
||||||
|
|
||||||
# Verify graph embeddings client was called correctly (with extracted vectors)
|
# Verify graph embeddings client was called correctly (with extracted vector)
|
||||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||||
vectors=test_vectors,
|
vector=test_vectors,
|
||||||
limit=25,
|
limit=25,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
|
|
|
||||||
|
|
@ -612,12 +612,12 @@ class AsyncFlowInstance:
|
||||||
print(f"{entity['name']}: {entity['score']}")
|
print(f"{entity['name']}: {entity['score']}")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -810,12 +810,12 @@ class AsyncFlowInstance:
|
||||||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
|
||||||
|
|
||||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||||
"""Query graph embeddings for semantic search"""
|
"""Query graph embeddings for semantic search"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
|
||||||
limit: int = 10, **kwargs
|
limit: int = 10, **kwargs
|
||||||
):
|
):
|
||||||
"""Query row embeddings for semantic search on structured data"""
|
"""Query row embeddings for semantic search on structured data"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = await self.embeddings(texts=[text])
|
emb_result = await self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -602,13 +602,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query graph embeddings for semantic search
|
# Query graph embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -648,13 +648,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query document embeddings for semantic search
|
# Query document embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -1362,13 +1362,13 @@ class FlowInstance:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
# Query row embeddings for semantic search
|
# Query row embeddings for semantic search
|
||||||
input = {
|
input = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -649,12 +649,12 @@ class SocketFlowInstance:
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -698,12 +698,12 @@ class SocketFlowInstance:
|
||||||
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"limit": limit
|
"limit": limit
|
||||||
|
|
@ -936,12 +936,12 @@ class SocketFlowInstance:
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# First convert text to embeddings vectors
|
# First convert text to embedding vector
|
||||||
emb_result = self.embeddings(texts=[text])
|
emb_result = self.embeddings(texts=[text])
|
||||||
vectors = emb_result.get("vectors", [[]])[0]
|
vector = emb_result.get("vectors", [[]])[0]
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"vectors": vectors,
|
"vector": vector,
|
||||||
"schema_name": schema_name,
|
"schema_name": schema_name,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DocumentEmbeddingsClient(RequestResponse):
|
class DocumentEmbeddingsClient(RequestResponse):
|
||||||
async def query(self, vectors, limit=20, user="trustgraph",
|
async def query(self, vector, limit=20, user="trustgraph",
|
||||||
collection="default", timeout=30):
|
collection="default", timeout=30):
|
||||||
|
|
||||||
resp = await self.request(
|
resp = await self.request(
|
||||||
DocumentEmbeddingsRequest(
|
DocumentEmbeddingsRequest(
|
||||||
vectors = vectors,
|
vector = vector,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
user = user,
|
user = user,
|
||||||
collection = collection
|
collection = collection
|
||||||
|
|
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
return resp.chunk_ids
|
# Return ChunkMatch objects with chunk_id and score
|
||||||
|
return resp.chunks
|
||||||
|
|
||||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
docs = await self.query_document_embeddings(request)
|
docs = await self.query_document_embeddings(request)
|
||||||
|
|
||||||
logger.debug("Sending document embeddings query response...")
|
logger.debug("Sending document embeddings query response...")
|
||||||
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
|
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
||||||
logger.debug("Document embeddings query request completed")
|
logger.debug("Document embeddings query request completed")
|
||||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||||
type = "document-embeddings-query-error",
|
type = "document-embeddings-query-error",
|
||||||
message = str(e),
|
message = str(e),
|
||||||
),
|
),
|
||||||
chunk_ids=[],
|
chunks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await flow("response").send(r, properties={"id": id})
|
await flow("response").send(r, properties={"id": id})
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,12 @@ def to_value(x):
|
||||||
return Literal(x.value or x.iri)
|
return Literal(x.value or x.iri)
|
||||||
|
|
||||||
class GraphEmbeddingsClient(RequestResponse):
|
class GraphEmbeddingsClient(RequestResponse):
|
||||||
async def query(self, vectors, limit=20, user="trustgraph",
|
async def query(self, vector, limit=20, user="trustgraph",
|
||||||
collection="default", timeout=30):
|
collection="default", timeout=30):
|
||||||
|
|
||||||
resp = await self.request(
|
resp = await self.request(
|
||||||
GraphEmbeddingsRequest(
|
GraphEmbeddingsRequest(
|
||||||
vectors = vectors,
|
vector = vector,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
user = user,
|
user = user,
|
||||||
collection = collection
|
collection = collection
|
||||||
|
|
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
return [
|
# Return EntityMatch objects with entity and score
|
||||||
to_value(v)
|
return resp.entities
|
||||||
for v in resp.entities
|
|
||||||
]
|
|
||||||
|
|
||||||
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||||
|
|
||||||
class RowEmbeddingsQueryClient(RequestResponse):
|
class RowEmbeddingsQueryClient(RequestResponse):
|
||||||
async def row_embeddings_query(
|
async def row_embeddings_query(
|
||||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
self, vector, schema_name, user="trustgraph", collection="default",
|
||||||
index_name=None, limit=10, timeout=600
|
index_name=None, limit=10, timeout=600
|
||||||
):
|
):
|
||||||
request = RowEmbeddingsRequest(
|
request = RowEmbeddingsRequest(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
schema_name=schema_name,
|
schema_name=schema_name,
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, user="trustgraph", collection="default",
|
self, vector, user="trustgraph", collection="default",
|
||||||
limit=10, timeout=300
|
limit=10, timeout=300
|
||||||
):
|
):
|
||||||
return self.call(
|
return self.call(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, limit=limit, timeout=timeout
|
vector=vector, limit=limit, timeout=timeout
|
||||||
).chunks
|
).chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, user="trustgraph", collection="default",
|
self, vector, user="trustgraph", collection="default",
|
||||||
limit=10, timeout=300
|
limit=10, timeout=300
|
||||||
):
|
):
|
||||||
return self.call(
|
return self.call(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, limit=limit, timeout=timeout
|
vector=vector, limit=limit, timeout=timeout
|
||||||
).entities
|
).entities
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
self, vector, schema_name, user="trustgraph", collection="default",
|
||||||
index_name=None, limit=10, timeout=300
|
index_name=None, limit=10, timeout=300
|
||||||
):
|
):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
user=user, collection=collection,
|
user=user, collection=collection,
|
||||||
vectors=vectors, schema_name=schema_name,
|
vector=vector, schema_name=schema_name,
|
||||||
limit=limit, timeout=timeout
|
limit=limit, timeout=timeout
|
||||||
)
|
)
|
||||||
if index_name:
|
if index_name:
|
||||||
|
|
|
||||||
|
|
@ -10,18 +10,18 @@ from .primitives import ValueTranslator
|
||||||
|
|
||||||
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||||
return DocumentEmbeddingsRequest(
|
return DocumentEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default")
|
collection=data.get("collection", "default")
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection
|
"collection": obj.collection
|
||||||
|
|
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.chunk_ids is not None:
|
if obj.chunks is not None:
|
||||||
result["chunk_ids"] = list(obj.chunk_ids)
|
result["chunks"] = [
|
||||||
|
{
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"score": chunk.score
|
||||||
|
}
|
||||||
|
for chunk in obj.chunks
|
||||||
|
]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
"""Translator for GraphEmbeddingsRequest schema objects"""
|
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||||
return GraphEmbeddingsRequest(
|
return GraphEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default")
|
collection=data.get("collection", "default")
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection
|
"collection": obj.collection
|
||||||
|
|
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
"""Translator for GraphEmbeddingsResponse schema objects"""
|
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.value_translator = ValueTranslator()
|
self.value_translator = ValueTranslator()
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.entities is not None:
|
if obj.entities is not None:
|
||||||
result["entities"] = [
|
result["entities"] = [
|
||||||
self.value_translator.from_pulsar(entity)
|
{
|
||||||
for entity in obj.entities
|
"entity": self.value_translator.from_pulsar(match.entity),
|
||||||
|
"score": match.score
|
||||||
|
}
|
||||||
|
for match in obj.entities
|
||||||
]
|
]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||||
return RowEmbeddingsRequest(
|
return RowEmbeddingsRequest(
|
||||||
vectors=data["vectors"],
|
vector=data["vector"],
|
||||||
limit=int(data.get("limit", 10)),
|
limit=int(data.get("limit", 10)),
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default"),
|
collection=data.get("collection", "default"),
|
||||||
|
|
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"vectors": obj.vectors,
|
"vector": obj.vector,
|
||||||
"limit": obj.limit,
|
"limit": obj.limit,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection,
|
"collection": obj.collection,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from ..core.topic import topic
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityEmbeddings:
|
class EntityEmbeddings:
|
||||||
entity: Term | None = None
|
entity: Term | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
# Provenance: which chunk this embedding was derived from
|
# Provenance: which chunk this embedding was derived from
|
||||||
chunk_id: str = ""
|
chunk_id: str = ""
|
||||||
|
|
||||||
|
|
@ -28,7 +28,7 @@ class GraphEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChunkEmbeddings:
|
class ChunkEmbeddings:
|
||||||
chunk_id: str = ""
|
chunk_id: str = ""
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
# This is a 'batching' mechanism for the above data
|
# This is a 'batching' mechanism for the above data
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -44,7 +44,7 @@ class DocumentEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectEmbeddings:
|
class ObjectEmbeddings:
|
||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
name: str = ""
|
name: str = ""
|
||||||
key_name: str = ""
|
key_name: str = ""
|
||||||
id: str = ""
|
id: str = ""
|
||||||
|
|
@ -56,7 +56,7 @@ class ObjectEmbeddings:
|
||||||
@dataclass
|
@dataclass
|
||||||
class StructuredObjectEmbedding:
|
class StructuredObjectEmbedding:
|
||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
schema_name: str = ""
|
schema_name: str = ""
|
||||||
object_id: str = "" # Primary key value
|
object_id: str = "" # Primary key value
|
||||||
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||||
|
|
@ -72,7 +72,7 @@ class RowIndexEmbedding:
|
||||||
index_name: str = "" # The indexed field name(s)
|
index_name: str = "" # The indexed field name(s)
|
||||||
index_value: list[str] = field(default_factory=list) # The field value(s)
|
index_value: list[str] = field(default_factory=list) # The field value(s)
|
||||||
text: str = "" # Text that was embedded
|
text: str = "" # Text that was embedded
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RowEmbeddings:
|
class RowEmbeddings:
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class EmbeddingsRequest:
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingsResponse:
|
class EmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,21 @@ from ..core.topic import topic
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphEmbeddingsRequest:
|
class GraphEmbeddingsRequest:
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
limit: int = 0
|
limit: int = 0
|
||||||
user: str = ""
|
user: str = ""
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityMatch:
|
||||||
|
"""A matching entity from a semantic search with similarity score"""
|
||||||
|
entity: Term | None = None
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphEmbeddingsResponse:
|
class GraphEmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
entities: list[Term] = field(default_factory=list)
|
entities: list[EntityMatch] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -44,15 +50,21 @@ class TriplesQueryResponse:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentEmbeddingsRequest:
|
class DocumentEmbeddingsRequest:
|
||||||
vectors: list[list[float]] = field(default_factory=list)
|
vector: list[float] = field(default_factory=list)
|
||||||
limit: int = 0
|
limit: int = 0
|
||||||
user: str = ""
|
user: str = ""
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMatch:
|
||||||
|
"""A matching chunk from a semantic search with similarity score"""
|
||||||
|
chunk_id: str = ""
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentEmbeddingsResponse:
|
class DocumentEmbeddingsResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
chunk_ids: list[str] = field(default_factory=list)
|
chunks: list[ChunkMatch] = field(default_factory=list)
|
||||||
|
|
||||||
document_embeddings_request_queue = topic(
|
document_embeddings_request_queue = topic(
|
||||||
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
||||||
|
|
@ -76,7 +88,7 @@ class RowIndexMatch:
|
||||||
@dataclass
|
@dataclass
|
||||||
class RowEmbeddingsRequest:
|
class RowEmbeddingsRequest:
|
||||||
"""Request for row embeddings semantic search"""
|
"""Request for row embeddings semantic search"""
|
||||||
vectors: list[list[float]] = field(default_factory=list) # Query vectors
|
vector: list[float] = field(default_factory=list) # Query vector
|
||||||
limit: int = 10 # Max results to return
|
limit: int = 10 # Max results to return
|
||||||
user: str = "" # User/keyspace
|
user: str = "" # User/keyspace
|
||||||
collection: str = "" # Collection name
|
collection: str = "" # Collection name
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ class RowEmbeddingsQueryImpl:
|
||||||
|
|
||||||
query_text = arguments.get("query")
|
query_text = arguments.get("query")
|
||||||
all_vectors = await embeddings_client.embed([query_text])
|
all_vectors = await embeddings_client.embed([query_text])
|
||||||
vectors = all_vectors[0] if all_vectors else []
|
vector = all_vectors[0] if all_vectors else []
|
||||||
|
|
||||||
# Now query row embeddings
|
# Now query row embeddings
|
||||||
client = self.context("row-embeddings-query-request")
|
client = self.context("row-embeddings-query-request")
|
||||||
|
|
@ -165,7 +165,7 @@ class RowEmbeddingsQueryImpl:
|
||||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||||
|
|
||||||
matches = await client.row_embeddings_query(
|
matches = await client.row_embeddings_query(
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
schema_name=self.schema_name,
|
schema_name=self.schema_name,
|
||||||
user=user,
|
user=user,
|
||||||
collection=self.collection or "default",
|
collection=self.collection or "default",
|
||||||
|
|
|
||||||
|
|
@ -66,13 +66,13 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# vectors[0] is the vector set for the first (only) text
|
# vectors[0] is the vector for the first (only) text
|
||||||
vectors = resp.vectors[0] if resp.vectors else []
|
vector = resp.vectors[0] if resp.vectors else []
|
||||||
|
|
||||||
embeds = [
|
embeds = [
|
||||||
ChunkEmbeddings(
|
ChunkEmbeddings(
|
||||||
chunk_id=v.document_id,
|
chunk_id=v.document_id,
|
||||||
vectors=vectors,
|
vector=vector,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,11 +59,8 @@ class Processor(EmbeddingsService):
|
||||||
# FastEmbed processes the full batch efficiently
|
# FastEmbed processes the full batch efficiently
|
||||||
vecs = list(self.embeddings.embed(texts))
|
vecs = list(self.embeddings.embed(texts))
|
||||||
|
|
||||||
# Return list of vector sets, one per input text
|
# Return list of vectors, one per input text
|
||||||
return [
|
return [v.tolist() for v in vecs]
|
||||||
[v.tolist()]
|
|
||||||
for v in vecs
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,10 @@ class Processor(FlowProcessor):
|
||||||
entities = [
|
entities = [
|
||||||
EntityEmbeddings(
|
EntityEmbeddings(
|
||||||
entity=entity.entity,
|
entity=entity.entity,
|
||||||
vectors=vectors, # Vector set for this entity
|
vector=vector,
|
||||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||||
)
|
)
|
||||||
for entity, vectors in zip(v.entities, all_vectors)
|
for entity, vector in zip(v.entities, all_vectors)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Send in batches to avoid oversized messages
|
# Send in batches to avoid oversized messages
|
||||||
|
|
|
||||||
|
|
@ -43,11 +43,8 @@ class Processor(EmbeddingsService):
|
||||||
input = texts
|
input = texts
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return list of vector sets, one per input text
|
# Return list of vectors, one per input text
|
||||||
return [
|
return list(embeds.embeddings)
|
||||||
[embedding]
|
|
||||||
for embedding in embeds.embeddings
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -208,7 +208,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||||
|
|
||||||
# Pair results with metadata
|
# Pair results with metadata
|
||||||
for text, (index_name, index_value), vectors in zip(
|
for text, (index_name, index_value), vector in zip(
|
||||||
texts, metadata, all_vectors
|
texts, metadata, all_vectors
|
||||||
):
|
):
|
||||||
embeddings_list.append(
|
embeddings_list.append(
|
||||||
|
|
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
index_value=index_value,
|
index_value=index_value,
|
||||||
text=text,
|
text=text,
|
||||||
vectors=vectors # Vector set for this text
|
vector=vector
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ of chunk_ids
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .... direct.milvus_doc_embeddings import DocVectors
|
from .... direct.milvus_doc_embeddings import DocVectors
|
||||||
from .... schema import DocumentEmbeddingsResponse
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -35,26 +35,33 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
chunk_ids = []
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit
|
||||||
|
)
|
||||||
|
|
||||||
for vec in msg.vectors:
|
chunks = []
|
||||||
|
for r in resp:
|
||||||
|
chunk_id = r["entity"]["chunk_id"]
|
||||||
|
# Milvus returns distance, convert to similarity score
|
||||||
|
distance = r.get("distance", 0.0)
|
||||||
|
score = 1.0 - distance if distance else 0.0
|
||||||
|
chunks.append(ChunkMatch(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
resp = self.vecstore.search(
|
return chunks
|
||||||
vec,
|
|
||||||
msg.user,
|
|
||||||
msg.collection,
|
|
||||||
limit=msg.limit
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in resp:
|
|
||||||
chunk_id = r["entity"]["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import os
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||||
|
|
||||||
|
from .... schema import ChunkMatch
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -51,38 +52,43 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
chunk_ids = []
|
dim = len(vec)
|
||||||
|
|
||||||
for vec in msg.vectors:
|
# Use dimension suffix in index name
|
||||||
|
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
||||||
|
|
||||||
dim = len(vec)
|
# Check if index exists - return empty if not
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
logger.info(f"Index {index_name} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
# Use dimension suffix in index name
|
index = self.pinecone.Index(index_name)
|
||||||
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
|
||||||
|
|
||||||
# Check if index exists - skip if not
|
results = index.query(
|
||||||
if not self.pinecone.has_index(index_name):
|
vector=vec,
|
||||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
top_k=msg.limit,
|
||||||
continue
|
include_values=False,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
chunks = []
|
||||||
|
for r in results.matches:
|
||||||
|
chunk_id = r.metadata["chunk_id"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
chunks.append(ChunkMatch(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
results = index.query(
|
return chunks
|
||||||
vector=vec,
|
|
||||||
top_k=msg.limit,
|
|
||||||
include_values=False,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in results.matches:
|
|
||||||
chunk_id = r.metadata["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
from .... schema import DocumentEmbeddingsResponse
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -69,31 +69,36 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
chunk_ids = []
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
# Use dimension suffix in collection name
|
||||||
|
dim = len(vec)
|
||||||
|
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
# Check if collection exists - return empty if not
|
||||||
dim = len(vec)
|
if not self.collection_exists(collection):
|
||||||
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||||
|
return []
|
||||||
|
|
||||||
# Check if collection exists - return empty if not
|
search_result = self.qdrant.query_points(
|
||||||
if not self.collection_exists(collection):
|
collection_name=collection,
|
||||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
query=vec,
|
||||||
continue
|
limit=msg.limit,
|
||||||
|
with_payload=True,
|
||||||
|
).points
|
||||||
|
|
||||||
search_result = self.qdrant.query_points(
|
chunks = []
|
||||||
collection_name=collection,
|
for r in search_result:
|
||||||
query=vec,
|
chunk_id = r.payload["chunk_id"]
|
||||||
limit=msg.limit,
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
with_payload=True,
|
chunks.append(ChunkMatch(
|
||||||
).points
|
chunk_id=chunk_id,
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
for r in search_result:
|
return chunks
|
||||||
chunk_id = r.payload["chunk_id"]
|
|
||||||
chunk_ids.append(chunk_id)
|
|
||||||
|
|
||||||
return chunk_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ entities
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
entity_set = set()
|
vec = msg.vector
|
||||||
entities = []
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit * 2
|
||||||
|
)
|
||||||
|
|
||||||
resp = self.vecstore.search(
|
entity_set = set()
|
||||||
vec,
|
entities = []
|
||||||
msg.user,
|
|
||||||
msg.collection,
|
|
||||||
limit=msg.limit * 2
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in resp:
|
for r in resp:
|
||||||
ent = r["entity"]["entity"]
|
ent = r["entity"]["entity"]
|
||||||
|
# Milvus returns distance, convert to similarity score
|
||||||
# De-dupe entities
|
distance = r.get("distance", 0.0)
|
||||||
if ent not in entity_set:
|
score = 1.0 - distance if distance else 0.0
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# De-dupe entities, keep highest score
|
||||||
if len(entity_set) >= msg.limit: break
|
if ent not in entity_set:
|
||||||
|
entity_set.add(ent)
|
||||||
|
entities.append(EntityMatch(
|
||||||
|
entity=self.create_value(ent),
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
logger.debug("Send response...")
|
logger.debug("Send response...")
|
||||||
return entities
|
return entities
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import os
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||||
|
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Handle zero limit case
|
# Handle zero limit case
|
||||||
if msg.limit <= 0:
|
if msg.limit <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
dim = len(vec)
|
||||||
|
|
||||||
|
# Use dimension suffix in index name
|
||||||
|
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
||||||
|
|
||||||
|
# Check if index exists - return empty if not
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
logger.info(f"Index {index_name} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
|
# of getting (limit) unique entities
|
||||||
|
results = index.query(
|
||||||
|
vector=vec,
|
||||||
|
top_k=msg.limit * 2,
|
||||||
|
include_values=False,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
entity_set = set()
|
entity_set = set()
|
||||||
entities = []
|
entities = []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for r in results.matches:
|
||||||
|
ent = r.metadata["entity"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
|
||||||
dim = len(vec)
|
# De-dupe entities, keep highest score
|
||||||
|
if ent not in entity_set:
|
||||||
# Use dimension suffix in index name
|
entity_set.add(ent)
|
||||||
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
entities.append(EntityMatch(
|
||||||
|
entity=self.create_value(ent),
|
||||||
# Check if index exists - skip if not
|
score=score,
|
||||||
if not self.pinecone.has_index(index_name):
|
))
|
||||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
|
||||||
continue
|
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
|
||||||
# of getting (limit) entities
|
|
||||||
results = index.query(
|
|
||||||
vector=vec,
|
|
||||||
top_k=msg.limit * 2,
|
|
||||||
include_values=False,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for r in results.matches:
|
|
||||||
|
|
||||||
ent = r.metadata["entity"]
|
|
||||||
|
|
||||||
# De-dupe entities
|
|
||||||
if ent not in entity_set:
|
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
|
||||||
if len(entity_set) >= msg.limit: break
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
from .... schema import GraphEmbeddingsResponse
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
|
||||||
|
|
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vec = msg.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use dimension suffix in collection name
|
||||||
|
dim = len(vec)
|
||||||
|
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
||||||
|
|
||||||
|
# Check if collection exists - return empty if not
|
||||||
|
if not self.collection_exists(collection):
|
||||||
|
logger.info(f"Collection {collection} does not exist")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
|
# of getting (limit) unique entities
|
||||||
|
search_result = self.qdrant.query_points(
|
||||||
|
collection_name=collection,
|
||||||
|
query=vec,
|
||||||
|
limit=msg.limit * 2,
|
||||||
|
with_payload=True,
|
||||||
|
).points
|
||||||
|
|
||||||
entity_set = set()
|
entity_set = set()
|
||||||
entities = []
|
entities = []
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for r in search_result:
|
||||||
|
ent = r.payload["entity"]
|
||||||
|
score = r.score if hasattr(r, 'score') else 0.0
|
||||||
|
|
||||||
# Use dimension suffix in collection name
|
# De-dupe entities, keep highest score
|
||||||
dim = len(vec)
|
if ent not in entity_set:
|
||||||
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
entity_set.add(ent)
|
||||||
|
entities.append(EntityMatch(
|
||||||
# Check if collection exists - return empty if not
|
entity=self.create_value(ent),
|
||||||
if not self.collection_exists(collection):
|
score=score,
|
||||||
logger.info(f"Collection {collection} does not exist, skipping this vector")
|
))
|
||||||
continue
|
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
|
||||||
# of getting (limit) entities
|
|
||||||
search_result = self.qdrant.query_points(
|
|
||||||
collection_name=collection,
|
|
||||||
query=vec,
|
|
||||||
limit=msg.limit * 2,
|
|
||||||
with_payload=True,
|
|
||||||
).points
|
|
||||||
|
|
||||||
for r in search_result:
|
|
||||||
ent = r.payload["entity"]
|
|
||||||
|
|
||||||
# De-dupe entities
|
|
||||||
if ent not in entity_set:
|
|
||||||
entity_set.add(ent)
|
|
||||||
entities.append(ent)
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
|
||||||
if len(entity_set) >= msg.limit: break
|
|
||||||
|
|
||||||
# Keep adding entities until limit
|
# Keep adding entities until limit
|
||||||
if len(entity_set) >= msg.limit: break
|
if len(entities) >= msg.limit:
|
||||||
|
break
|
||||||
ents2 = []
|
|
||||||
|
|
||||||
for ent in entities:
|
|
||||||
ents2.append(self.create_value(ent))
|
|
||||||
|
|
||||||
entities = ents2
|
|
||||||
|
|
||||||
logger.debug("Send response...")
|
logger.debug("Send response...")
|
||||||
return entities
|
return entities
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,9 @@ class Processor(FlowProcessor):
|
||||||
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||||
"""Execute row embeddings query"""
|
"""Execute row embeddings query"""
|
||||||
|
|
||||||
matches = []
|
vec = request.vector
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
|
||||||
# Find the collection for this user/collection/schema
|
# Find the collection for this user/collection/schema
|
||||||
qdrant_collection = self.find_collection(
|
qdrant_collection = self.find_collection(
|
||||||
|
|
@ -105,47 +107,47 @@ class Processor(FlowProcessor):
|
||||||
f"No Qdrant collection found for "
|
f"No Qdrant collection found for "
|
||||||
f"{request.user}/{request.collection}/{request.schema_name}"
|
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||||
)
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build optional filter for index_name
|
||||||
|
query_filter = None
|
||||||
|
if request.index_name:
|
||||||
|
query_filter = Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="index_name",
|
||||||
|
match=MatchValue(value=request.index_name)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query Qdrant
|
||||||
|
search_result = self.qdrant.query_points(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
query=vec,
|
||||||
|
limit=request.limit,
|
||||||
|
with_payload=True,
|
||||||
|
query_filter=query_filter,
|
||||||
|
).points
|
||||||
|
|
||||||
|
# Convert to RowIndexMatch objects
|
||||||
|
matches = []
|
||||||
|
for point in search_result:
|
||||||
|
payload = point.payload or {}
|
||||||
|
match = RowIndexMatch(
|
||||||
|
index_name=payload.get("index_name", ""),
|
||||||
|
index_value=payload.get("index_value", []),
|
||||||
|
text=payload.get("text", ""),
|
||||||
|
score=point.score if hasattr(point, 'score') else 0.0
|
||||||
|
)
|
||||||
|
matches.append(match)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
for vec in request.vectors:
|
except Exception as e:
|
||||||
try:
|
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||||
# Build optional filter for index_name
|
raise
|
||||||
query_filter = None
|
|
||||||
if request.index_name:
|
|
||||||
query_filter = Filter(
|
|
||||||
must=[
|
|
||||||
FieldCondition(
|
|
||||||
key="index_name",
|
|
||||||
match=MatchValue(value=request.index_name)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query Qdrant
|
|
||||||
search_result = self.qdrant.query_points(
|
|
||||||
collection_name=qdrant_collection,
|
|
||||||
query=vec,
|
|
||||||
limit=request.limit,
|
|
||||||
with_payload=True,
|
|
||||||
query_filter=query_filter,
|
|
||||||
).points
|
|
||||||
|
|
||||||
# Convert to RowIndexMatch objects
|
|
||||||
for point in search_result:
|
|
||||||
payload = point.payload or {}
|
|
||||||
match = RowIndexMatch(
|
|
||||||
index_name=payload.get("index_name", ""),
|
|
||||||
index_value=payload.get("index_value", []),
|
|
||||||
text=payload.get("text", ""),
|
|
||||||
score=point.score if hasattr(point, 'score') else 0.0
|
|
||||||
)
|
|
||||||
matches.append(match)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
"""Handle incoming query request"""
|
"""Handle incoming query request"""
|
||||||
|
|
|
||||||
|
|
@ -37,26 +37,26 @@ class Query:
|
||||||
vectors = await self.get_vector(query)
|
vectors = await self.get_vector(query)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Getting chunk_ids from embeddings store...")
|
logger.debug("Getting chunks from embeddings store...")
|
||||||
|
|
||||||
# Get chunk_ids from embeddings store
|
# Get chunk matches from embeddings store
|
||||||
chunk_ids = await self.rag.doc_embeddings_client.query(
|
chunk_matches = await self.rag.doc_embeddings_client.query(
|
||||||
vectors, limit=self.doc_limit,
|
vector=vectors, limit=self.doc_limit,
|
||||||
user=self.user, collection=self.collection,
|
user=self.user, collection=self.collection,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug(f"Got {len(chunk_ids)} chunk_ids, fetching content from Garage...")
|
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
|
||||||
|
|
||||||
# Fetch chunk content from Garage
|
# Fetch chunk content from Garage
|
||||||
docs = []
|
docs = []
|
||||||
for chunk_id in chunk_ids:
|
for match in chunk_matches:
|
||||||
if chunk_id:
|
if match.chunk_id:
|
||||||
try:
|
try:
|
||||||
content = await self.rag.fetch_chunk(chunk_id, self.user)
|
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
|
||||||
docs.append(content)
|
docs.append(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch chunk {chunk_id}: {e}")
|
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Documents fetched:")
|
logger.debug("Documents fetched:")
|
||||||
|
|
|
||||||
|
|
@ -87,14 +87,14 @@ class Query:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Getting entities...")
|
logger.debug("Getting entities...")
|
||||||
|
|
||||||
entities = await self.rag.graph_embeddings_client.query(
|
entity_matches = await self.rag.graph_embeddings_client.query(
|
||||||
vectors=vectors, limit=self.entity_limit,
|
vector=vectors, limit=self.entity_limit,
|
||||||
user=self.user, collection=self.collection,
|
user=self.user, collection=self.collection,
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
str(e)
|
str(e.entity)
|
||||||
for e in entities
|
for e in entity_matches
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if vec:
|
||||||
self.vecstore.insert(
|
self.vecstore.insert(
|
||||||
vec, chunk_id,
|
vec, chunk_id,
|
||||||
message.metadata.user,
|
message.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -105,35 +105,37 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create index name with dimension suffix for lazy creation
|
# Create index name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
index_name = (
|
index_name = (
|
||||||
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||||
if not self.pinecone.has_index(index_name):
|
if not self.pinecone.has_index(index_name):
|
||||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||||
self.create_index(index_name, dim)
|
self.create_index(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
# Generate unique ID for each vector
|
# Generate unique ID for each vector
|
||||||
vector_id = str(uuid.uuid4())
|
vector_id = str(uuid.uuid4())
|
||||||
|
|
||||||
records = [
|
records = [
|
||||||
{
|
{
|
||||||
"id": vector_id,
|
"id": vector_id,
|
||||||
"values": vec,
|
"values": vec,
|
||||||
"metadata": { "chunk_id": chunk_id },
|
"metadata": { "chunk_id": chunk_id },
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
index.upsert(
|
index.upsert(
|
||||||
vectors = records,
|
vectors = records,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -56,38 +56,40 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
if chunk_id == "":
|
if chunk_id == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
vec = emb.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
# Create collection name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||||
if not self.qdrant.collection_exists(collection):
|
if not self.qdrant.collection_exists(collection):
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||||
self.qdrant.create_collection(
|
self.qdrant.create_collection(
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
vectors_config=VectorParams(
|
||||||
PointStruct(
|
size=dim,
|
||||||
id=str(uuid.uuid4()),
|
distance=Distance.COSINE
|
||||||
vector=vec,
|
)
|
||||||
payload={
|
|
||||||
"chunk_id": chunk_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vec,
|
||||||
|
payload={
|
||||||
|
"chunk_id": chunk_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
entity_value = get_term_value(entity.entity)
|
entity_value = get_term_value(entity.entity)
|
||||||
|
|
||||||
if entity_value != "" and entity_value is not None:
|
if entity_value != "" and entity_value is not None:
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if vec:
|
||||||
self.vecstore.insert(
|
self.vecstore.insert(
|
||||||
vec, entity_value,
|
vec, entity_value,
|
||||||
message.metadata.user,
|
message.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -119,39 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if entity_value == "" or entity_value is None:
|
if entity_value == "" or entity_value is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create index name with dimension suffix for lazy creation
|
# Create index name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
index_name = (
|
index_name = (
|
||||||
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||||
if not self.pinecone.has_index(index_name):
|
if not self.pinecone.has_index(index_name):
|
||||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||||
self.create_index(index_name, dim)
|
self.create_index(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
# Generate unique ID for each vector
|
# Generate unique ID for each vector
|
||||||
vector_id = str(uuid.uuid4())
|
vector_id = str(uuid.uuid4())
|
||||||
|
|
||||||
metadata = {"entity": entity_value}
|
metadata = {"entity": entity_value}
|
||||||
if entity.chunk_id:
|
if entity.chunk_id:
|
||||||
metadata["chunk_id"] = entity.chunk_id
|
metadata["chunk_id"] = entity.chunk_id
|
||||||
|
|
||||||
records = [
|
records = [
|
||||||
{
|
{
|
||||||
"id": vector_id,
|
"id": vector_id,
|
||||||
"values": vec,
|
"values": vec,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
index.upsert(
|
index.upsert(
|
||||||
vectors = records,
|
vectors = records,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -71,42 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
if entity_value == "" or entity_value is None:
|
if entity_value == "" or entity_value is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for vec in entity.vectors:
|
vec = entity.vector
|
||||||
|
if not vec:
|
||||||
|
continue
|
||||||
|
|
||||||
# Create collection name with dimension suffix for lazy creation
|
# Create collection name with dimension suffix for lazy creation
|
||||||
dim = len(vec)
|
dim = len(vec)
|
||||||
collection = (
|
collection = (
|
||||||
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||||
if not self.qdrant.collection_exists(collection):
|
if not self.qdrant.collection_exists(collection):
|
||||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||||
self.qdrant.create_collection(
|
self.qdrant.create_collection(
|
||||||
collection_name=collection,
|
|
||||||
vectors_config=VectorParams(
|
|
||||||
size=dim,
|
|
||||||
distance=Distance.COSINE
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"entity": entity_value,
|
|
||||||
}
|
|
||||||
if entity.chunk_id:
|
|
||||||
payload["chunk_id"] = entity.chunk_id
|
|
||||||
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
points=[
|
vectors_config=VectorParams(
|
||||||
PointStruct(
|
size=dim,
|
||||||
id=str(uuid.uuid4()),
|
distance=Distance.COSINE
|
||||||
vector=vec,
|
)
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"entity": entity_value,
|
||||||
|
}
|
||||||
|
if entity.chunk_id:
|
||||||
|
payload["chunk_id"] = entity.chunk_id
|
||||||
|
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vec,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
qdrant_collection = None
|
qdrant_collection = None
|
||||||
|
|
||||||
for row_emb in embeddings.embeddings:
|
for row_emb in embeddings.embeddings:
|
||||||
if not row_emb.vectors:
|
vector = row_emb.vector
|
||||||
|
if not vector:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No vectors for index {row_emb.index_name} - skipping"
|
f"No vector for index {row_emb.index_name} - skipping"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Use first vector (there may be multiple from different models)
|
dimension = len(vector)
|
||||||
for vector in row_emb.vectors:
|
|
||||||
dimension = len(vector)
|
|
||||||
|
|
||||||
# Create/get collection name (lazily on first vector)
|
# Create/get collection name (lazily on first vector)
|
||||||
if qdrant_collection is None:
|
if qdrant_collection is None:
|
||||||
qdrant_collection = self.get_collection_name(
|
qdrant_collection = self.get_collection_name(
|
||||||
user, collection, schema_name, dimension
|
user, collection, schema_name, dimension
|
||||||
)
|
|
||||||
self.ensure_collection(qdrant_collection, dimension)
|
|
||||||
|
|
||||||
# Write to Qdrant
|
|
||||||
self.qdrant.upsert(
|
|
||||||
collection_name=qdrant_collection,
|
|
||||||
points=[
|
|
||||||
PointStruct(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
vector=vector,
|
|
||||||
payload={
|
|
||||||
"index_name": row_emb.index_name,
|
|
||||||
"index_value": row_emb.index_value,
|
|
||||||
"text": row_emb.text
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
embeddings_written += 1
|
self.ensure_collection(qdrant_collection, dimension)
|
||||||
|
|
||||||
|
# Write to Qdrant
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"index_name": row_emb.index_name,
|
||||||
|
"index_value": row_emb.index_value,
|
||||||
|
"text": row_emb.text
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
embeddings_written += 1
|
||||||
|
|
||||||
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue