mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-12 08:42:37 +02:00
- Fixed document-rag workspace problem - OpenAI text-completion processor now puts 'not-set' in the token if no token is set (new OpenAI library requires it to be set to something. - Update tests
316 lines
10 KiB
Python
316 lines
10 KiB
Python
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
# Provenance imports
|
|
from trustgraph.provenance import (
|
|
docrag_question_uri,
|
|
docrag_grounding_uri,
|
|
docrag_exploration_uri,
|
|
docrag_synthesis_uri,
|
|
docrag_question_triples,
|
|
grounding_triples,
|
|
docrag_exploration_triples,
|
|
docrag_synthesis_triples,
|
|
set_graph,
|
|
GRAPH_RETRIEVAL,
|
|
)
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
|
|
|
class Query:
|
|
|
|
def __init__(
|
|
self, rag, workspace, collection, verbose,
|
|
doc_limit=20, track_usage=None,
|
|
):
|
|
self.rag = rag
|
|
self.workspace = workspace
|
|
self.collection = collection
|
|
self.verbose = verbose
|
|
self.doc_limit = doc_limit
|
|
self.track_usage = track_usage
|
|
|
|
async def extract_concepts(self, query):
|
|
"""Extract key concepts from query for independent embedding."""
|
|
result = await self.rag.prompt_client.prompt(
|
|
"extract-concepts",
|
|
variables={"query": query}
|
|
)
|
|
if self.track_usage:
|
|
self.track_usage(result)
|
|
|
|
concepts = []
|
|
if result.text:
|
|
for line in result.text.strip().split('\n'):
|
|
line = line.strip()
|
|
if line:
|
|
concepts.append(line)
|
|
|
|
# Fallback to raw query if no concepts extracted
|
|
if not concepts:
|
|
concepts = [query]
|
|
|
|
self.concepts_usage = result
|
|
|
|
if self.verbose:
|
|
logger.debug(f"Extracted concepts: {concepts}")
|
|
|
|
return concepts
|
|
|
|
async def get_vectors(self, concepts):
|
|
"""Compute embeddings for a list of concepts."""
|
|
if self.verbose:
|
|
logger.debug("Computing embeddings...")
|
|
|
|
qembeds = await self.rag.embeddings_client.embed(concepts)
|
|
|
|
if self.verbose:
|
|
logger.debug("Embeddings computed")
|
|
|
|
return qembeds
|
|
|
|
async def get_docs(self, concepts):
|
|
"""
|
|
Get documents (chunks) matching the extracted concepts.
|
|
|
|
Returns:
|
|
tuple: (docs, chunk_ids) where:
|
|
- docs: list of document content strings
|
|
- chunk_ids: list of chunk IDs that were successfully fetched
|
|
"""
|
|
vectors = await self.get_vectors(concepts)
|
|
|
|
if self.verbose:
|
|
logger.debug("Getting chunks from embeddings store...")
|
|
|
|
# Query chunk matches for each concept concurrently
|
|
per_concept_limit = max(
|
|
1, self.doc_limit // len(vectors)
|
|
)
|
|
|
|
async def query_concept(vec):
|
|
return await self.rag.doc_embeddings_client.query(
|
|
vector=vec, limit=per_concept_limit,
|
|
collection=self.collection,
|
|
)
|
|
|
|
results = await asyncio.gather(
|
|
*[query_concept(v) for v in vectors]
|
|
)
|
|
|
|
# Deduplicate chunk matches by chunk_id
|
|
seen = set()
|
|
chunk_matches = []
|
|
for matches in results:
|
|
for match in matches:
|
|
if match.chunk_id and match.chunk_id not in seen:
|
|
seen.add(match.chunk_id)
|
|
chunk_matches.append(match)
|
|
|
|
if self.verbose:
|
|
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
|
|
|
|
# Fetch chunk content from Garage
|
|
docs = []
|
|
chunk_ids = []
|
|
for match in chunk_matches:
|
|
if match.chunk_id:
|
|
try:
|
|
content = await self.rag.fetch_chunk(match.chunk_id)
|
|
docs.append(content)
|
|
chunk_ids.append(match.chunk_id)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
|
|
|
|
if self.verbose:
|
|
logger.debug("Documents fetched:")
|
|
for doc in docs:
|
|
logger.debug(f" {doc[:100]}...")
|
|
|
|
return docs, chunk_ids
|
|
|
|
class DocumentRag:
|
|
|
|
def __init__(
|
|
self, prompt_client, embeddings_client, doc_embeddings_client,
|
|
fetch_chunk,
|
|
verbose=False,
|
|
):
|
|
|
|
self.verbose = verbose
|
|
|
|
self.prompt_client = prompt_client
|
|
self.embeddings_client = embeddings_client
|
|
self.doc_embeddings_client = doc_embeddings_client
|
|
self.fetch_chunk = fetch_chunk
|
|
|
|
if self.verbose:
|
|
logger.debug("DocumentRag initialized")
|
|
|
|
async def query(
|
|
self, query, workspace="default", collection="default",
|
|
doc_limit=20, streaming=False, chunk_callback=None,
|
|
explain_callback=None, save_answer_callback=None,
|
|
):
|
|
"""
|
|
Execute a Document RAG query with optional explainability tracking.
|
|
|
|
Args:
|
|
query: The query string
|
|
workspace: Workspace for isolation (also scopes chunk lookup)
|
|
collection: Collection identifier
|
|
doc_limit: Max chunks to retrieve
|
|
streaming: Enable streaming LLM response
|
|
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
|
explain_callback: async def callback(triples, explain_id) for explainability
|
|
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
|
|
|
|
Returns:
|
|
tuple: (answer_text, usage) where usage is a dict with
|
|
in_token, out_token, model
|
|
"""
|
|
total_in = 0
|
|
total_out = 0
|
|
last_model = None
|
|
|
|
def track_usage(result):
|
|
nonlocal total_in, total_out, last_model
|
|
if result is not None:
|
|
if result.in_token is not None:
|
|
total_in += result.in_token
|
|
if result.out_token is not None:
|
|
total_out += result.out_token
|
|
if result.model is not None:
|
|
last_model = result.model
|
|
|
|
if self.verbose:
|
|
logger.debug("Constructing prompt...")
|
|
|
|
# Generate explainability URIs upfront
|
|
session_id = str(uuid.uuid4())
|
|
q_uri = docrag_question_uri(session_id)
|
|
gnd_uri = docrag_grounding_uri(session_id)
|
|
exp_uri = docrag_exploration_uri(session_id)
|
|
syn_uri = docrag_synthesis_uri(session_id)
|
|
|
|
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
|
|
# Emit question explainability immediately
|
|
if explain_callback:
|
|
q_triples = set_graph(
|
|
docrag_question_triples(q_uri, query, timestamp),
|
|
GRAPH_RETRIEVAL
|
|
)
|
|
await explain_callback(q_triples, q_uri)
|
|
|
|
q = Query(
|
|
rag=self, workspace=workspace, collection=collection,
|
|
verbose=self.verbose,
|
|
doc_limit=doc_limit, track_usage=track_usage,
|
|
)
|
|
|
|
# Extract concepts from query (grounding step)
|
|
concepts = await q.extract_concepts(query)
|
|
|
|
# Emit grounding explainability after concept extraction
|
|
if explain_callback:
|
|
cu = getattr(q, 'concepts_usage', None)
|
|
gnd_triples = set_graph(
|
|
grounding_triples(
|
|
gnd_uri, q_uri, concepts,
|
|
in_token=cu.in_token if cu else None,
|
|
out_token=cu.out_token if cu else None,
|
|
model=cu.model if cu else None,
|
|
),
|
|
GRAPH_RETRIEVAL
|
|
)
|
|
await explain_callback(gnd_triples, gnd_uri)
|
|
|
|
docs, chunk_ids = await q.get_docs(concepts)
|
|
|
|
# Emit exploration explainability after chunks retrieved
|
|
if explain_callback:
|
|
exp_triples = set_graph(
|
|
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
|
GRAPH_RETRIEVAL
|
|
)
|
|
await explain_callback(exp_triples, exp_uri)
|
|
|
|
if self.verbose:
|
|
logger.debug("Invoking LLM...")
|
|
logger.debug(f"Documents: {docs}")
|
|
logger.debug(f"Query: {query}")
|
|
|
|
if streaming and chunk_callback:
|
|
# Accumulate chunks for answer storage while forwarding to callback
|
|
accumulated_chunks = []
|
|
|
|
async def accumulating_callback(chunk, end_of_stream):
|
|
accumulated_chunks.append(chunk)
|
|
await chunk_callback(chunk, end_of_stream)
|
|
|
|
synthesis_result = await self.prompt_client.document_prompt(
|
|
query=query,
|
|
documents=docs,
|
|
streaming=True,
|
|
chunk_callback=accumulating_callback
|
|
)
|
|
track_usage(synthesis_result)
|
|
# Combine all chunks into full response
|
|
resp = "".join(accumulated_chunks)
|
|
else:
|
|
synthesis_result = await self.prompt_client.document_prompt(
|
|
query=query,
|
|
documents=docs
|
|
)
|
|
track_usage(synthesis_result)
|
|
resp = synthesis_result.text
|
|
|
|
if self.verbose:
|
|
logger.debug("Query processing complete")
|
|
|
|
# Emit synthesis explainability after answer generated
|
|
if explain_callback:
|
|
synthesis_doc_id = None
|
|
answer_text = resp if resp else ""
|
|
|
|
# Save answer to librarian
|
|
if save_answer_callback and answer_text:
|
|
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
|
|
try:
|
|
await save_answer_callback(synthesis_doc_id, answer_text)
|
|
if self.verbose:
|
|
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to save answer to librarian: {e}")
|
|
synthesis_doc_id = None
|
|
|
|
syn_triples = set_graph(
|
|
docrag_synthesis_triples(
|
|
syn_uri, exp_uri,
|
|
document_id=synthesis_doc_id,
|
|
in_token=synthesis_result.in_token if synthesis_result else None,
|
|
out_token=synthesis_result.out_token if synthesis_result else None,
|
|
model=synthesis_result.model if synthesis_result else None,
|
|
),
|
|
GRAPH_RETRIEVAL
|
|
)
|
|
await explain_callback(syn_triples, syn_uri)
|
|
|
|
if self.verbose:
|
|
logger.debug(f"Emitted explain for session {session_id}")
|
|
|
|
usage = {
|
|
"in_token": total_in if total_in > 0 else None,
|
|
"out_token": total_out if total_out > 0 else None,
|
|
"model": last_model,
|
|
}
|
|
|
|
return resp, usage
|
|
|