mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-03 03:45:13 +02:00
Vector stores will create collections on query (#511)
This commit is contained in:
parent
314ce76b81
commit
6a1cc61e52
4 changed files with 110 additions and 0 deletions
|
|
@ -47,6 +47,39 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.last_index_name = None
|
||||||
|
|
||||||
|
def ensure_index_exists(self, index_name, dim):
|
||||||
|
"""Ensure index exists, create if it doesn't"""
|
||||||
|
if index_name != self.last_index_name:
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
try:
|
||||||
|
self.pinecone.create_index(
|
||||||
|
name=index_name,
|
||||||
|
dimension=dim,
|
||||||
|
metric="cosine",
|
||||||
|
spec=ServerlessSpec(
|
||||||
|
cloud="aws",
|
||||||
|
region="us-east-1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"Created index: {index_name}")
|
||||||
|
|
||||||
|
# Wait for index to be ready
|
||||||
|
import time
|
||||||
|
for i in range(0, 1000):
|
||||||
|
if self.pinecone.describe_index(index_name).status["ready"]:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if not self.pinecone.describe_index(index_name).status["ready"]:
|
||||||
|
raise RuntimeError("Gave up waiting for index creation")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pinecone index creation failed: {e}")
|
||||||
|
raise e
|
||||||
|
self.last_index_name = index_name
|
||||||
|
|
||||||
async def query_document_embeddings(self, msg):
|
async def query_document_embeddings(self, msg):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -65,6 +98,8 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_index_exists(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
results = index.query(
|
results = index.query(
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,24 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
self.last_collection = None
|
||||||
|
|
||||||
|
def ensure_collection_exists(self, collection, dim):
|
||||||
|
"""Ensure collection exists, create if it doesn't"""
|
||||||
|
if collection != self.last_collection:
|
||||||
|
if not self.qdrant.collection_exists(collection):
|
||||||
|
try:
|
||||||
|
self.qdrant.create_collection(
|
||||||
|
collection_name=collection,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dim, distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(f"Created collection: {collection}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Qdrant collection creation failed: {e}")
|
||||||
|
raise e
|
||||||
|
self.last_collection = collection
|
||||||
|
|
||||||
async def query_document_embeddings(self, msg):
|
async def query_document_embeddings(self, msg):
|
||||||
|
|
||||||
|
|
@ -53,6 +71,8 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
str(dim)
|
str(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_collection_exists(collection, dim)
|
||||||
|
|
||||||
search_result = self.qdrant.query_points(
|
search_result = self.qdrant.query_points(
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query=vec,
|
query=vec,
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,39 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.last_index_name = None
|
||||||
|
|
||||||
|
def ensure_index_exists(self, index_name, dim):
|
||||||
|
"""Ensure index exists, create if it doesn't"""
|
||||||
|
if index_name != self.last_index_name:
|
||||||
|
if not self.pinecone.has_index(index_name):
|
||||||
|
try:
|
||||||
|
self.pinecone.create_index(
|
||||||
|
name=index_name,
|
||||||
|
dimension=dim,
|
||||||
|
metric="cosine",
|
||||||
|
spec=ServerlessSpec(
|
||||||
|
cloud="aws",
|
||||||
|
region="us-east-1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"Created index: {index_name}")
|
||||||
|
|
||||||
|
# Wait for index to be ready
|
||||||
|
import time
|
||||||
|
for i in range(0, 1000):
|
||||||
|
if self.pinecone.describe_index(index_name).status["ready"]:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if not self.pinecone.describe_index(index_name).status["ready"]:
|
||||||
|
raise RuntimeError("Gave up waiting for index creation")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pinecone index creation failed: {e}")
|
||||||
|
raise e
|
||||||
|
self.last_index_name = index_name
|
||||||
|
|
||||||
def create_value(self, ent):
|
def create_value(self, ent):
|
||||||
if ent.startswith("http://") or ent.startswith("https://"):
|
if ent.startswith("http://") or ent.startswith("https://"):
|
||||||
return Value(value=ent, is_uri=True)
|
return Value(value=ent, is_uri=True)
|
||||||
|
|
@ -74,6 +107,8 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_index_exists(index_name, dim)
|
||||||
|
|
||||||
index = self.pinecone.Index(index_name)
|
index = self.pinecone.Index(index_name)
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,24 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
self.last_collection = None
|
||||||
|
|
||||||
|
def ensure_collection_exists(self, collection, dim):
|
||||||
|
"""Ensure collection exists, create if it doesn't"""
|
||||||
|
if collection != self.last_collection:
|
||||||
|
if not self.qdrant.collection_exists(collection):
|
||||||
|
try:
|
||||||
|
self.qdrant.create_collection(
|
||||||
|
collection_name=collection,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dim, distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(f"Created collection: {collection}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Qdrant collection creation failed: {e}")
|
||||||
|
raise e
|
||||||
|
self.last_collection = collection
|
||||||
|
|
||||||
def create_value(self, ent):
|
def create_value(self, ent):
|
||||||
if ent.startswith("http://") or ent.startswith("https://"):
|
if ent.startswith("http://") or ent.startswith("https://"):
|
||||||
|
|
@ -60,6 +78,8 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
str(dim)
|
str(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ensure_collection_exists(collection, dim)
|
||||||
|
|
||||||
# Heuristic hack, get (2*limit), so that we have more chance
|
# Heuristic hack, get (2*limit), so that we have more chance
|
||||||
# of getting (limit) entities
|
# of getting (limit) entities
|
||||||
search_result = self.qdrant.query_points(
|
search_result = self.qdrant.query_points(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue