diff --git a/templates/components/document-rag.jsonnet b/templates/components/document-rag.jsonnet index 0a68dd52..11dc9c13 100644 --- a/templates/components/document-rag.jsonnet +++ b/templates/components/document-rag.jsonnet @@ -39,5 +39,35 @@ local prompts = import "prompts/mixtral.jsonnet"; }, + "document-embeddings" +: { + + create:: function(engine) + + local container = + engine.container("document-embeddings") + .with_image(images.trustgraph) + .with_command([ + "document-embeddings", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.5", "512M"); + + local containerSet = engine.containers( + "document-embeddings", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + } diff --git a/templates/components/graph-rag.jsonnet b/templates/components/graph-rag.jsonnet index 860152c9..eb72754e 100644 --- a/templates/components/graph-rag.jsonnet +++ b/templates/components/graph-rag.jsonnet @@ -138,5 +138,35 @@ local url = import "values/url.jsonnet"; }, + "graph-embeddings" +: { + + create:: function(engine) + + local container = + engine.container("graph-embeddings") + .with_image(images.trustgraph) + .with_command([ + "graph-embeddings", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.5", "512M"); + + local containerSet = engine.containers( + "graph-embeddings", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + } diff --git a/templates/components/trustgraph.jsonnet b/templates/components/trustgraph.jsonnet index 31ae420e..541beeed 100644 --- a/templates/components/trustgraph.jsonnet +++ b/templates/components/trustgraph.jsonnet @@ -119,36 +119,6 @@ local prompt = import "prompt-template.jsonnet"; }, - "vectorize" +: { - - create:: function(engine) - - local container = - engine.container("vectorize") - .with_image(images.trustgraph) - .with_command([ - "embeddings-vectorize", - "-p", - url.pulsar, - ]) - .with_limits("1.0", "512M") - .with_reservations("0.5", "512M"); - - local containerSet = engine.containers( - "vectorize", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - "metering" +: { create:: function(engine) diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index de96499c..24207f32 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -131,6 +131,35 @@ class Api: except: raise ProtocolException(f"Response not formatted correctly") + def document_rag(self, question): + + # The input consists of a question + input = { + "query": question + } + + url = f"{self.url}document-rag" + + # Invoke the API, input is passed as JSON + resp = requests.post(url, json=input) + + # Should be a 200 status code + if resp.status_code != 200: + raise ProtocolException(f"Status code {resp.status_code}") + + try: + # Parse the response as JSON + object = resp.json() + except: + raise ProtocolException(f"Expected JSON response") + + self.check_error(resp) + + try: + return object["response"] + except: + raise ProtocolException(f"Response not formatted correctly") + def embeddings(self, text): # The input consists of a text block diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index d432991d..5b6d324e 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -38,8 +38,12 @@ class DocumentEmbeddingsClient(BaseClient): output_schema=DocumentEmbeddingsResponse, ) - def request(self, vectors, limit=10, timeout=300): + def request( + self, vectors, user="trustgraph", collection="default", + limit=10, timeout=300 + ): return self.call( + user=user, collection=collection, vectors=vectors, limit=limit, timeout=timeout ).documents diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index 38add83d..fd0049ee 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -35,11 +35,28 @@ chunk_ingest_queue = topic('chunk-load') ############################################################################ +# Document embeddings are embeddings associated with a chunk + +class ChunkEmbeddings(Record): + chunk = Bytes() + vectors = Array(Array(Double())) + +# This is a 'batching' mechanism for the above data +class DocumentEmbeddings(Record): + metadata = Metadata() + chunks = Array(ChunkEmbeddings()) + +document_embeddings_store_queue = topic('document-embeddings-store') + +############################################################################ + # Doc embeddings query class DocumentEmbeddingsRequest(Record): vectors = Array(Array(Double())) limit = Integer() + user = String() + collection = String() class DocumentEmbeddingsResponse(Record): error = Error() diff --git a/trustgraph-flow/scripts/document-embeddings b/trustgraph-flow/scripts/document-embeddings new file mode 100755 index 00000000..26bb85b0 --- /dev/null +++ b/trustgraph-flow/scripts/document-embeddings @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.embeddings.document_embeddings import run + +run() + diff --git a/trustgraph-flow/scripts/embeddings-vectorize b/trustgraph-flow/scripts/embeddings-vectorize deleted file mode 100755 index 3de1e3a9..00000000 --- a/trustgraph-flow/scripts/embeddings-vectorize +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python3 - -from trustgraph.embeddings.vectorize import run - -run() - diff --git a/trustgraph-flow/scripts/graph-embeddings b/trustgraph-flow/scripts/graph-embeddings new file mode 100755 index 00000000..29b1fbf4 --- /dev/null +++ b/trustgraph-flow/scripts/graph-embeddings @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.embeddings.graph_embeddings import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index b7eab434..83e4f4f7 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -63,29 +63,30 @@ setuptools.setup( "falkordb", ], scripts=[ - "scripts/api-gateway", "scripts/agent-manager-react", + "scripts/api-gateway", "scripts/chunker-recursive", "scripts/chunker-token", "scripts/de-query-milvus", - "scripts/de-query-qdrant", "scripts/de-query-pinecone", + "scripts/de-query-qdrant", "scripts/de-write-milvus", - "scripts/de-write-qdrant", "scripts/de-write-pinecone", + "scripts/de-write-qdrant", + "scripts/document-embeddings", "scripts/document-rag", "scripts/embeddings-ollama", - "scripts/embeddings-vectorize", "scripts/ge-query-milvus", "scripts/ge-query-pinecone", "scripts/ge-query-qdrant", "scripts/ge-write-milvus", "scripts/ge-write-pinecone", "scripts/ge-write-qdrant", + "scripts/graph-embeddings", "scripts/graph-rag", "scripts/kg-extract-definitions", - "scripts/kg-extract-topics", "scripts/kg-extract-relationships", + "scripts/kg-extract-topics", "scripts/metering", "scripts/object-extract-row", "scripts/oe-write-milvus", @@ -103,13 +104,13 @@ setuptools.setup( "scripts/text-completion-ollama", "scripts/text-completion-openai", "scripts/triples-query-cassandra", - "scripts/triples-query-neo4j", - "scripts/triples-query-memgraph", "scripts/triples-query-falkordb", + "scripts/triples-query-memgraph", + "scripts/triples-query-neo4j", "scripts/triples-write-cassandra", - "scripts/triples-write-neo4j", - "scripts/triples-write-memgraph", "scripts/triples-write-falkordb", + "scripts/triples-write-memgraph", + "scripts/triples-write-neo4j", "scripts/wikipedia-lookup", ] ) diff --git a/trustgraph-flow/trustgraph/document_rag.py b/trustgraph-flow/trustgraph/document_rag.py index f3c8b158..86298783 100644 --- a/trustgraph-flow/trustgraph/document_rag.py +++ b/trustgraph-flow/trustgraph/document_rag.py @@ -16,6 +16,44 @@ from . schema import document_embeddings_response_queue LABEL="http://www.w3.org/2000/01/rdf-schema#label" DEFINITION="http://www.w3.org/2004/02/skos/core#definition" +class Query: + + def __init__(self, rag, user, collection, verbose): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + + def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = self.rag.embeddings.request(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + def get_docs(self, query): + + vectors = self.get_vector(query) + + if self.verbose: + print("Get entities...", flush=True) + + docs = self.rag.de_client.request( + vectors, limit=self.rag.doc_limit + ) + + if self.verbose: + print("Docs:", flush=True) + for doc in docs: + print(doc, flush=True) + + return docs + class DocumentRag: def __init__( @@ -55,7 +93,7 @@ class DocumentRag: print("Initialising...", flush=True) # FIXME: Configurable - self.entity_limit = 20 + self.doc_limit = 20 self.de_client = DocumentEmbeddingsClient( pulsar_host=pulsar_host, @@ -81,42 +119,16 @@ class DocumentRag: if self.verbose: print("Initialised", flush=True) - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_docs(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - docs = self.de_client.request( - vectors, self.entity_limit - ) - - if self.verbose: - print("Docs:", flush=True) - for doc in docs: - print(doc, flush=True) - - return docs - - def query(self, query): + def query(self, query, user="trustgraph", collection="default"): if self.verbose: print("Construct prompt...", flush=True) - docs = self.get_docs(query) + q = Query( + rag=self, user=user, collection=collection, verbose=self.verbose + ) + + docs = q.get_docs(query) if self.verbose: print("Invoke LLM...", flush=True) diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py similarity index 57% rename from trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py rename to trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py index a578de8a..a48cc4d0 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py @@ -1,5 +1,5 @@ -from . vectorize import run +from . embeddings import run if __name__ == '__main__': run() diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py new file mode 100755 index 00000000..745ab4db --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -0,0 +1,109 @@ + +""" +Document embeddings, calls the embeddings service to get embeddings for a +chunk of text. Input is chunk of text plus metadata. +Output is chunk plus embedding. +""" + +from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings +from ... schema import chunk_ingest_queue +from ... schema import document_embeddings_store_queue +from ... schema import embeddings_request_queue, embeddings_response_queue +from ... clients.embeddings_client import EmbeddingsClient +from ... log_level import LogLevel +from ... base import ConsumerProducer + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = chunk_ingest_queue +default_output_queue = document_embeddings_store_queue +default_subscriber = module + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + emb_request_queue = params.get( + "embeddings_request_queue", embeddings_request_queue + ) + emb_response_queue = params.get( + "embeddings_response_queue", embeddings_response_queue + ) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "embeddings_request_queue": emb_request_queue, + "embeddings_response_queue": emb_response_queue, + "subscriber": subscriber, + "input_schema": Chunk, + "output_schema": DocumentEmbeddings, + } + ) + + self.embeddings = EmbeddingsClient( + pulsar_host=self.pulsar_host, + input_queue=emb_request_queue, + output_queue=emb_response_queue, + subscriber=module + "-emb", + ) + + def handle(self, msg): + + v = msg.value() + print(f"Indexing {v.metadata.id}...", flush=True) + + try: + + vectors = self.embeddings.request(v.chunk) + + embeds = [ + ChunkEmbeddings( + chunk=v.chunk, + vectors=vectors, + ) + ] + + r = DocumentEmbeddings( + metadata=v.metadata, + chunks=embeds, + ) + + self.producer.send(r) + + except Exception as e: + print("Exception:", e, flush=True) + + # Retry + raise e + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '--embeddings-request-queue', + default=embeddings_request_queue, + help=f'Embeddings request queue (default: {embeddings_request_queue})', + ) + + parser.add_argument( + '--embeddings-response-queue', + default=embeddings_response_queue, + help=f'Embeddings request queue (default: {embeddings_response_queue})', + ) + +def run(): + + Processor.start(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py new file mode 100755 index 00000000..a48cc4d0 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py @@ -0,0 +1,6 @@ + +from . embeddings import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py similarity index 92% rename from trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py rename to trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py index 5630a7b5..e4d1646e 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -1,7 +1,8 @@ """ -Vectorizer, calls the embeddings service to get embeddings for a chunk. -Input is text chunk, output is chunk and vectors. +Graph embeddings, calls the embeddings service to get embeddings for a +set of entity contexts. Input is entity plus textual context. +Output is entity plus embedding. """ from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings @@ -51,11 +52,6 @@ class Processor(ConsumerProducer): subscriber=module + "-emb", ) - def emit(self, rec, vectors): - - r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) - self.producer.send(r) - def handle(self, msg): v = msg.value() diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py b/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py deleted file mode 100644 index 31596b8c..00000000 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from . vectorize import * - diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index a260b631..d6306ac6 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -31,6 +31,7 @@ from . subscriber import Subscriber from . text_completion import TextCompletionRequestor from . prompt import PromptRequestor from . graph_rag import GraphRagRequestor +from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . embeddings import EmbeddingsRequestor @@ -91,6 +92,10 @@ class Api: pulsar_host=self.pulsar_host, timeout=self.timeout, auth = self.auth, ), + "document-rag": DocumentRagRequestor( + pulsar_host=self.pulsar_host, timeout=self.timeout, + auth = self.auth, + ), "triples-query": TriplesQueryRequestor( pulsar_host=self.pulsar_host, timeout=self.timeout, auth = self.auth, @@ -140,6 +145,10 @@ class Api: endpoint_path = "/api/v1/graph-rag", auth=self.auth, requestor = self.services["graph-rag"], ), + ServiceEndpoint( + endpoint_path = "/api/v1/document-rag", auth=self.auth, + requestor = self.services["document-rag"], + ), ServiceEndpoint( endpoint_path = "/api/v1/triples-query", auth=self.auth, requestor = self.services["triples-query"], diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 00f9d5b5..bfa6c123 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -3,15 +3,16 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue -from .... log_level import LogLevel from .... direct.milvus_doc_embeddings import DocVectors + +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue +from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_store_uri = 'http://localhost:19530' @@ -27,7 +28,7 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "store_uri": store_uri, } ) @@ -38,11 +39,16 @@ class Processor(Consumer): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if v.chunk != "" and v.chunk is not None: - for vec in v.vectors: - self.vecstore.insert(vec, chunk) + chunk = emb.chunk.decode("utf-8") + if chunk == "" or chunk is None: continue + + for vec in emb.vectors: + + if chunk != "" and v.chunk is not None: + for vec in v.vectors: + self.vecstore.insert(vec, chunk) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 24cfcb78..c59ecd7b 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -11,14 +11,14 @@ import time import uuid import os -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" @@ -54,7 +54,7 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "url": self.url, } ) @@ -65,71 +65,74 @@ class Processor(Consumer): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if chunk == "": return + chunk = emb.chunk.decode("utf-8") + if chunk == "" or chunk is None: continue - for vec in v.vectors: + for vec in emb.vectors: - dim = len(vec) - collection = ( - "d-" + v.metadata.user + "-" + str(dim) - ) + for vec in v.vectors: - if index_name != self.last_index_name: + dim = len(vec) + collection = ( + "d-" + v.metadata.user + "-" + str(dim) + ) - if not self.pinecone.has_index(index_name): + if index_name != self.last_index_name: - try: + if not self.pinecone.has_index(index_name): - self.pinecone.create_index( - name = index_name, - dimension = dim, - metric = "cosine", - spec = ServerlessSpec( - cloud = self.cloud, - region = self.region, - ) - ) + try: - for i in range(0, 1000): + self.pinecone.create_index( + name = index_name, + dimension = dim, + metric = "cosine", + spec = ServerlessSpec( + cloud = self.cloud, + region = self.region, + ) + ) - if self.pinecone.describe_index( - index_name - ).status["ready"]: - break + for i in range(0, 1000): - time.sleep(1) + if self.pinecone.describe_index( + index_name + ).status["ready"]: + break - if not self.pinecone.describe_index( - index_name - ).status["ready"]: - raise RuntimeError( - "Gave up waiting for index creation" - ) + time.sleep(1) - except Exception as e: - print("Pinecone index creation failed") - raise e + if not self.pinecone.describe_index( + index_name + ).status["ready"]: + raise RuntimeError( + "Gave up waiting for index creation" + ) - print(f"Index {index_name} created", flush=True) + except Exception as e: + print("Pinecone index creation failed") + raise e - self.last_index_name = index_name + print(f"Index {index_name} created", flush=True) - index = self.pinecone.Index(index_name) + self.last_index_name = index_name - records = [ - { - "id": id, - "values": vec, - "metadata": { "doc": chunk }, - } - ] + index = self.pinecone.Index(index_name) - index.upsert( - vectors = records, - namespace = v.metadata.collection, - ) + records = [ + { + "id": id, + "values": vec, + "metadata": { "doc": chunk }, + } + ] + + index.upsert( + vectors = records, + namespace = v.metadata.collection, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 813c4f29..f852e03b 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -8,14 +8,14 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams import uuid -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_store_uri = 'http://localhost:6333' @@ -31,7 +31,7 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "store_uri": store_uri, } ) @@ -44,47 +44,48 @@ class Processor(Consumer): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if chunk == "": return + chunk = emb.chunk.decode("utf-8") + if chunk == "": return - for vec in v.vectors: + for vec in emb.vectors: - dim = len(vec) - collection = ( - "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + - str(dim) - ) + dim = len(vec) + collection = ( + "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if collection != self.last_collection: + if collection != self.last_collection: - if not self.client.collection_exists(collection): + if not self.client.collection_exists(collection): - try: - self.client.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.DOT - ), + try: + self.client.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + except Exception as e: + print("Qdrant collection creation failed") + raise e + + self.last_collection = collection + + self.client.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "doc": chunk, + } ) - except Exception as e: - print("Qdrant collection creation failed") - raise e - - self.last_collection = collection - - self.client.upsert( - collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "doc": chunk, - } - ) - ] - ) + ] + ) @staticmethod def add_args(parser):