Feature/refactor entity embeddings (#235)

* Make schema changes
* Core entity context flow in place
* extract-def outputs entity contexts
* Refactored qdrant write
* Refactoring of all vector stores in place
This commit is contained in:
cybermaggedon 2024-12-30 12:53:19 +00:00 committed by GitHub
parent 9942f63773
commit a458d57af2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 230 additions and 169 deletions

View file

@ -35,17 +35,6 @@ chunk_ingest_queue = topic('chunk-load')
############################################################################ ############################################################################
# Chunk embeddings are an embeddings associated with a text chunk
class ChunkEmbeddings(Record):
metadata = Metadata()
vectors = Array(Array(Double()))
chunk = Bytes()
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
############################################################################
# Doc embeddings query # Doc embeddings query
class DocumentEmbeddingsRequest(Record): class DocumentEmbeddingsRequest(Record):
@ -62,3 +51,4 @@ document_embeddings_request_queue = topic(
document_embeddings_response_queue = topic( document_embeddings_response_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='response', 'doc-embeddings', kind='non-persistent', namespace='response',
) )

View file

@ -7,12 +7,31 @@ from . metadata import Metadata
############################################################################ ############################################################################
# Entity context are an entity associated with textual context
class EntityContext(Record):
entity = Value()
context = String()
# This is a 'batching' mechanism for the above data
class EntityContexts(Record):
metadata = Metadata()
entities = Array(EntityContext())
entity_contexts_ingest_queue = topic('entity-contexts-load')
############################################################################
# Graph embeddings are embeddings associated with a graph entity # Graph embeddings are embeddings associated with a graph entity
class EntityEmbeddings(Record):
entity = Value()
vectors = Array(Array(Double()))
# This is a 'batching' mechanism for the above data
class GraphEmbeddings(Record): class GraphEmbeddings(Record):
metadata = Metadata() metadata = Metadata()
vectors = Array(Array(Double())) entities = Array(EntityEmbeddings())
entity = Value()
graph_embeddings_store_queue = topic('graph-embeddings-store') graph_embeddings_store_queue = topic('graph-embeddings-store')

View file

@ -4,8 +4,9 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk.
Input is text chunk, output is chunk and vectors. Input is text chunk, output is chunk and vectors.
""" """
from ... schema import Chunk, ChunkEmbeddings from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue from ... schema import entity_contexts_ingest_queue
from ... schema import graph_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel from ... log_level import LogLevel
@ -13,8 +14,8 @@ from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue default_input_queue = entity_contexts_ingest_queue
default_output_queue = chunk_embeddings_ingest_queue default_output_queue = graph_embeddings_store_queue
default_subscriber = module default_subscriber = module
class Processor(ConsumerProducer): class Processor(ConsumerProducer):
@ -38,8 +39,8 @@ class Processor(ConsumerProducer):
"embeddings_request_queue": emb_request_queue, "embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue, "embeddings_response_queue": emb_response_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": Chunk, "input_schema": EntityContexts,
"output_schema": ChunkEmbeddings, "output_schema": GraphEmbeddings,
} }
) )
@ -50,9 +51,9 @@ class Processor(ConsumerProducer):
subscriber=module + "-emb", subscriber=module + "-emb",
) )
def emit(self, metadata, chunk, vectors): def emit(self, rec, vectors):
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
self.producer.send(r) self.producer.send(r)
def handle(self, msg): def handle(self, msg):
@ -60,21 +61,34 @@ class Processor(ConsumerProducer):
v = msg.value() v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True) print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8") entities = []
try: try:
vectors = self.embeddings.request(chunk) for entity in v.entities:
self.emit( vectors = self.embeddings.request(entity.context)
entities.append(
EntityEmbeddings(
entity=entity.entity,
vectors=vectors
)
)
r = GraphEmbeddings(
metadata=v.metadata, metadata=v.metadata,
chunk=chunk.encode("utf-8"), entities=entities,
vectors=vectors
) )
self.producer.send(r)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
# Retry
raise e
print("Done.", flush=True) print("Done.", flush=True)
@staticmethod @staticmethod

View file

@ -1,14 +1,17 @@
""" """
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to Simple decoder, accepts text chunks input, applies entity analysis to
get entity definitions which are output as graph edges. get entity definitions which are output as graph edges along with
entity/context definitions for embedding.
""" """
import urllib.parse import urllib.parse
import json from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import EntityContext, EntityContexts
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import entity_contexts_ingest_queue
from .... schema import prompt_request_queue from .... schema import prompt_request_queue
from .... schema import prompt_response_queue from .... schema import prompt_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue default_output_queue = triples_store_queue
default_entity_context_queue = entity_contexts_ingest_queue
default_subscriber = module default_subscriber = module
class Processor(ConsumerProducer): class Processor(ConsumerProducer):
@ -32,6 +36,10 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue) input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue) output_queue = params.get("output_queue", default_output_queue)
ec_queue = params.get(
"entity_context_queue",
default_entity_context_queue
)
subscriber = params.get("subscriber", default_subscriber) subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get( pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue "prompt_request_queue", prompt_request_queue
@ -45,13 +53,30 @@ class Processor(ConsumerProducer):
"input_queue": input_queue, "input_queue": input_queue,
"output_queue": output_queue, "output_queue": output_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": Chunk,
"output_schema": Triples, "output_schema": Triples,
"prompt_request_queue": pr_request_queue, "prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue, "prompt_response_queue": pr_response_queue,
} }
) )
self.ec_prod = self.client.create_producer(
topic=ec_queue,
schema=JsonSchema(EntityContexts),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"entity_context_queue": ec_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": Chunk.__name__,
"output_schema": Triples.__name__,
"vector_schema": EntityContexts.__name__,
})
self.prompt = PromptClient( self.prompt = PromptClient(
pulsar_host=self.pulsar_host, pulsar_host=self.pulsar_host,
input_queue=pr_request_queue, input_queue=pr_request_queue,
@ -79,6 +104,14 @@ class Processor(ConsumerProducer):
) )
self.producer.send(t) self.producer.send(t)
def emit_ecs(self, metadata, entities):
t = EntityContexts(
metadata=metadata,
entities=entities,
)
self.ec_prod.send(t)
def handle(self, msg): def handle(self, msg):
v = msg.value() v = msg.value()
@ -91,6 +124,7 @@ class Processor(ConsumerProducer):
defs = self.get_definitions(chunk) defs = self.get_definitions(chunk)
triples = [] triples = []
entities = []
# FIXME: Putting metadata into triples store is duplicated in # FIXME: Putting metadata into triples store is duplicated in
# relationships extractor too # relationships extractor too
@ -129,6 +163,14 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True) o=Value(value=v.metadata.id, is_uri=True)
)) ))
ec = EntityContext(
entity=s_value,
context=defn.definition,
)
entities.append(ec)
self.emit_edges( self.emit_edges(
Metadata( Metadata(
id=v.metadata.id, id=v.metadata.id,
@ -139,6 +181,16 @@ class Processor(ConsumerProducer):
triples triples
) )
self.emit_ecs(
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
entities
)
except Exception as e: except Exception as e:
print("Exception: ", e, flush=True) print("Exception: ", e, flush=True)
@ -152,6 +204,12 @@ class Processor(ConsumerProducer):
default_output_queue, default_output_queue,
) )
parser.add_argument(
'-e', '--entity-context-queue',
default=default_entity_context_queue,
help=f'Entity context queue (default: {default_entity_context_queue})'
)
parser.add_argument( parser.add_argument(
'--prompt-request-queue', '--prompt-request-queue',
default=prompt_request_queue, default=prompt_request_queue,

View file

@ -1,18 +1,15 @@
""" """
Simple decoder, accepts vector+text chunks input, applies entity Simple decoder, accepts text chunks input, applies entity
relationship analysis to get entity relationship edges which are output as relationship analysis to get entity relationship edges which are output as
graph edges. graph edges.
""" """
import urllib.parse import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value from .... schema import Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue from .... schema import prompt_request_queue
from .... schema import prompt_response_queue from .... schema import prompt_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue default_output_queue = triples_store_queue
default_vector_queue = graph_embeddings_store_queue
default_subscriber = module default_subscriber = module
class Processor(ConsumerProducer): class Processor(ConsumerProducer):
@ -36,7 +32,6 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue) input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue) output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber) subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get( pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue "prompt_request_queue", prompt_request_queue
@ -50,30 +45,13 @@ class Processor(ConsumerProducer):
"input_queue": input_queue, "input_queue": input_queue,
"output_queue": output_queue, "output_queue": output_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": Chunk,
"output_schema": Triples, "output_schema": Triples,
"prompt_request_queue": pr_request_queue, "prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue, "prompt_response_queue": pr_response_queue,
} }
) )
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(GraphEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Triples.__name__,
"vector_schema": GraphEmbeddings.__name__,
})
self.prompt = PromptClient( self.prompt = PromptClient(
pulsar_host=self.pulsar_host, pulsar_host=self.pulsar_host,
input_queue=pr_request_queue, input_queue=pr_request_queue,
@ -101,11 +79,6 @@ class Processor(ConsumerProducer):
) )
self.producer.send(t) self.producer.send(t)
def emit_vec(self, metadata, ent, vec):
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg): def handle(self, msg):
v = msg.value() v = msg.value()
@ -193,12 +166,6 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True) o=Value(value=v.metadata.id, is_uri=True)
)) ))
self.emit_vec(v.metadata, s_value, v.vectors)
self.emit_vec(v.metadata, p_value, v.vectors)
if rel.o_entity:
self.emit_vec(v.metadata, o_value, v.vectors)
self.emit_edges( self.emit_edges(
Metadata( Metadata(
id=v.metadata.id, id=v.metadata.id,
@ -222,12 +189,6 @@ class Processor(ConsumerProducer):
default_output_queue, default_output_queue,
) )
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument( parser.add_argument(
'--prompt-request-queue', '--prompt-request-queue',
default=prompt_request_queue, default=prompt_request_queue,

View file

@ -1,14 +1,14 @@
""" """
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to Simple decoder, accepts text chunks input, applies entity analysis to
get entity definitions which are output as graph edges. get topics which are output as graph edges.
""" """
import urllib.parse import urllib.parse
import json import json
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue from .... schema import prompt_request_queue
from .... schema import prompt_response_queue from .... schema import prompt_response_queue
from .... log_level import LogLevel from .... log_level import LogLevel
@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue default_output_queue = triples_store_queue
default_subscriber = module default_subscriber = module
@ -43,7 +43,7 @@ class Processor(ConsumerProducer):
"input_queue": input_queue, "input_queue": input_queue,
"output_queue": output_queue, "output_queue": output_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": Chunk,
"output_schema": Triples, "output_schema": Triples,
"prompt_request_queue": pr_request_queue, "prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue, "prompt_response_queue": pr_response_queue,

View file

@ -38,9 +38,11 @@ class Processor(Consumer):
v = msg.value() v = msg.value()
if v.entity.value != "": for entity in v.entities:
for vec in v.vectors:
self.vecstore.insert(vec, v.entity.value) if entity.entity.value != "" and entity.entity.value is not None:
for vec in entity.vectors:
self.vecstore.insert(vec, entity.entity.value)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -60,76 +60,83 @@ class Processor(Consumer):
self.last_index_name = None self.last_index_name = None
def create_index(self, index_name, dim):
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000):
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
def handle(self, msg): def handle(self, msg):
v = msg.value() v = msg.value()
id = str(uuid.uuid4()) id = str(uuid.uuid4())
if v.entity.value == "" or v.entity.value is None: return for entity in v.entities:
for vec in v.vectors: if entity.entity.value == "" or entity.entity.value is None:
continue
dim = len(vec) for vec in entity.vectors:
index_name = ( dim = len(vec)
"t-" + v.metadata.user + "-" + str(dim)
)
if index_name != self.last_index_name: index_name = (
"t-" + v.metadata.user + "-" + str(dim)
)
if not self.pinecone.has_index(index_name): if index_name != self.last_index_name:
try: if not self.pinecone.has_index(index_name):
self.pinecone.create_index( try:
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000): self.create_index(index_name, dim)
if self.pinecone.describe_index( except Exception as e:
index_name print("Pinecone index creation failed")
).status["ready"]: raise e
break
time.sleep(1) print(f"Index {index_name} created", flush=True)
if not self.pinecone.describe_index( self.last_index_name = index_name
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
except Exception as e: index = self.pinecone.Index(index_name)
print("Pinecone index creation failed")
raise e
print(f"Index {index_name} created", flush=True) records = [
{
"id": id,
"values": vec,
"metadata": { "entity": entity.entity.value },
}
]
self.last_index_name = index_name index.upsert(
vectors = records,
index = self.pinecone.Index(index_name) namespace = v.metadata.collection,
)
records = [
{
"id": id,
"values": vec,
"metadata": { "entity": v.entity.value },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -40,49 +40,59 @@ class Processor(Consumer):
self.client = QdrantClient(url=store_uri) self.client = QdrantClient(url=store_uri)
def get_collection(self, dim, user, collection):
cname = (
"t_" + user + "_" + collection + "_" + str(dim)
)
if cname != self.last_collection:
if not self.client.collection_exists(cname):
try:
self.client.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
print("Qdrant collection creation failed")
raise e
self.last_collection = cname
return cname
def handle(self, msg): def handle(self, msg):
v = msg.value() v = msg.value()
if v.entity.value == "" or v.entity.value is None: return for entity in v.entities:
for vec in v.vectors: if entity.entity.value == "" or entity.entity.value is None: return
dim = len(vec) for vec in entity.vectors:
collection = (
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if collection != self.last_collection: dim = len(vec)
if not self.client.collection_exists(collection): collection = self.get_collection(
dim, v.metadata.user, v.metadata.collection
)
try: self.client.upsert(
self.client.create_collection( collection_name=collection,
collection_name=collection, points=[
vectors_config=VectorParams( PointStruct(
size=dim, distance=Distance.COSINE id=str(uuid.uuid4()),
), vector=vec,
payload={
"entity": entity.entity.value,
}
) )
except Exception as e: ]
print("Qdrant collection creation failed") )
raise e
self.last_collection = collection
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"entity": v.entity.value,
}
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):