Fix/document embeddings (#247)

* Update schema for doc embeddings

* Rename embeddings-vectorize to graph-embeddings

* Added document-embeddings processor (broken, needs fixing)

* Added scripts

* Fixed DE queue schema

* Add missing DE process

* Fix doc RAG processing, put graph-rag and doc-rag in appropriate component files.
This commit is contained in:
cybermaggedon 2025-01-04 21:51:28 +00:00 committed by GitHub
parent c633652fd2
commit 6aa212061d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 421 additions and 189 deletions

View file

@ -39,5 +39,35 @@ local prompts = import "prompts/mixtral.jsonnet";
}, },
"document-embeddings" +: {
create:: function(engine)
local container =
engine.container("document-embeddings")
.with_image(images.trustgraph)
.with_command([
"document-embeddings",
"-p",
url.pulsar,
])
.with_limits("1.0", "512M")
.with_reservations("0.5", "512M");
local containerSet = engine.containers(
"document-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8000, 8000, "metrics");
engine.resources([
containerSet,
service,
])
},
} }

View file

@ -138,5 +138,35 @@ local url = import "values/url.jsonnet";
}, },
"graph-embeddings" +: {
create:: function(engine)
local container =
engine.container("graph-embeddings")
.with_image(images.trustgraph)
.with_command([
"graph-embeddings",
"-p",
url.pulsar,
])
.with_limits("1.0", "512M")
.with_reservations("0.5", "512M");
local containerSet = engine.containers(
"graph-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8000, 8000, "metrics");
engine.resources([
containerSet,
service,
])
},
} }

View file

@ -119,36 +119,6 @@ local prompt = import "prompt-template.jsonnet";
}, },
"vectorize" +: {
create:: function(engine)
local container =
engine.container("vectorize")
.with_image(images.trustgraph)
.with_command([
"embeddings-vectorize",
"-p",
url.pulsar,
])
.with_limits("1.0", "512M")
.with_reservations("0.5", "512M");
local containerSet = engine.containers(
"vectorize", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8000, 8000, "metrics");
engine.resources([
containerSet,
service,
])
},
"metering" +: { "metering" +: {
create:: function(engine) create:: function(engine)

View file

@ -131,6 +131,35 @@ class Api:
except: except:
raise ProtocolException(f"Response not formatted correctly") raise ProtocolException(f"Response not formatted correctly")
def document_rag(self, question):
# The input consists of a question
input = {
"query": question
}
url = f"{self.url}document-rag"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException(f"Expected JSON response")
self.check_error(resp)
try:
return object["response"]
except:
raise ProtocolException(f"Response not formatted correctly")
def embeddings(self, text): def embeddings(self, text):
# The input consists of a text block # The input consists of a text block

View file

@ -38,8 +38,12 @@ class DocumentEmbeddingsClient(BaseClient):
output_schema=DocumentEmbeddingsResponse, output_schema=DocumentEmbeddingsResponse,
) )
def request(self, vectors, limit=10, timeout=300): def request(
self, vectors, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call( return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout vectors=vectors, limit=limit, timeout=timeout
).documents ).documents

View file

@ -35,11 +35,28 @@ chunk_ingest_queue = topic('chunk-load')
############################################################################ ############################################################################
# Document embeddings are embeddings associated with a chunk
class ChunkEmbeddings(Record):
chunk = Bytes()
vectors = Array(Array(Double()))
# This is a 'batching' mechanism for the above data
class DocumentEmbeddings(Record):
metadata = Metadata()
chunks = Array(ChunkEmbeddings())
document_embeddings_store_queue = topic('document-embeddings-store')
############################################################################
# Doc embeddings query # Doc embeddings query
class DocumentEmbeddingsRequest(Record): class DocumentEmbeddingsRequest(Record):
vectors = Array(Array(Double())) vectors = Array(Array(Double()))
limit = Integer() limit = Integer()
user = String()
collection = String()
class DocumentEmbeddingsResponse(Record): class DocumentEmbeddingsResponse(Record):
error = Error() error = Error()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.embeddings.document_embeddings import run
run()

View file

@ -1,6 +0,0 @@
#!/usr/bin/env python3
from trustgraph.embeddings.vectorize import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.embeddings.graph_embeddings import run
run()

View file

@ -63,29 +63,30 @@ setuptools.setup(
"falkordb", "falkordb",
], ],
scripts=[ scripts=[
"scripts/api-gateway",
"scripts/agent-manager-react", "scripts/agent-manager-react",
"scripts/api-gateway",
"scripts/chunker-recursive", "scripts/chunker-recursive",
"scripts/chunker-token", "scripts/chunker-token",
"scripts/de-query-milvus", "scripts/de-query-milvus",
"scripts/de-query-qdrant",
"scripts/de-query-pinecone", "scripts/de-query-pinecone",
"scripts/de-query-qdrant",
"scripts/de-write-milvus", "scripts/de-write-milvus",
"scripts/de-write-qdrant",
"scripts/de-write-pinecone", "scripts/de-write-pinecone",
"scripts/de-write-qdrant",
"scripts/document-embeddings",
"scripts/document-rag", "scripts/document-rag",
"scripts/embeddings-ollama", "scripts/embeddings-ollama",
"scripts/embeddings-vectorize",
"scripts/ge-query-milvus", "scripts/ge-query-milvus",
"scripts/ge-query-pinecone", "scripts/ge-query-pinecone",
"scripts/ge-query-qdrant", "scripts/ge-query-qdrant",
"scripts/ge-write-milvus", "scripts/ge-write-milvus",
"scripts/ge-write-pinecone", "scripts/ge-write-pinecone",
"scripts/ge-write-qdrant", "scripts/ge-write-qdrant",
"scripts/graph-embeddings",
"scripts/graph-rag", "scripts/graph-rag",
"scripts/kg-extract-definitions", "scripts/kg-extract-definitions",
"scripts/kg-extract-topics",
"scripts/kg-extract-relationships", "scripts/kg-extract-relationships",
"scripts/kg-extract-topics",
"scripts/metering", "scripts/metering",
"scripts/object-extract-row", "scripts/object-extract-row",
"scripts/oe-write-milvus", "scripts/oe-write-milvus",
@ -103,13 +104,13 @@ setuptools.setup(
"scripts/text-completion-ollama", "scripts/text-completion-ollama",
"scripts/text-completion-openai", "scripts/text-completion-openai",
"scripts/triples-query-cassandra", "scripts/triples-query-cassandra",
"scripts/triples-query-neo4j",
"scripts/triples-query-memgraph",
"scripts/triples-query-falkordb", "scripts/triples-query-falkordb",
"scripts/triples-query-memgraph",
"scripts/triples-query-neo4j",
"scripts/triples-write-cassandra", "scripts/triples-write-cassandra",
"scripts/triples-write-neo4j",
"scripts/triples-write-memgraph",
"scripts/triples-write-falkordb", "scripts/triples-write-falkordb",
"scripts/triples-write-memgraph",
"scripts/triples-write-neo4j",
"scripts/wikipedia-lookup", "scripts/wikipedia-lookup",
] ]
) )

View file

@ -16,6 +16,44 @@ from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label" LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition" DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.rag.doc_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag: class DocumentRag:
def __init__( def __init__(
@ -55,7 +93,7 @@ class DocumentRag:
print("Initialising...", flush=True) print("Initialising...", flush=True)
# FIXME: Configurable # FIXME: Configurable
self.entity_limit = 20 self.doc_limit = 20
self.de_client = DocumentEmbeddingsClient( self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host, pulsar_host=pulsar_host,
@ -81,42 +119,16 @@ class DocumentRag:
if self.verbose: if self.verbose:
print("Initialised", flush=True) print("Initialised", flush=True)
def get_vector(self, query): def query(self, query, user="trustgraph", collection="default"):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.de_client.request(
vectors, self.entity_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
def query(self, query):
if self.verbose: if self.verbose:
print("Construct prompt...", flush=True) print("Construct prompt...", flush=True)
docs = self.get_docs(query) q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
)
docs = q.get_docs(query)
if self.verbose: if self.verbose:
print("Invoke LLM...", flush=True) print("Invoke LLM...", flush=True)

View file

@ -0,0 +1,3 @@
from . embeddings import *

View file

@ -1,5 +1,5 @@
from . vectorize import run from . embeddings import run
if __name__ == '__main__': if __name__ == '__main__':
run() run()

View file

@ -0,0 +1,109 @@
"""
Document embeddings, calls the embeddings service to get embeddings for a
chunk of text. Input is chunk of text plus metadata.
Output is chunk plus embedding.
"""
from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings
from ... schema import chunk_ingest_queue
from ... schema import document_embeddings_store_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue
default_output_queue = document_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": DocumentEmbeddings,
}
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
try:
vectors = self.embeddings.request(v.chunk)
embeds = [
ChunkEmbeddings(
chunk=v.chunk,
vectors=vectors,
)
]
r = DocumentEmbeddings(
metadata=v.metadata,
chunks=embeds,
)
self.producer.send(r)
except Exception as e:
print("Exception:", e, flush=True)
# Retry
raise e
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . embeddings import *

View file

@ -0,0 +1,6 @@
from . embeddings import run
if __name__ == '__main__':
run()

View file

@ -1,7 +1,8 @@
""" """
Vectorizer, calls the embeddings service to get embeddings for a chunk. Graph embeddings, calls the embeddings service to get embeddings for a
Input is text chunk, output is chunk and vectors. set of entity contexts. Input is entity plus textual context.
Output is entity plus embedding.
""" """
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
@ -51,11 +52,6 @@ class Processor(ConsumerProducer):
subscriber=module + "-emb", subscriber=module + "-emb",
) )
def emit(self, rec, vectors):
r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
self.producer.send(r)
def handle(self, msg): def handle(self, msg):
v = msg.value() v = msg.value()

View file

@ -1,3 +0,0 @@
from . vectorize import *

View file

@ -31,6 +31,7 @@ from . subscriber import Subscriber
from . text_completion import TextCompletionRequestor from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor from . triples_query import TriplesQueryRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . embeddings import EmbeddingsRequestor from . embeddings import EmbeddingsRequestor
@ -91,6 +92,10 @@ class Api:
pulsar_host=self.pulsar_host, timeout=self.timeout, pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth, auth = self.auth,
), ),
"document-rag": DocumentRagRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"triples-query": TriplesQueryRequestor( "triples-query": TriplesQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout, pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth, auth = self.auth,
@ -140,6 +145,10 @@ class Api:
endpoint_path = "/api/v1/graph-rag", auth=self.auth, endpoint_path = "/api/v1/graph-rag", auth=self.auth,
requestor = self.services["graph-rag"], requestor = self.services["graph-rag"],
), ),
ServiceEndpoint(
endpoint_path = "/api/v1/document-rag", auth=self.auth,
requestor = self.services["document-rag"],
),
ServiceEndpoint( ServiceEndpoint(
endpoint_path = "/api/v1/triples-query", auth=self.auth, endpoint_path = "/api/v1/triples-query", auth=self.auth,
requestor = self.services["triples-query"], requestor = self.services["triples-query"],

View file

@ -3,15 +3,16 @@
Accepts entity/vector pairs and writes them to a Milvus store. Accepts entity/vector pairs and writes them to a Milvus store.
""" """
from .... schema import ChunkEmbeddings
from .... schema import chunk_embeddings_ingest_queue
from .... log_level import LogLevel
from .... direct.milvus_doc_embeddings import DocVectors from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddings
from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel
from .... base import Consumer from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = document_embeddings_store_queue
default_subscriber = module default_subscriber = module
default_store_uri = 'http://localhost:19530' default_store_uri = 'http://localhost:19530'
@ -27,7 +28,7 @@ class Processor(Consumer):
**params | { **params | {
"input_queue": input_queue, "input_queue": input_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": DocumentEmbeddings,
"store_uri": store_uri, "store_uri": store_uri,
} }
) )
@ -38,11 +39,16 @@ class Processor(Consumer):
v = msg.value() v = msg.value()
chunk = v.chunk.decode("utf-8") for emb in v.chunks:
if v.chunk != "" and v.chunk is not None: chunk = emb.chunk.decode("utf-8")
for vec in v.vectors: if chunk == "" or chunk is None: continue
self.vecstore.insert(vec, chunk)
for vec in emb.vectors:
if chunk != "" and v.chunk is not None:
for vec in v.vectors:
self.vecstore.insert(vec, chunk)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -11,14 +11,14 @@ import time
import uuid import uuid
import os import os
from .... schema import ChunkEmbeddings from .... schema import DocumentEmbeddings
from .... schema import chunk_embeddings_ingest_queue from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import Consumer from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = document_embeddings_store_queue
default_subscriber = module default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws" default_cloud = "aws"
@ -54,7 +54,7 @@ class Processor(Consumer):
**params | { **params | {
"input_queue": input_queue, "input_queue": input_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": DocumentEmbeddings,
"url": self.url, "url": self.url,
} }
) )
@ -65,71 +65,74 @@ class Processor(Consumer):
v = msg.value() v = msg.value()
chunk = v.chunk.decode("utf-8") for emb in v.chunks:
if chunk == "": return chunk = emb.chunk.decode("utf-8")
if chunk == "" or chunk is None: continue
for vec in v.vectors: for vec in emb.vectors:
dim = len(vec) for vec in v.vectors:
collection = (
"d-" + v.metadata.user + "-" + str(dim)
)
if index_name != self.last_index_name: dim = len(vec)
collection = (
"d-" + v.metadata.user + "-" + str(dim)
)
if not self.pinecone.has_index(index_name): if index_name != self.last_index_name:
try: if not self.pinecone.has_index(index_name):
self.pinecone.create_index( try:
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
for i in range(0, 1000): self.pinecone.create_index(
name = index_name,
dimension = dim,
metric = "cosine",
spec = ServerlessSpec(
cloud = self.cloud,
region = self.region,
)
)
if self.pinecone.describe_index( for i in range(0, 1000):
index_name
).status["ready"]:
break
time.sleep(1) if self.pinecone.describe_index(
index_name
).status["ready"]:
break
if not self.pinecone.describe_index( time.sleep(1)
index_name
).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
except Exception as e: if not self.pinecone.describe_index(
print("Pinecone index creation failed") index_name
raise e ).status["ready"]:
raise RuntimeError(
"Gave up waiting for index creation"
)
print(f"Index {index_name} created", flush=True) except Exception as e:
print("Pinecone index creation failed")
raise e
self.last_index_name = index_name print(f"Index {index_name} created", flush=True)
index = self.pinecone.Index(index_name) self.last_index_name = index_name
records = [ index = self.pinecone.Index(index_name)
{
"id": id,
"values": vec,
"metadata": { "doc": chunk },
}
]
index.upsert( records = [
vectors = records, {
namespace = v.metadata.collection, "id": id,
) "values": vec,
"metadata": { "doc": chunk },
}
]
index.upsert(
vectors = records,
namespace = v.metadata.collection,
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View file

@ -8,14 +8,14 @@ from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
import uuid import uuid
from .... schema import ChunkEmbeddings from .... schema import DocumentEmbeddings
from .... schema import chunk_embeddings_ingest_queue from .... schema import document_embeddings_store_queue
from .... log_level import LogLevel from .... log_level import LogLevel
from .... base import Consumer from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue default_input_queue = document_embeddings_store_queue
default_subscriber = module default_subscriber = module
default_store_uri = 'http://localhost:6333' default_store_uri = 'http://localhost:6333'
@ -31,7 +31,7 @@ class Processor(Consumer):
**params | { **params | {
"input_queue": input_queue, "input_queue": input_queue,
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": ChunkEmbeddings, "input_schema": DocumentEmbeddings,
"store_uri": store_uri, "store_uri": store_uri,
} }
) )
@ -44,47 +44,48 @@ class Processor(Consumer):
v = msg.value() v = msg.value()
chunk = v.chunk.decode("utf-8") for emb in v.chunks:
if chunk == "": return chunk = emb.chunk.decode("utf-8")
if chunk == "": return
for vec in v.vectors: for vec in emb.vectors:
dim = len(vec) dim = len(vec)
collection = ( collection = (
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" + "d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim) str(dim)
) )
if collection != self.last_collection: if collection != self.last_collection:
if not self.client.collection_exists(collection): if not self.client.collection_exists(collection):
try: try:
self.client.create_collection( self.client.create_collection(
collection_name=collection, collection_name=collection,
vectors_config=VectorParams( vectors_config=VectorParams(
size=dim, distance=Distance.DOT size=dim, distance=Distance.COSINE
), ),
)
except Exception as e:
print("Qdrant collection creation failed")
raise e
self.last_collection = collection
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"doc": chunk,
}
) )
except Exception as e: ]
print("Qdrant collection creation failed") )
raise e
self.last_collection = collection
self.client.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"doc": chunk,
}
)
]
)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):