diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index 2a3d3d0c..38add83d 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -35,17 +35,6 @@ chunk_ingest_queue = topic('chunk-load') ############################################################################ -# Chunk embeddings are an embeddings associated with a text chunk - -class ChunkEmbeddings(Record): - metadata = Metadata() - vectors = Array(Array(Double())) - chunk = Bytes() - -chunk_embeddings_ingest_queue = topic('chunk-embeddings-load') - -############################################################################ - # Doc embeddings query class DocumentEmbeddingsRequest(Record): @@ -62,3 +51,4 @@ document_embeddings_request_queue = topic( document_embeddings_response_queue = topic( 'doc-embeddings', kind='non-persistent', namespace='response', ) + diff --git a/trustgraph-base/trustgraph/schema/graph.py b/trustgraph-base/trustgraph/schema/graph.py index 78c1a99c..7c304e1d 100644 --- a/trustgraph-base/trustgraph/schema/graph.py +++ b/trustgraph-base/trustgraph/schema/graph.py @@ -7,12 +7,31 @@ from . metadata import Metadata ############################################################################ +# Entity context are an entity associated with textual context + +class EntityContext(Record): + entity = Value() + context = String() + +# This is a 'batching' mechanism for the above data +class EntityContexts(Record): + metadata = Metadata() + entities = Array(EntityContext()) + +entity_contexts_ingest_queue = topic('entity-contexts-load') + +############################################################################ + # Graph embeddings are embeddings associated with a graph entity +class EntityEmbeddings(Record): + entity = Value() + vectors = Array(Array(Double())) + +# This is a 'batching' mechanism for the above data class GraphEmbeddings(Record): metadata = Metadata() - vectors = Array(Array(Double())) - entity = Value() + entities = Array(EntityEmbeddings()) graph_embeddings_store_queue = topic('graph-embeddings-store') diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py index 4cf2af05..5630a7b5 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py @@ -4,8 +4,9 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk. Input is text chunk, output is chunk and vectors. """ -from ... schema import Chunk, ChunkEmbeddings -from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue +from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings +from ... schema import entity_contexts_ingest_queue +from ... schema import graph_embeddings_store_queue from ... schema import embeddings_request_queue, embeddings_response_queue from ... clients.embeddings_client import EmbeddingsClient from ... log_level import LogLevel @@ -13,8 +14,8 @@ from ... base import ConsumerProducer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_ingest_queue -default_output_queue = chunk_embeddings_ingest_queue +default_input_queue = entity_contexts_ingest_queue +default_output_queue = graph_embeddings_store_queue default_subscriber = module class Processor(ConsumerProducer): @@ -38,8 +39,8 @@ class Processor(ConsumerProducer): "embeddings_request_queue": emb_request_queue, "embeddings_response_queue": emb_response_queue, "subscriber": subscriber, - "input_schema": Chunk, - "output_schema": ChunkEmbeddings, + "input_schema": EntityContexts, + "output_schema": GraphEmbeddings, } ) @@ -50,9 +51,9 @@ class Processor(ConsumerProducer): subscriber=module + "-emb", ) - def emit(self, metadata, chunk, vectors): + def emit(self, rec, vectors): - r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) + r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) self.producer.send(r) def handle(self, msg): @@ -60,21 +61,34 @@ class Processor(ConsumerProducer): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) - chunk = v.chunk.decode("utf-8") + entities = [] try: - vectors = self.embeddings.request(chunk) + for entity in v.entities: - self.emit( + vectors = self.embeddings.request(entity.context) + + entities.append( + EntityEmbeddings( + entity=entity.entity, + vectors=vectors + ) + ) + + r = GraphEmbeddings( metadata=v.metadata, - chunk=chunk.encode("utf-8"), - vectors=vectors + entities=entities, ) + self.producer.send(r) + except Exception as e: print("Exception:", e, flush=True) + # Retry + raise e + print("Done.", flush=True) @staticmethod diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index eed34574..dcb1123e 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -1,14 +1,17 @@ """ -Simple decoder, accepts embeddings+text chunks input, applies entity analysis to -get entity definitions which are output as graph edges. +Simple decoder, accepts text chunks input, applies entity analysis to +get entity definitions which are output as graph edges along with +entity/context definitions for embedding. """ import urllib.parse -import json +from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import EntityContext, EntityContexts +from .... schema import chunk_ingest_queue, triples_store_queue +from .... schema import entity_contexts_ingest_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue +default_entity_context_queue = entity_contexts_ingest_queue default_subscriber = module class Processor(ConsumerProducer): @@ -32,6 +36,10 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) + ec_queue = params.get( + "entity_context_queue", + default_entity_context_queue + ) subscriber = params.get("subscriber", default_subscriber) pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue @@ -45,13 +53,30 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } ) + self.ec_prod = self.client.create_producer( + topic=ec_queue, + schema=JsonSchema(EntityContexts), + ) + + __class__.pubsub_metric.info({ + "input_queue": input_queue, + "output_queue": output_queue, + "entity_context_queue": ec_queue, + "prompt_request_queue": pr_request_queue, + "prompt_response_queue": pr_response_queue, + "subscriber": subscriber, + "input_schema": Chunk.__name__, + "output_schema": Triples.__name__, + "vector_schema": EntityContexts.__name__, + }) + self.prompt = PromptClient( pulsar_host=self.pulsar_host, input_queue=pr_request_queue, @@ -79,6 +104,14 @@ class Processor(ConsumerProducer): ) self.producer.send(t) + def emit_ecs(self, metadata, entities): + + t = EntityContexts( + metadata=metadata, + entities=entities, + ) + self.ec_prod.send(t) + def handle(self, msg): v = msg.value() @@ -91,6 +124,7 @@ class Processor(ConsumerProducer): defs = self.get_definitions(chunk) triples = [] + entities = [] # FIXME: Putting metadata into triples store is duplicated in # relationships extractor too @@ -129,6 +163,14 @@ class Processor(ConsumerProducer): o=Value(value=v.metadata.id, is_uri=True) )) + ec = EntityContext( + entity=s_value, + context=defn.definition, + ) + + entities.append(ec) + + self.emit_edges( Metadata( id=v.metadata.id, @@ -139,6 +181,16 @@ class Processor(ConsumerProducer): triples ) + self.emit_ecs( + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + entities + ) + except Exception as e: print("Exception: ", e, flush=True) @@ -152,6 +204,12 @@ class Processor(ConsumerProducer): default_output_queue, ) + parser.add_argument( + '-e', '--entity-context-queue', + default=default_entity_context_queue, + help=f'Entity context queue (default: {default_entity_context_queue})' + ) + parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index d2dea062..0fd7b9a8 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -1,18 +1,15 @@ """ -Simple decoder, accepts vector+text chunks input, applies entity +Simple decoder, accepts text chunks input, applies entity relationship analysis to get entity relationship edges which are output as graph edges. """ import urllib.parse -import os -from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings +from .... schema import Chunk, Triple, Triples from .... schema import Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue -from .... schema import graph_embeddings_store_queue +from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue -default_vector_queue = graph_embeddings_store_queue default_subscriber = module class Processor(ConsumerProducer): @@ -36,7 +32,6 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) - vector_queue = params.get("vector_queue", default_vector_queue) subscriber = params.get("subscriber", default_subscriber) pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue @@ -50,30 +45,13 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } ) - self.vec_prod = self.client.create_producer( - topic=vector_queue, - schema=JsonSchema(GraphEmbeddings), - ) - - __class__.pubsub_metric.info({ - "input_queue": input_queue, - "output_queue": output_queue, - "vector_queue": vector_queue, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, - "subscriber": subscriber, - "input_schema": ChunkEmbeddings.__name__, - "output_schema": Triples.__name__, - "vector_schema": GraphEmbeddings.__name__, - }) - self.prompt = PromptClient( pulsar_host=self.pulsar_host, input_queue=pr_request_queue, @@ -101,11 +79,6 @@ class Processor(ConsumerProducer): ) self.producer.send(t) - def emit_vec(self, metadata, ent, vec): - - r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec) - self.vec_prod.send(r) - def handle(self, msg): v = msg.value() @@ -193,12 +166,6 @@ class Processor(ConsumerProducer): o=Value(value=v.metadata.id, is_uri=True) )) - self.emit_vec(v.metadata, s_value, v.vectors) - self.emit_vec(v.metadata, p_value, v.vectors) - - if rel.o_entity: - self.emit_vec(v.metadata, o_value, v.vectors) - self.emit_edges( Metadata( id=v.metadata.id, @@ -222,12 +189,6 @@ class Processor(ConsumerProducer): default_output_queue, ) - parser.add_argument( - '-c', '--vector-queue', - default=default_vector_queue, - help=f'Vector output queue (default: {default_vector_queue})' - ) - parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index 8dfc3e6e..9181ae2c 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -1,14 +1,14 @@ """ -Simple decoder, accepts embeddings+text chunks input, applies entity analysis to -get entity definitions which are output as graph edges. +Simple decoder, accepts text chunks input, applies entity analysis to +get topics which are output as graph edges. """ import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue default_subscriber = module @@ -43,7 +43,7 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 98fe7915..e1379577 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -38,9 +38,11 @@ class Processor(Consumer): v = msg.value() - if v.entity.value != "": - for vec in v.vectors: - self.vecstore.insert(vec, v.entity.value) + for entity in v.entities: + + if entity.entity.value != "" and entity.entity.value is not None: + for vec in entity.vectors: + self.vecstore.insert(vec, entity.entity.value) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index b918c10b..a32ff627 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -60,76 +60,83 @@ class Processor(Consumer): self.last_index_name = None + def create_index(self, index_name, dim): + + self.pinecone.create_index( + name = index_name, + dimension = dim, + metric = "cosine", + spec = ServerlessSpec( + cloud = self.cloud, + region = self.region, + ) + ) + + for i in range(0, 1000): + + if self.pinecone.describe_index( + index_name + ).status["ready"]: + break + + time.sleep(1) + + if not self.pinecone.describe_index( + index_name + ).status["ready"]: + raise RuntimeError( + "Gave up waiting for index creation" + ) + def handle(self, msg): v = msg.value() id = str(uuid.uuid4()) - if v.entity.value == "" or v.entity.value is None: return + for entity in v.entities: - for vec in v.vectors: + if entity.entity.value == "" or entity.entity.value is None: + continue - dim = len(vec) + for vec in entity.vectors: - index_name = ( - "t-" + v.metadata.user + "-" + str(dim) - ) + dim = len(vec) - if index_name != self.last_index_name: + index_name = ( + "t-" + v.metadata.user + "-" + str(dim) + ) - if not self.pinecone.has_index(index_name): + if index_name != self.last_index_name: - try: + if not self.pinecone.has_index(index_name): - self.pinecone.create_index( - name = index_name, - dimension = dim, - metric = "cosine", - spec = ServerlessSpec( - cloud = self.cloud, - region = self.region, - ) - ) + try: - for i in range(0, 1000): + self.create_index(index_name, dim) - if self.pinecone.describe_index( - index_name - ).status["ready"]: - break + except Exception as e: + print("Pinecone index creation failed") + raise e - time.sleep(1) + print(f"Index {index_name} created", flush=True) - if not self.pinecone.describe_index( - index_name - ).status["ready"]: - raise RuntimeError( - "Gave up waiting for index creation" - ) + self.last_index_name = index_name - except Exception as e: - print("Pinecone index creation failed") - raise e + index = self.pinecone.Index(index_name) - print(f"Index {index_name} created", flush=True) + records = [ + { + "id": id, + "values": vec, + "metadata": { "entity": entity.entity.value }, + } + ] - self.last_index_name = index_name - - index = self.pinecone.Index(index_name) - - records = [ - { - "id": id, - "values": vec, - "metadata": { "entity": v.entity.value }, - } - ] - - index.upsert( - vectors = records, - namespace = v.metadata.collection, - ) + index.upsert( + vectors = records, + namespace = v.metadata.collection, + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 47b53979..7bc5778c 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -40,49 +40,59 @@ class Processor(Consumer): self.client = QdrantClient(url=store_uri) + def get_collection(self, dim, user, collection): + + cname = ( + "t_" + user + "_" + collection + "_" + str(dim) + ) + + if cname != self.last_collection: + + if not self.client.collection_exists(cname): + + try: + self.client.create_collection( + collection_name=cname, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + except Exception as e: + print("Qdrant collection creation failed") + raise e + + self.last_collection = cname + + return cname + def handle(self, msg): v = msg.value() - if v.entity.value == "" or v.entity.value is None: return + for entity in v.entities: - for vec in v.vectors: + if entity.entity.value == "" or entity.entity.value is None: return - dim = len(vec) - collection = ( - "t_" + v.metadata.user + "_" + v.metadata.collection + "_" + - str(dim) - ) + for vec in entity.vectors: - if collection != self.last_collection: + dim = len(vec) - if not self.client.collection_exists(collection): + collection = self.get_collection( + dim, v.metadata.user, v.metadata.collection + ) - try: - self.client.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), + self.client.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "entity": entity.entity.value, + } ) - except Exception as e: - print("Qdrant collection creation failed") - raise e - - self.last_collection = collection - - self.client.upsert( - collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "entity": v.entity.value, - } - ) - ] - ) + ] + ) @staticmethod def add_args(parser):