import asyncio import logging import uuid from datetime import datetime, timezone # Provenance imports from trustgraph.provenance import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, docrag_synthesis_uri, docrag_question_triples, grounding_triples, docrag_exploration_triples, docrag_synthesis_triples, set_graph, GRAPH_RETRIEVAL, ) # Module logger logger = logging.getLogger(__name__) LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( self, rag, workspace, collection, verbose, doc_limit=20, track_usage=None, ): self.rag = rag self.workspace = workspace self.collection = collection self.verbose = verbose self.doc_limit = doc_limit self.track_usage = track_usage async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) if self.track_usage: self.track_usage(result) concepts = [] if result.text: for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) # Fallback to raw query if no concepts extracted if not concepts: concepts = [query] self.concepts_usage = result if self.verbose: logger.debug(f"Extracted concepts: {concepts}") return concepts async def get_vectors(self, concepts): """Compute embeddings for a list of concepts.""" if self.verbose: logger.debug("Computing embeddings...") qembeds = await self.rag.embeddings_client.embed(concepts) if self.verbose: logger.debug("Embeddings computed") return qembeds async def get_docs(self, concepts): """ Get documents (chunks) matching the extracted concepts. Returns: tuple: (docs, chunk_ids) where: - docs: list of document content strings - chunk_ids: list of chunk IDs that were successfully fetched """ vectors = await self.get_vectors(concepts) if self.verbose: logger.debug("Getting chunks from embeddings store...") # Query chunk matches for each concept concurrently per_concept_limit = max( 1, self.doc_limit // len(vectors) ) async def query_concept(vec): return await self.rag.doc_embeddings_client.query( vector=vec, limit=per_concept_limit, collection=self.collection, ) results = await asyncio.gather( *[query_concept(v) for v in vectors] ) # Deduplicate chunk matches by chunk_id seen = set() chunk_matches = [] for matches in results: for match in matches: if match.chunk_id and match.chunk_id not in seen: seen.add(match.chunk_id) chunk_matches.append(match) if self.verbose: logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...") # Fetch chunk content from Garage docs = [] chunk_ids = [] for match in chunk_matches: if match.chunk_id: try: content = await self.rag.fetch_chunk(match.chunk_id, self.workspace) docs.append(content) chunk_ids.append(match.chunk_id) except Exception as e: logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}") if self.verbose: logger.debug("Documents fetched:") for doc in docs: logger.debug(f" {doc[:100]}...") return docs, chunk_ids class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk, verbose=False, ): self.verbose = verbose self.prompt_client = prompt_client self.embeddings_client = embeddings_client self.doc_embeddings_client = doc_embeddings_client self.fetch_chunk = fetch_chunk if self.verbose: logger.debug("DocumentRag initialized") async def query( self, query, workspace="default", collection="default", doc_limit=20, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): """ Execute a Document RAG query with optional explainability tracking. Args: query: The query string workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier doc_limit: Max chunks to retrieve streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for explainability save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian Returns: tuple: (answer_text, usage) where usage is a dict with in_token, out_token, model """ total_in = 0 total_out = 0 last_model = None def track_usage(result): nonlocal total_in, total_out, last_model if result is not None: if result.in_token is not None: total_in += result.in_token if result.out_token is not None: total_out += result.out_token if result.model is not None: last_model = result.model if self.verbose: logger.debug("Constructing prompt...") # Generate explainability URIs upfront session_id = str(uuid.uuid4()) q_uri = docrag_question_uri(session_id) gnd_uri = docrag_grounding_uri(session_id) exp_uri = docrag_exploration_uri(session_id) syn_uri = docrag_synthesis_uri(session_id) timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") # Emit question explainability immediately if explain_callback: q_triples = set_graph( docrag_question_triples(q_uri, query, timestamp), GRAPH_RETRIEVAL ) await explain_callback(q_triples, q_uri) q = Query( rag=self, workspace=workspace, collection=collection, verbose=self.verbose, doc_limit=doc_limit, track_usage=track_usage, ) # Extract concepts from query (grounding step) concepts = await q.extract_concepts(query) # Emit grounding explainability after concept extraction if explain_callback: cu = getattr(q, 'concepts_usage', None) gnd_triples = set_graph( grounding_triples( gnd_uri, q_uri, concepts, in_token=cu.in_token if cu else None, out_token=cu.out_token if cu else None, model=cu.model if cu else None, ), GRAPH_RETRIEVAL ) await explain_callback(gnd_triples, gnd_uri) docs, chunk_ids = await q.get_docs(concepts) # Emit exploration explainability after chunks retrieved if explain_callback: exp_triples = set_graph( docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), GRAPH_RETRIEVAL ) await explain_callback(exp_triples, exp_uri) if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Documents: {docs}") logger.debug(f"Query: {query}") if streaming and chunk_callback: # Accumulate chunks for answer storage while forwarding to callback accumulated_chunks = [] async def accumulating_callback(chunk, end_of_stream): accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs, streaming=True, chunk_callback=accumulating_callback ) track_usage(synthesis_result) # Combine all chunks into full response resp = "".join(accumulated_chunks) else: synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs ) track_usage(synthesis_result) resp = synthesis_result.text if self.verbose: logger.debug("Query processing complete") # Emit synthesis explainability after answer generated if explain_callback: synthesis_doc_id = None answer_text = resp if resp else "" # Save answer to librarian if save_answer_callback and answer_text: synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer" try: await save_answer_callback(synthesis_doc_id, answer_text) if self.verbose: logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") except Exception as e: logger.warning(f"Failed to save answer to librarian: {e}") synthesis_doc_id = None syn_triples = set_graph( docrag_synthesis_triples( syn_uri, exp_uri, document_id=synthesis_doc_id, in_token=synthesis_result.in_token if synthesis_result else None, out_token=synthesis_result.out_token if synthesis_result else None, model=synthesis_result.model if synthesis_result else None, ), GRAPH_RETRIEVAL ) await explain_callback(syn_triples, syn_uri) if self.verbose: logger.debug(f"Emitted explain for session {session_id}") usage = { "in_token": total_in if total_in > 0 else None, "out_token": total_out if total_out > 0 else None, "model": last_model, } return resp, usage