trustgraph/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
cybermaggedon 14e49d83c7
Expose LLM token usage across all service layers (#782)
Expose LLM token usage (in_token, out_token, model) across all
service layers

Propagate token counts from LLM services through the prompt,
text-completion, graph-RAG, document-RAG, and agent orchestrator
pipelines to the API gateway and Python SDK. All fields are Optional
— None means "not available", distinguishing from a real zero count.

Key changes:

- Schema: Add in_token/out_token/model to TextCompletionResponse,
  PromptResponse, GraphRagResponse, DocumentRagResponse,
  AgentResponse

- TextCompletionClient: New TextCompletionResult return type. Split
  into text_completion() (non-streaming) and
  text_completion_stream() (streaming with per-chunk handler
  callback)

- PromptClient: New PromptResult with response_type
  (text/json/jsonl), typed fields (text/object/objects), and token
  usage. All callers updated.

- RAG services: Accumulate token usage across all prompt calls
  (extract-concepts, edge-scoring, edge-reasoning,
  synthesis). Non-streaming path sends single combined response
  instead of chunk + end_of_session.

- Agent orchestrator: UsageTracker accumulates tokens across
  meta-router, pattern prompt calls, and react reasoning. Attached
  to end_of_dialog.

- Translators: Encode token fields when not None (is not None, not truthy)

- Python SDK: RAG and text-completion methods return
  TextCompletionResult (non-streaming) or RAGChunk/AgentAnswer with
  token fields (streaming)

- CLI: --show-usage flag on tg-invoke-llm, tg-invoke-prompt,
  tg-invoke-graph-rag, tg-invoke-document-rag, tg-invoke-agent
2026-04-13 14:38:34 +01:00

304 lines
9.8 KiB
Python

import asyncio
import logging
import uuid
from datetime import datetime
# Provenance imports
from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
)
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, user, collection, verbose,
doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
if self.track_usage:
self.track_usage(result)
concepts = []
if result.text:
for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
# Fallback to raw query if no concepts extracted
if not concepts:
concepts = [query]
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
return concepts
async def get_vectors(self, concepts):
"""Compute embeddings for a list of concepts."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Embeddings computed")
return qembeds
async def get_docs(self, concepts):
"""
Get documents (chunks) matching the extracted concepts.
Returns:
tuple: (docs, chunk_ids) where:
- docs: list of document content strings
- chunk_ids: list of chunk IDs that were successfully fetched
"""
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting chunks from embeddings store...")
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
)
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
results = await asyncio.gather(
*[query_concept(v) for v in vectors]
)
# Deduplicate chunk matches by chunk_id
seen = set()
chunk_matches = []
for matches in results:
for match in matches:
if match.chunk_id and match.chunk_id not in seen:
seen.add(match.chunk_id)
chunk_matches.append(match)
if self.verbose:
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
chunk_ids = []
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content)
chunk_ids.append(match.chunk_id)
except Exception as e:
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")
for doc in docs:
logger.debug(f" {doc[:100]}...")
return docs, chunk_ids
class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
verbose=False,
):
self.verbose = verbose
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
if self.verbose:
logger.debug("DocumentRag initialized")
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
Execute a Document RAG query with optional explainability tracking.
Args:
query: The query string
user: User identifier
collection: Collection identifier
doc_limit: Max chunks to retrieve
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
total_in = 0
total_out = 0
last_model = None
def track_usage(result):
nonlocal total_in, total_out, last_model
if result is not None:
if result.in_token is not None:
total_in += result.in_token
if result.out_token is not None:
total_out += result.out_token
if result.model is not None:
last_model = result.model
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.utcnow().isoformat() + "Z"
# Emit question explainability immediately
if explain_callback:
q_triples = set_graph(
docrag_question_triples(q_uri, query, timestamp),
GRAPH_RETRIEVAL
)
await explain_callback(q_triples, q_uri)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
concepts = await q.extract_concepts(query)
# Emit grounding explainability after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Documents: {docs}")
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
track_usage(synthesis_result)
resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explainability after answer generated
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(syn_triples, syn_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
usage = {
"in_token": total_in if total_in > 0 else None,
"out_token": total_out if total_out > 0 else None,
"model": last_model,
}
return resp, usage