trustgraph/trustgraph-flow/trustgraph/document_rag.py
cybermaggedon 6aa212061d
Fix/document embeddings (#247)
* Update schema for doc embeddings

* Rename embeddings-vectorize to graph-embeddings

* Added document-embeddings processor (broken, needs fixing)

* Added scripts

* Fixed DE queue schema

* Add missing DE process

* Fix doc RAG processing, put graph-rag and doc-rag in appropriate component files.
2025-01-04 21:51:28 +00:00

144 lines
4 KiB
Python

from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.rag.de_client.request(
vectors, limit=self.rag.doc_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
# FIXME: Configurable
self.doc_limit = 20
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def query(self, query, user="trustgraph", collection="default"):
if self.verbose:
print("Construct prompt...", flush=True)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
)
docs = q.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp