mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
Feature/refactor entity embeddings (#235)
* Make schema changes * Core entity context flow in place * extract-def outputs entity contexts * Refactored qdrant write * Refactoring of all vector stores in place
This commit is contained in:
parent
9942f63773
commit
a458d57af2
9 changed files with 230 additions and 169 deletions
|
|
@ -4,8 +4,9 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk.
|
|||
Input is text chunk, output is chunk and vectors.
|
||||
"""
|
||||
|
||||
from ... schema import Chunk, ChunkEmbeddings
|
||||
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
|
||||
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
|
||||
from ... schema import entity_contexts_ingest_queue
|
||||
from ... schema import graph_embeddings_store_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
|
|
@ -13,8 +14,8 @@ from ... base import ConsumerProducer
|
|||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = chunk_embeddings_ingest_queue
|
||||
default_input_queue = entity_contexts_ingest_queue
|
||||
default_output_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
|
@ -38,8 +39,8 @@ class Processor(ConsumerProducer):
|
|||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": ChunkEmbeddings,
|
||||
"input_schema": EntityContexts,
|
||||
"output_schema": GraphEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -50,9 +51,9 @@ class Processor(ConsumerProducer):
|
|||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, metadata, chunk, vectors):
|
||||
def emit(self, rec, vectors):
|
||||
|
||||
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
||||
r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
|
@ -60,21 +61,34 @@ class Processor(ConsumerProducer):
|
|||
v = msg.value()
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
entities = []
|
||||
|
||||
try:
|
||||
|
||||
vectors = self.embeddings.request(chunk)
|
||||
for entity in v.entities:
|
||||
|
||||
self.emit(
|
||||
vectors = self.embeddings.request(entity.context)
|
||||
|
||||
entities.append(
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors
|
||||
)
|
||||
)
|
||||
|
||||
r = GraphEmbeddings(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
entities=entities,
|
||||
)
|
||||
|
||||
self.producer.send(r)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# Retry
|
||||
raise e
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
Simple decoder, accepts text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges along with
|
||||
entity/context definitions for embedding.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import EntityContext, EntityContexts
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import entity_contexts_ingest_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
|
|
@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
|||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_entity_context_queue = entity_contexts_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
|
@ -32,6 +36,10 @@ class Processor(ConsumerProducer):
|
|||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
ec_queue = params.get(
|
||||
"entity_context_queue",
|
||||
default_entity_context_queue
|
||||
)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
|
|
@ -45,13 +53,30 @@ class Processor(ConsumerProducer):
|
|||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": Triples,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.ec_prod = self.client.create_producer(
|
||||
topic=ec_queue,
|
||||
schema=JsonSchema(EntityContexts),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"entity_context_queue": ec_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk.__name__,
|
||||
"output_schema": Triples.__name__,
|
||||
"vector_schema": EntityContexts.__name__,
|
||||
})
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
|
|
@ -79,6 +104,14 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_ecs(self, metadata, entities):
|
||||
|
||||
t = EntityContexts(
|
||||
metadata=metadata,
|
||||
entities=entities,
|
||||
)
|
||||
self.ec_prod.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
|
@ -91,6 +124,7 @@ class Processor(ConsumerProducer):
|
|||
defs = self.get_definitions(chunk)
|
||||
|
||||
triples = []
|
||||
entities = []
|
||||
|
||||
# FIXME: Putting metadata into triples store is duplicated in
|
||||
# relationships extractor too
|
||||
|
|
@ -129,6 +163,14 @@ class Processor(ConsumerProducer):
|
|||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
))
|
||||
|
||||
ec = EntityContext(
|
||||
entity=s_value,
|
||||
context=defn.definition,
|
||||
)
|
||||
|
||||
entities.append(ec)
|
||||
|
||||
|
||||
self.emit_edges(
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
|
|
@ -139,6 +181,16 @@ class Processor(ConsumerProducer):
|
|||
triples
|
||||
)
|
||||
|
||||
self.emit_ecs(
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
entities
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
|
|
@ -152,6 +204,12 @@ class Processor(ConsumerProducer):
|
|||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--entity-context-queue',
|
||||
default=default_entity_context_queue,
|
||||
help=f'Entity context queue (default: {default_entity_context_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies entity
|
||||
Simple decoder, accepts text chunks input, applies entity
|
||||
relationship analysis to get entity relationship edges which are output as
|
||||
graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings
|
||||
from .... schema import Chunk, Triple, Triples
|
||||
from .... schema import Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
|
|
@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
|||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_vector_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
|
@ -36,7 +32,6 @@ class Processor(ConsumerProducer):
|
|||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
|
|
@ -50,30 +45,13 @@ class Processor(ConsumerProducer):
|
|||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": Triples,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(GraphEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Triples.__name__,
|
||||
"vector_schema": GraphEmbeddings.__name__,
|
||||
})
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
|
|
@ -101,11 +79,6 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, metadata, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
|
@ -193,12 +166,6 @@ class Processor(ConsumerProducer):
|
|||
o=Value(value=v.metadata.id, is_uri=True)
|
||||
))
|
||||
|
||||
self.emit_vec(v.metadata, s_value, v.vectors)
|
||||
self.emit_vec(v.metadata, p_value, v.vectors)
|
||||
|
||||
if rel.o_entity:
|
||||
self.emit_vec(v.metadata, o_value, v.vectors)
|
||||
|
||||
self.emit_edges(
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
|
|
@ -222,12 +189,6 @@ class Processor(ConsumerProducer):
|
|||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
Simple decoder, accepts text chunks input, applies entity analysis to
|
||||
get topics which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
|
|
@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
|||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class Processor(ConsumerProducer):
|
|||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": Triples,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
|
|
|
|||
|
|
@ -38,9 +38,11 @@ class Processor(Consumer):
|
|||
|
||||
v = msg.value()
|
||||
|
||||
if v.entity.value != "":
|
||||
for vec in v.vectors:
|
||||
self.vecstore.insert(vec, v.entity.value)
|
||||
for entity in v.entities:
|
||||
|
||||
if entity.entity.value != "" and entity.entity.value is not None:
|
||||
for vec in entity.vectors:
|
||||
self.vecstore.insert(vec, entity.entity.value)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -60,76 +60,83 @@ class Processor(Consumer):
|
|||
|
||||
self.last_index_name = None
|
||||
|
||||
def create_index(self, index_name, dim):
|
||||
|
||||
self.pinecone.create_index(
|
||||
name = index_name,
|
||||
dimension = dim,
|
||||
metric = "cosine",
|
||||
spec = ServerlessSpec(
|
||||
cloud = self.cloud,
|
||||
region = self.region,
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(0, 1000):
|
||||
|
||||
if self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if not self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
raise RuntimeError(
|
||||
"Gave up waiting for index creation"
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
if v.entity.value == "" or v.entity.value is None: return
|
||||
for entity in v.entities:
|
||||
|
||||
for vec in v.vectors:
|
||||
if entity.entity.value == "" or entity.entity.value is None:
|
||||
continue
|
||||
|
||||
dim = len(vec)
|
||||
for vec in entity.vectors:
|
||||
|
||||
index_name = (
|
||||
"t-" + v.metadata.user + "-" + str(dim)
|
||||
)
|
||||
dim = len(vec)
|
||||
|
||||
if index_name != self.last_index_name:
|
||||
index_name = (
|
||||
"t-" + v.metadata.user + "-" + str(dim)
|
||||
)
|
||||
|
||||
if not self.pinecone.has_index(index_name):
|
||||
if index_name != self.last_index_name:
|
||||
|
||||
try:
|
||||
if not self.pinecone.has_index(index_name):
|
||||
|
||||
self.pinecone.create_index(
|
||||
name = index_name,
|
||||
dimension = dim,
|
||||
metric = "cosine",
|
||||
spec = ServerlessSpec(
|
||||
cloud = self.cloud,
|
||||
region = self.region,
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
||||
for i in range(0, 1000):
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
if self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
break
|
||||
except Exception as e:
|
||||
print("Pinecone index creation failed")
|
||||
raise e
|
||||
|
||||
time.sleep(1)
|
||||
print(f"Index {index_name} created", flush=True)
|
||||
|
||||
if not self.pinecone.describe_index(
|
||||
index_name
|
||||
).status["ready"]:
|
||||
raise RuntimeError(
|
||||
"Gave up waiting for index creation"
|
||||
)
|
||||
self.last_index_name = index_name
|
||||
|
||||
except Exception as e:
|
||||
print("Pinecone index creation failed")
|
||||
raise e
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
print(f"Index {index_name} created", flush=True)
|
||||
records = [
|
||||
{
|
||||
"id": id,
|
||||
"values": vec,
|
||||
"metadata": { "entity": entity.entity.value },
|
||||
}
|
||||
]
|
||||
|
||||
self.last_index_name = index_name
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": id,
|
||||
"values": vec,
|
||||
"metadata": { "entity": v.entity.value },
|
||||
}
|
||||
]
|
||||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
namespace = v.metadata.collection,
|
||||
)
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
namespace = v.metadata.collection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -40,49 +40,59 @@ class Processor(Consumer):
|
|||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
def get_collection(self, dim, user, collection):
|
||||
|
||||
cname = (
|
||||
"t_" + user + "_" + collection + "_" + str(dim)
|
||||
)
|
||||
|
||||
if cname != self.last_collection:
|
||||
|
||||
if not self.client.collection_exists(cname):
|
||||
|
||||
try:
|
||||
self.client.create_collection(
|
||||
collection_name=cname,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = cname
|
||||
|
||||
return cname
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
if v.entity.value == "" or v.entity.value is None: return
|
||||
for entity in v.entities:
|
||||
|
||||
for vec in v.vectors:
|
||||
if entity.entity.value == "" or entity.entity.value is None: return
|
||||
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
for vec in entity.vectors:
|
||||
|
||||
if collection != self.last_collection:
|
||||
dim = len(vec)
|
||||
|
||||
if not self.client.collection_exists(collection):
|
||||
collection = self.get_collection(
|
||||
dim, v.metadata.user, v.metadata.collection
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"entity": entity.entity.value,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print("Qdrant collection creation failed")
|
||||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"entity": v.entity.value,
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue