Feature/pinecone integration (#170)

* Added Pinecone for GE write & query

* Add templates

* Doc embedding support
This commit is contained in:
cybermaggedon 2024-11-22 23:48:21 +00:00 committed by GitHub
parent ae1264f5c4
commit 319f9ac04a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 842 additions and 0 deletions

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,142 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks. Pinecone implementation.
"""
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"url": self.url,
}
)
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
chunks = []
for vec in v.vectors:
dim = len(vec)
index_name = (
"d-" + v.user + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
include_values=False,
include_metadata=True
)
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=v.limit,
with_payload=True,
).points
for r in results.matches:
doc = r.metadata["doc"]
chunks.add(doc)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,156 @@
"""
Graph embeddings query service. Input is vector, output is list of
entities. Pinecone implementation.
"""
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"url": self.url,
}
)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
else:
return Value(value=ent, is_uri=False)
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entities = set()
for vec in v.vectors:
dim = len(vec)
index_name = (
"t-" + v.user + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
entities.add(ent)
# Convert set to list
entities = list(entities)
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
entities=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-a', '--api-key',
default=default_api_key,
help='Pinecone API key. (default from PINECONE_API_KEY)'
)
parser.add_argument(
'-u', '--url',
help='Pinecone URL. If unspecified, serverless is used'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . write import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,167 @@
"""
Accepts entity/vector pairs and writes them to a Qdrant store.
"""
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import time
import uuid
import os
from .... schema import ChunkEmbeddings
from .... schema import chunk_embeddings_ingest_queue
from .... log_level import LogLevel
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.cloud = params.get("cloud", default_cloud)
self.region = params.get("region", default_region)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None:
raise RuntimeError("Pinecone API key must be specified")
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"url": self.url,
}
)
self.last_index_name = None
def handle(self, msg):
v = msg.value()
chunk = v.chunk.decode("utf-8")
if chunk == "": return
for vec in v.vectors:
dim = len(vec)
collection = (
"d-" + v.metadata.user + "-" + str(dim)
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000):
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
except Exception as e:
print("Pinecone index creation failed")
raise e
print(f"Index {index_name} created", flush=True)
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
records = [
{
"id": id,
"values": vec,
"metadata": { "doc": chunk },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-a', '--api-key',
default=default_api_key,
help='Pinecone API key. (default from PINECONE_API_KEY)'
)
parser.add_argument(
'-u', '--url',
help='Pinecone URL. If unspecified, serverless is used'
)
parser.add_argument(
'--cloud',
default=default_cloud,
help=f'Pinecone cloud, (default: {default_cloud}'
)
parser.add_argument(
'--region',
default=default_region,
help=f'Pinecone region, (default: {default_region}'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . write import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,167 @@
"""
Accepts entity/vector pairs and writes them to a Pinecone store.
"""
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import time
import uuid
import os
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.cloud = params.get("cloud", default_cloud)
self.region = params.get("region", default_region)
self.api_key = params.get("api_key", default_api_key)
if self.api_key is None:
raise RuntimeError("Pinecone API key must be specified")
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
"url": self.url,
}
)
self.last_index_name = None
def handle(self, msg):
v = msg.value()
id = str(uuid.uuid4())
if v.entity.value == "" or v.entity.value is None: return
for vec in v.vectors:
dim = len(vec)
index_name = (
"t-" + v.metadata.user + "-" + str(dim)
)
if index_name != self.last_index_name:
if not self.pinecone.has_index(index_name):
try:
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000):
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
except Exception as e:
print("Pinecone index creation failed")
raise e
print(f"Index {index_name} created", flush=True)
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
records = [
{
"id": id,
"values": vec,
"metadata": { "entity": v.entity.value },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-a', '--api-key',
default=default_api_key,
help='Pinecone API key. (default from PINECONE_API_KEY)'
)
parser.add_argument(
'-u', '--url',
help='Pinecone URL. If unspecified, serverless is used'
)
parser.add_argument(
'--cloud',
default=default_cloud,
help=f'Pinecone cloud, (default: {default_cloud}'
)
parser.add_argument(
'--region',
default=default_region,
help=f'Pinecone region, (default: {default_region}'
)
def run():
Processor.start(module, __doc__)