trustgraph/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
cybermaggedon d282d72db1
Fixed document-rag workspace problem (#866)
- Fixed document-rag workspace problem
- OpenAI text-completion processor now puts 'not-set' in the token
  if no token is set (new OpenAI library requires it to be set to
  something.
- Update tests
2026-05-06 14:55:21 +01:00

316 lines
10 KiB
Python

import asyncio
import logging
import uuid
from datetime import datetime, timezone
# Provenance imports
from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
)
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None,
):
self.rag = rag
self.workspace = workspace
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
# Fallback to raw query if no concepts extracted
if not concepts:
concepts = [query]
self.concepts_usage = result
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
return concepts
async def get_vectors(self, concepts):
"""Compute embeddings for a list of concepts."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Embeddings computed")
return qembeds
async def get_docs(self, concepts):
"""
Get documents (chunks) matching the extracted concepts.
Returns:
tuple: (docs, chunk_ids) where:
- docs: list of document content strings
- chunk_ids: list of chunk IDs that were successfully fetched
"""
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting chunks from embeddings store...")
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
)
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
collection=self.collection,
)
results = await asyncio.gather(
*[query_concept(v) for v in vectors]
)
# Deduplicate chunk matches by chunk_id
seen = set()
chunk_matches = []
for matches in results:
for match in matches:
if match.chunk_id and match.chunk_id not in seen:
seen.add(match.chunk_id)
chunk_matches.append(match)
if self.verbose:
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
chunk_ids = []
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(match.chunk_id)
docs.append(content)
chunk_ids.append(match.chunk_id)
except Exception as e:
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")
for doc in docs:
logger.debug(f" {doc[:100]}...")
return docs, chunk_ids
class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
verbose=False,
):
self.verbose = verbose
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
if self.verbose:
logger.debug("DocumentRag initialized")
async def query(
self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
Execute a Document RAG query with optional explainability tracking.
Args:
query: The query string
workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier
doc_limit: Max chunks to retrieve
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
# Emit question explainability immediately
if explain_callback:
q_triples = set_graph(
docrag_question_triples(q_uri, query, timestamp),
GRAPH_RETRIEVAL
)
await explain_callback(q_triples, q_uri)
q = Query(
rag=self, workspace=workspace, collection=collection,
verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
concepts = await q.extract_concepts(query)
# Emit grounding explainability after concept extraction
if explain_callback:
cu = getattr(q, 'concepts_usage', None)
gnd_triples = set_graph(
grounding_triples(
gnd_uri, q_uri, concepts,
in_token=cu.in_token if cu else None,
out_token=cu.out_token if cu else None,
model=cu.model if cu else None,
),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Documents: {docs}")
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explainability after answer generated
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
),
GRAPH_RETRIEVAL
)
await explain_callback(syn_triples, syn_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage