diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index a0fec166..3ef3f40b 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -47,6 +47,39 @@ class Processor(DocumentEmbeddingsQueryService): } ) + self.last_index_name = None + + def ensure_index_exists(self, index_name, dim): + """Ensure index exists, create if it doesn't""" + if index_name != self.last_index_name: + if not self.pinecone.has_index(index_name): + try: + self.pinecone.create_index( + name=index_name, + dimension=dim, + metric="cosine", + spec=ServerlessSpec( + cloud="aws", + region="us-east-1", + ) + ) + logger.info(f"Created index: {index_name}") + + # Wait for index to be ready + import time + for i in range(0, 1000): + if self.pinecone.describe_index(index_name).status["ready"]: + break + time.sleep(1) + + if not self.pinecone.describe_index(index_name).status["ready"]: + raise RuntimeError("Gave up waiting for index creation") + + except Exception as e: + logger.error(f"Pinecone index creation failed: {e}") + raise e + self.last_index_name = index_name + async def query_document_embeddings(self, msg): try: @@ -65,6 +98,8 @@ class Processor(DocumentEmbeddingsQueryService): "d-" + msg.user + "-" + msg.collection + "-" + str(dim) ) + self.ensure_index_exists(index_name, dim) + index = self.pinecone.Index(index_name) results = index.query( diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index cedcaf52..bb07e063 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -38,6 +38,24 @@ class Processor(DocumentEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.last_collection = None + + def ensure_collection_exists(self, collection, dim): + """Ensure collection exists, create if it doesn't""" + if collection != self.last_collection: + if not self.qdrant.collection_exists(collection): + try: + self.qdrant.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + logger.info(f"Created collection: {collection}") + except Exception as e: + logger.error(f"Qdrant collection creation failed: {e}") + raise e + self.last_collection = collection async def query_document_embeddings(self, msg): @@ -53,6 +71,8 @@ class Processor(DocumentEmbeddingsQueryService): str(dim) ) + self.ensure_collection_exists(collection, dim) + search_result = self.qdrant.query_points( collection_name=collection, query=vec, diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 64a2bb10..6de08e4c 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -49,6 +49,39 @@ class Processor(GraphEmbeddingsQueryService): } ) + self.last_index_name = None + + def ensure_index_exists(self, index_name, dim): + """Ensure index exists, create if it doesn't""" + if index_name != self.last_index_name: + if not self.pinecone.has_index(index_name): + try: + self.pinecone.create_index( + name=index_name, + dimension=dim, + metric="cosine", + spec=ServerlessSpec( + cloud="aws", + region="us-east-1", + ) + ) + logger.info(f"Created index: {index_name}") + + # Wait for index to be ready + import time + for i in range(0, 1000): + if self.pinecone.describe_index(index_name).status["ready"]: + break + time.sleep(1) + + if not self.pinecone.describe_index(index_name).status["ready"]: + raise RuntimeError("Gave up waiting for index creation") + + except Exception as e: + logger.error(f"Pinecone index creation failed: {e}") + raise e + self.last_index_name = index_name + def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): return Value(value=ent, is_uri=True) @@ -74,6 +107,8 @@ class Processor(GraphEmbeddingsQueryService): "t-" + msg.user + "-" + msg.collection + "-" + str(dim) ) + self.ensure_index_exists(index_name, dim) + index = self.pinecone.Index(index_name) # Heuristic hack, get (2*limit), so that we have more chance diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 00e711db..756f619b 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -38,6 +38,24 @@ class Processor(GraphEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.last_collection = None + + def ensure_collection_exists(self, collection, dim): + """Ensure collection exists, create if it doesn't""" + if collection != self.last_collection: + if not self.qdrant.collection_exists(collection): + try: + self.qdrant.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + logger.info(f"Created collection: {collection}") + except Exception as e: + logger.error(f"Qdrant collection creation failed: {e}") + raise e + self.last_collection = collection def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -60,6 +78,8 @@ class Processor(GraphEmbeddingsQueryService): str(dim) ) + self.ensure_collection_exists(collection, dim) + # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) entities search_result = self.qdrant.query_points(