Feature/refactor entity embeddings (#235)

* Make schema changes
* Core entity context flow in place
* extract-def outputs entity contexts
* Refactored qdrant write
* Refactoring of all vector stores in place
This commit is contained in:
cybermaggedon 2024-12-30 12:53:19 +00:00 committed by GitHub
parent 9942f63773
commit a458d57af2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 230 additions and 169 deletions

View file

@ -4,8 +4,9 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk.
Input is text chunk, output is chunk and vectors.
"""
from ... schema import Chunk, ChunkEmbeddings
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
from ... schema import entity_contexts_ingest_queue
from ... schema import graph_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
@ -13,8 +14,8 @@ from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue
default_output_queue = chunk_embeddings_ingest_queue
default_input_queue = entity_contexts_ingest_queue
default_output_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
@ -38,8 +39,8 @@ class Processor(ConsumerProducer):
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": ChunkEmbeddings,
"input_schema": EntityContexts,
"output_schema": GraphEmbeddings,
}
)
@ -50,9 +51,9 @@ class Processor(ConsumerProducer):
subscriber=module + "-emb",
)
def emit(self, metadata, chunk, vectors):
def emit(self, rec, vectors):
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
self.producer.send(r)
def handle(self, msg):
@ -60,21 +61,34 @@ class Processor(ConsumerProducer):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
entities = []
try:
vectors = self.embeddings.request(chunk)
for entity in v.entities:
self.emit(
vectors = self.embeddings.request(entity.context)
entities.append(
EntityEmbeddings(
entity=entity.entity,
vectors=vectors
)
)
r = GraphEmbeddings(
metadata=v.metadata,
chunk=chunk.encode("utf-8"),
vectors=vectors
entities=entities,
)
self.producer.send(r)
except Exception as e:
print("Exception:", e, flush=True)
# Retry
raise e
print("Done.", flush=True)
@staticmethod

View file

@ -1,14 +1,17 @@
"""
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
get entity definitions which are output as graph edges.
Simple decoder, accepts text chunks input, applies entity analysis to
get entity definitions which are output as graph edges along with
entity/context definitions for embedding.
"""
import urllib.parse
import json
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import EntityContext, EntityContexts
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import entity_contexts_ingest_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_entity_context_queue = entity_contexts_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
@ -32,6 +36,10 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
ec_queue = params.get(
"entity_context_queue",
default_entity_context_queue
)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
@ -45,13 +53,30 @@ class Processor(ConsumerProducer):
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.ec_prod = self.client.create_producer(
topic=ec_queue,
schema=JsonSchema(EntityContexts),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"entity_context_queue": ec_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": Chunk.__name__,
"output_schema": Triples.__name__,
"vector_schema": EntityContexts.__name__,
})
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
@ -79,6 +104,14 @@ class Processor(ConsumerProducer):
)
self.producer.send(t)
def emit_ecs(self, metadata, entities):
t = EntityContexts(
metadata=metadata,
entities=entities,
)
self.ec_prod.send(t)
def handle(self, msg):
v = msg.value()
@ -91,6 +124,7 @@ class Processor(ConsumerProducer):
defs = self.get_definitions(chunk)
triples = []
entities = []
# FIXME: Putting metadata into triples store is duplicated in
# relationships extractor too
@ -129,6 +163,14 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True)
))
ec = EntityContext(
entity=s_value,
context=defn.definition,
)
entities.append(ec)
self.emit_edges(
Metadata(
id=v.metadata.id,
@ -139,6 +181,16 @@ class Processor(ConsumerProducer):
triples
)
self.emit_ecs(
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
entities
)
except Exception as e:
print("Exception: ", e, flush=True)
@ -152,6 +204,12 @@ class Processor(ConsumerProducer):
default_output_queue,
)
parser.add_argument(
'-e', '--entity-context-queue',
default=default_entity_context_queue,
help=f'Entity context queue (default: {default_entity_context_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,

View file

@ -1,18 +1,15 @@
"""
Simple decoder, accepts vector+text chunks input, applies entity
Simple decoder, accepts text chunks input, applies entity
relationship analysis to get entity relationship edges which are output as
graph edges.
"""
import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_vector_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
@ -36,7 +32,6 @@ class Processor(ConsumerProducer):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
@ -50,30 +45,13 @@ class Processor(ConsumerProducer):
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(GraphEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Triples.__name__,
"vector_schema": GraphEmbeddings.__name__,
})
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
@ -101,11 +79,6 @@ class Processor(ConsumerProducer):
)
self.producer.send(t)
def emit_vec(self, metadata, ent, vec):
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
@ -193,12 +166,6 @@ class Processor(ConsumerProducer):
o=Value(value=v.metadata.id, is_uri=True)
))
self.emit_vec(v.metadata, s_value, v.vectors)
self.emit_vec(v.metadata, p_value, v.vectors)
if rel.o_entity:
self.emit_vec(v.metadata, o_value, v.vectors)
self.emit_edges(
Metadata(
id=v.metadata.id,
@ -222,12 +189,6 @@ class Processor(ConsumerProducer):
default_output_queue,
)
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,

View file

@ -1,14 +1,14 @@
"""
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
get entity definitions which are output as graph edges.
Simple decoder, accepts text chunks input, applies entity analysis to
get topics which are output as graph edges.
"""
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import Chunk, Triple, Triples, Metadata, Value
from .... schema import chunk_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_input_queue = chunk_ingest_queue
default_output_queue = triples_store_queue
default_subscriber = module
@ -43,7 +43,7 @@ class Processor(ConsumerProducer):
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"input_schema": Chunk,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,

View file

@ -38,9 +38,11 @@ class Processor(Consumer):
v = msg.value()
if v.entity.value != "":
for vec in v.vectors:
self.vecstore.insert(vec, v.entity.value)
for entity in v.entities:
if entity.entity.value != "" and entity.entity.value is not None:
for vec in entity.vectors:
self.vecstore.insert(vec, entity.entity.value)
@staticmethod
def add_args(parser):

View file

@ -60,76 +60,83 @@ class Processor(Consumer):
self.last_index_name = None
def create_index(self, index_name, dim):
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000):
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
time.sleep(1)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
def handle(self, msg):
v = msg.value()
id = str(uuid.uuid4())
if v.entity.value == "" or v.entity.value is None: return
for entity in v.entities:
for vec in v.vectors:
if entity.entity.value == "" or entity.entity.value is None:
continue
dim = len(vec)
for vec in entity.vectors:
index_name = (
"t-" + v.metadata.user + "-" + str(dim)
)
dim = len(vec)
if index_name != self.last_index_name:
index_name = (
"t-" + v.metadata.user + "-" + str(dim)
)
if not self.pinecone.has_index(index_name):
if index_name != self.last_index_name:
try:
if not self.pinecone.has_index(index_name):
self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
try:
for i in range(0, 1000):
self.create_index(index_name, dim)
if self.pinecone.describe_index(
index_name
).status["ready"]:
break
except Exception as e:
print("Pinecone index creation failed")
raise e
time.sleep(1)
print(f"Index {index_name} created", flush=True)
if not self.pinecone.describe_index(
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
self.last_index_name = index_name
except Exception as e:
print("Pinecone index creation failed")
raise e
index = self.pinecone.Index(index_name)
print(f"Index {index_name} created", flush=True)
records = [
{
"id": id,
"values": vec,
"metadata": { "entity": entity.entity.value },
}
]
self.last_index_name = index_name
index = self.pinecone.Index(index_name)
records = [
{
"id": id,
"values": vec,
"metadata": { "entity": v.entity.value },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod
def add_args(parser):

View file

@ -40,49 +40,59 @@ class Processor(Consumer):
self.client = QdrantClient(url=store_uri)
def get_collection(self, dim, user, collection):
cname = (
"t_" + user + "_" + collection + "_" + str(dim)
)
if cname != self.last_collection:
if not self.client.collection_exists(cname):
try:
self.client.create_collection(
collection_name=cname,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
except Exception as e:
print("Qdrant collection creation failed")
raise e
self.last_collection = cname
return cname
def handle(self, msg):
v = msg.value()
if v.entity.value == "" or v.entity.value is None: return
for entity in v.entities:
for vec in v.vectors:
if entity.entity.value == "" or entity.entity.value is None: return
dim = len(vec)
collection = (
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
for vec in entity.vectors:
if collection != self.last_collection:
dim = len(vec)
if not self.client.collection_exists(collection):
collection = self.get_collection(
dim, v.metadata.user, v.metadata.collection
)
try:
self.client.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"entity": entity.entity.value,
}
)
except Exception as e:
print("Qdrant collection creation failed")
raise e
self.last_collection = collection
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"entity": v.entity.value,
}
)
]
)
]
)
@staticmethod
def add_args(parser):