diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md new file mode 100644 index 00000000..0401947e --- /dev/null +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -0,0 +1,523 @@ +# GraphRAG Semantic Filter Improvement + +## Problem Statement + +The GraphRAG semantic filter is observed to be ineffective with certain +LLM models. Smaller models in particular produce poor-quality edge +relevance scores, and there is a suspicion that models trained or +evaluated heavily on non-Roman-script datasets offer lower performance +on the semantic ranking operation. + +The root cause is that the current implementation delegates edge +relevance scoring to the LLM via a prompt that asks the model to +assign a 1–10 relevance score to each knowledge-graph edge. This +task — ranking structured triples for relevance to a natural-language +query — is not well covered in standard LLM evaluation suites, so +model benchmark scores are not predictive of performance on this +operation. The result is that GraphRAG quality varies unpredictably +across model choices, undermining confidence in the pipeline. + +Beyond model variability, the LLM scoring step has further problems: + +- **Cost and latency.** The LLM call consumes tokens and adds + latency to every query, yet its output is unreliable. Even when + the model performs well, the cost is disproportionate for what is + fundamentally a ranking operation. + +- **Subjective scoring scale.** The 1–10 relevance scale gives the + model no objective criteria for what constitutes a 5 versus a 7. + Different models interpret the scale differently, and even the same + model can produce inconsistent scores across runs. + +- **Redundancy with the embedding pre-filter.** The pipeline already + contains a cosine-similarity stage that ranks edges by semantic + relevance using embeddings. The LLM scoring step is a second + filter applied on top of this, and it is not clear that it adds + enough value to justify the additional cost and risk of + degradation. + +### Industry context + +Semantic ranking is rigorously evaluated on dedicated benchmarks such +as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking +Information Retrieval), which test retrieval and reranking across +diverse domains. The current TrustGraph approach — prompting a +general-purpose LLM to score and rank documents (the "listwise" +approach) — is known to be poorly optimized for this task. It +suffers from positional bias, formatting failures, and +inconsistency at scale. + +The industry standard for semantic ranking has moved to +cross-encoder models: lightweight, purpose-built models that take a +query–document pair as input and produce a single relevance score. +These models are fine-tuned on millions of relevance-labelled pairs +and dominate retrieval benchmarks. They are fast, deterministic, +and do not require an LLM inference call. + +## Architecture + +### Cross-encoder service + +A new request/response service that exposes a generic semantic +ranking API. The service is not specific to GraphRAG — it is a +reusable building block for any component that needs to rank text +by relevance. + +The service interface is pluggable. Alternative implementations +can be swapped in behind the same API. + +**Packaging options considered:** + +- *`sentence-transformers`.* Full-featured, widely used. + However, it pulls in PyTorch (~2 GB), making containers + very large. Tested at ~1.8 seconds for 2200 edges. + +- *`optimum.onnxruntime`.* ONNX-based inference. Still + depends on PyTorch at import time despite using ONNX for + inference. Tested at ~4.2 seconds for 2200 edges. + +- *`flashrank`.* Lightweight wrapper around ONNX Runtime + with a clean API (`Ranker`, `RerankRequest`). No PyTorch + dependency. Tested at ~4.4 seconds for 2200 edges. + +- *Pure `onnxruntime` + `tokenizers`.* Leanest option + (~200 MB total). Requires manual tokenisation, padding, + and numpy array management — more boilerplate to maintain. + +- *External API (e.g. Cohere Rerank).* No local model at + all. Adds network latency and an external dependency. + +**Decision:** `flashrank` for the initial implementation. +No PyTorch dependency, clean API, comparable performance. +The pluggable interface allows swapping to another backend +later. + +**Request:** + +- `queries` — list of `{id, text}` objects. In the GraphRAG use + case these are the concepts extracted from the user's question. +- `documents` — list of `{id, text}` objects. In the GraphRAG + use case these are the candidate knowledge-graph edges + represented as text. +- `limit` — integer. Maximum number of results to return. + +**Scoring:** + +The service produces the cartesian product of all query–document +pairs and scores each pair through the cross-encoder model. For +each document, the maximum score across all queries is taken as the +document's relevance score. Documents are then ranked by this +score and the top `limit` results are returned. + +**Response:** + +A list of the top `limit` results, each containing: + +- `document_id` — the ID of the matched document. +- `query_id` — the ID of the query (concept) that produced the + highest score for this document. +- `score` — the relevance score. + +Including `query_id` in the response supports the explainability +interface: it records that an edge was selected because it is +related to a specific concept. + +### Integration + +The cross-encoder service follows the standard TrustGraph service +integration pattern: + +- **Base package (trustgraph-base).** Schema definitions for the + cross-encoder request/response messages. A client class that + other components (e.g. GraphRAG) can use to call the + cross-encoder service. Message translator registration so the + pub/sub layer can serialise/deserialise the messages. + +- **Flow package (trustgraph-flow).** The cross-encoder service + implementation itself — loads the model, listens for requests, + scores pairs, returns results. Flow definition support so the + cross-encoder can be introduced into a processing flow via the + standard flow configuration. `flashrank` is added as a + dependency of `trustgraph-flow`. The service runs in its own + container. + +- **API gateway.** A gateway endpoint that routes cross-encoder + requests from the HTTP API to the service over pub/sub and + returns the response. + +- **CLI tool.** A command-line utility + (e.g. `tg-invoke-cross-encoder`) that calls the gateway + endpoint for manual testing and debugging. + +### Current GraphRAG pipeline + +The current pipeline follows these steps: + +1. **Concept extraction.** An LLM prompt extracts key concepts + from the user's query. + +2. **Graph exploration.** Seed entities are found via embedding + similarity. A subgraph is built by multi-hop traversal from + the seed entities (up to `max_path_length` hops, capped at + `max_subgraph_size` edges). + +3. **Embedding pre-filter.** Each edge is embedded as + `"subject, predicate, object"` and scored by cosine similarity + against the concept embeddings. The top `edge_score_limit` + (default 30) edges are kept. + +4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the + LLM to assign a 1–10 relevance score to each remaining edge. + The top `edge_limit` (default 25) edges are kept. + +5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks + the LLM to explain why each selected edge is relevant to the + query. Used for the explainability interface. + +6. **Document tracing.** Selected edges are traced back to their + source documents in the librarian. Runs concurrently with + step 5. + +7. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Potential improvements + +#### Replace LLM edge scoring with cross-encoder (step 4) + +The LLM edge scoring step is replaced by a call to the +cross-encoder service. The candidate edges are the documents and +`edge_limit` is the limit. This is a direct substitution: faster, +cheaper, deterministic, and more reliable across model choices. +The LLM `kg-edge-scoring` prompt is retired. + +**Cross-encoder query input: concepts vs. raw query.** There are +two options for what to use as the cross-encoder queries: + +- *Option A: Raw user query.* Pass the original question as a + single query string. Simpler, no dependency on concept + extraction. However, raw queries contain noise words and + conversational phrasing that do not match well against the + structured vocabulary of knowledge-graph edges. A single query + also means every edge competes against the full question — a + partial match on one aspect is diluted. + +- *Option B: Extracted concepts.* Pass the concepts from step 1 + as separate queries. The concepts are distilled, focused terms + that are closer to the language of the edges. With multiple + concepts as independent queries, the cross-encoder scores each + edge against each concept separately, giving better coverage — + an edge only needs to match one concept well to be selected. + The trade-off is a dependency on the LLM concept extraction + step, but this is already in the pipeline and is a lightweight, + reliable LLM call. + +**Decision:** Option B — use extracted concepts. The concept +extraction is fast, and the resulting terms produce better +cross-encoder matches against structured triples. + +#### Edge text representation + +The current embedding pre-filter represents each edge as +`"subject, predicate, object"`. Two changes: + +- **Drop commas.** Commas add tokenisation noise without semantic + value. + +- **Drop the subject.** The subject identifies which entity the + edge belongs to, but it does not contribute to whether the + edge's content is relevant to the query. The predicate and + object carry the semantic meaning — what relationship exists + and what it connects to. Representing edges as `"{p} {o}"` + produces cleaner cross-encoder matches. + +#### Remove the embedding pre-filter (step 3) + +The embedding pre-filter was introduced to reduce the number of +edges before the expensive LLM scoring call. With the +cross-encoder replacing the LLM call, this cost equation changes. + +**Arguments for removal:** + +- The cross-encoder is fast enough to score the full subgraph + directly. In testing, 2200 edges scored in ~1.8 seconds; at + the default `max_subgraph_size` of 150 edges, scoring takes + a fraction of a second. + +- The pre-filter is a weaker version of what the cross-encoder + does. Bi-encoder cosine similarity embeds the query and + document independently and compares vectors; the cross-encoder + processes both texts together through the full transformer, + giving it much better relevance judgement. Running a weaker + filter before a stronger one adds latency without improving + quality. + +- Removing it eliminates an embedding service call (two batches: + concepts + edges) and the associated latency. + +**Arguments for keeping it:** + +- If the subgraph is very large (thousands of edges), the + cross-encoder's linear scaling could become a bottleneck. + The pre-filter would act as a safety valve. + +- The embedding call is cheap compared to an LLM call, so the + overhead is modest. + +**Decision:** Remove the pre-filter. The `max_subgraph_size` +parameter (default 150) already caps the number of edges entering +this stage, so the cross-encoder will not face an unbounded +workload. If very large subgraphs become a concern in future, +the pre-filter can be reintroduced or `max_subgraph_size` can be +tuned. + +#### Iterative graph traversal with cross-encoder filtering + +The current pipeline performs graph exploration and edge filtering +as separate phases: first build the full subgraph (up to +`max_path_length` hops), then score and filter edges. An +alternative is to interleave traversal and filtering — at each +hop, use the cross-encoder to select relevant edges before +expanding further. + +**Option A: Big-bang traversal then filter.** Traverse the full +subgraph up to `max_path_length` hops from the seed entities, +collecting all edges up to `max_subgraph_size`. Then +cross-encode the entire result to select the top edges. + +- Simple to implement — the current traversal logic is largely + unchanged. +- Produces large, unfocused subgraphs. Irrelevant branches are + explored and scored even though they will be discarded. +- Poorly suited to multi-hop reasoning. For a query about + Voyager 1, the subgraph includes Voyager 2's edges because + they are within hop distance, and the filter must then + separate them. + +**Option B: Iterative hop-and-filter.** At each hop: + +1. Retrieve all edges one hop from the current frontier nodes. +2. Cross-encode these edges against the query concepts. +3. Select the top relevant edges. +4. The target nodes of the selected edges become the frontier + for the next hop. +5. Repeat up to `max_path_length` hops. + +The final set of selected edges across all hops is the input to +synthesis. + +- **Guided exploration.** Each hop focuses the search by + pruning irrelevant branches before expanding further. The + working set stays small and relevant at every step. +- **Multi-hop reasoning works naturally.** Following + "Voyager 1 → has-event → crossed the heliopause" succeeds + because each hop is individually relevant and leads to the + next. +- **Smaller total workload.** Fewer edges are scored overall + because irrelevant branches are never expanded. +- **Trade-off: greedy pruning.** An edge discarded at hop 1 + cannot lead to relevant edges at hop 2. This is inherent in + any bounded traversal, and the cross-encoder is better + equipped to make this relevance judgement than a blind hop + limit. +- **Trade-off: sequential latency.** Hops cannot be + parallelised since each depends on the previous. However, + each cross-encoder call on a small edge set is very fast + (sub-second for typical working sets). + +**Decision:** Option B — iterative hop-and-filter. The guided +traversal produces more focused subgraphs and supports multi-hop +reasoning, which is a significant quality improvement over the +current approach. + +#### Replace LLM edge reasoning with cross-encoder metadata (step 5) + +The current `kg-edge-reasoning` prompt asks the LLM to explain why +each edge is relevant. With the cross-encoder now making the +selection, this explanation would be a post-hoc fabrication — the +LLM was not involved in the decision. + +- *Option A: Keep LLM reasoning.* Generates natural-language + explanations but they are not grounded in the actual selection + process. Adds an LLM call per query. + +- *Option B: Record cross-encoder metadata.* The cross-encoder + already returns the matched concept and score for each selected + edge. Use this directly as the explanation. + +**Decision:** Option B. The cross-encoder metadata is the true +reason the edge was selected. The `kg-edge-reasoning` prompt is +retired. + +#### Explainability interface update + +The explainability interface uses a `Focus` entity containing +`EdgeSelection` sub-entities. Each `EdgeSelection` currently +carries an `edge` (the quoted triple) and a `reasoning` field +(free-text LLM prose), stored as `tg:reasoning` in the +provenance graph. + +With the cross-encoder replacing LLM reasoning, the +`EdgeSelection` type gains two new predicates and drops one: + +- **Remove** `tg:reasoning` — no longer produced. +- **Add** `tg:concept` — the concept text that produced the + highest cross-encoder score for this edge. +- **Add** `tg:score` — the cross-encoder relevance score. + +This is an evolution of the existing `EdgeSelection` type, not a +new entity type. The edge selection sub-entities currently have +no `rdf:type` declared; a new `tg:EdgeSelection` type should be +added so that consumers can identify them in the provenance +graph. The `Focus` entity and its relationship to `Exploration` +are unchanged. + +The `Focus` entity's token-usage metadata (`tg:inToken`, +`tg:outToken`, `tg:llmModel`) no longer applies since there is +no LLM call. These fields are dropped from the Focus entity. + +### Proposed pipeline + +1. **Concept extraction.** Unchanged — LLM extracts key concepts + from the user's query. + +2. **Seed entity lookup.** Find seed entities via embedding + similarity against the extracted concepts. + +3. **Iterative hop-and-filter.** For each hop up to + `max_path_length`: + + a. Retrieve all edges one hop from the current frontier nodes. + + b. Represent each edge as `"{predicate} {object}"`. + + c. Score edges against the extracted concepts using the + cross-encoder service. + + d. Select the top relevant edges. The target nodes of the + selected edges become the frontier for the next hop. + +4. **Document tracing.** Selected edges are traced back to source + documents. + +5. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Implementation order + +1. Cross-encoder service with full integration (base schema, + flow service, gateway endpoint, CLI tool). +2. GraphRAG pipeline changes (iterative hop-and-filter, + edge representation, remove pre-filter). +3. Explainability update (`tg:EdgeSelection` type, concept + and score predicates, retire `tg:reasoning`). +4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts. +5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace` + to display the new metadata. Use these as the main + end-to-end test. +6. Fix any failing unit tests, then add new tests as needed. +7. Write guidance for UX devs to update the UI for the new + explainability predicates. + +## UX developer guidance + +This section describes the changes to the explainability interface +that affect frontend rendering of GraphRAG Focus events. + +### What changed + +Edge selection in GraphRAG previously used LLM-based scoring and +reasoning. Each selected edge carried a `tg:reasoning` predicate +with free-text explanation from the LLM. This has been replaced +by a cross-encoder reranker that scores edges against query +concepts. The explainability data now carries structured metadata +instead of free text. + +### Removed + +- **`tg:reasoning`** is no longer emitted on edge selection + entities in GraphRAG Focus events. UX code that reads + `edge_sel.reasoning` will get an empty string. Remove any + rendering that displays a "Reasoning" or "Reason" field for + Focus edges. + +- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and + **`kg-edge-selection`** prompts are retired. Any UX that + references these prompt names should be cleaned up. + +### Added + +Each edge selection entity within a Focus event now has three +new properties: + +| RDF predicate | API field | Type | Description | +|---|---|---|---| +| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed | +| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge | +| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) | + +The `tg:edge` predicate (RDF-star quoted triple) is unchanged. + +### How to render + +The recommended rendering for each selected edge in a Focus event: + +``` +Edge: (subject_label, predicate_label, object_label) + Concept: Score: +``` + +Scores near 1.0 indicate high relevance; scores near 0.0 indicate +low relevance. UX could use the score to drive visual indicators +such as colour intensity or a relevance bar. + +Edges are not returned in score order — they arrive in traversal +order across hops. If the UX wants to display edges ranked by +relevance, sort by `edge_sel.score` descending. + +### API classes (Python) + +The `EdgeSelection` dataclass in `trustgraph.api.explainability` +has these fields: + +```python +@dataclass +class EdgeSelection: + uri: str + edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...} + reasoning: str = "" # Legacy, always empty for new traces + concept: str = "" # Query concept that matched + score: Optional[float] = None # Cross-encoder relevance score +``` + +These are populated when calling +`ExplainabilityClient.fetch_focus_with_edges()` or when parsing +inline provenance triples from the streaming response. + +### WebSocket response format + +For inline explainability via the streaming WebSocket, Focus events +arrive as `message_type: "explain"` responses. The `explain_triples` +array contains the edge selection triples. The relevant predicates +in wire format are: + +```json +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"}, + "o": {"t": "l", "v": "flyby event"}} + +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/score"}, + "o": {"t": "l", "v": "0.9962"}} +``` + +Note that `tg:score` is transmitted as a string literal and must +be parsed to a float on the client side. + +### Exploration event + +The Exploration event's `edge_count` field now reports the number +of edges selected by the cross-encoder across all hops (previously +it reported the total number of edges retrieved before filtering). +The `entities` list continues to report the seed entities found +by vector search. diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 696df7ec..8930d159 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -95,10 +95,6 @@ class TestGraphRagIntegration: async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": return PromptResult( response_type="text", @@ -113,14 +109,22 @@ class TestGraphRagIntegration: client.prompt.side_effect = mock_prompt return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_prompt_client): + mock_triples_client, mock_reranker_client, mock_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_prompt_client, verbose=True ) @@ -167,8 +171,8 @@ class TestGraphRagIntegration: # 3. Should query triples to build knowledge subgraph assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis) - assert mock_prompt_client.prompt.call_count == 4 + # 4. Should call prompt twice (extract-concepts + synthesis) + assert mock_prompt_client.prompt.call_count == 2 # Verify final response response, usage = response diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 48e26618..8dfd9c2b 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -63,11 +63,6 @@ class TestGraphRagStreaming: async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): if prompt_id == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_id == "kg-edge-scoring": - # Edge scoring returns JSONL with IDs and scores - return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n') - elif prompt_id == "kg-edge-reasoning": - return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n') elif prompt_id == "kg-synthesis": if streaming and chunk_callback: # Simulate streaming chunks with end_of_stream flags @@ -88,14 +83,23 @@ class TestGraphRagStreaming: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with streaming support""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=True ) diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 279c81ef..efce5922 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol: client = AsyncMock() async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=False ) @@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases: client = AsyncMock() async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases: client.prompt.side_effect = prompt_with_empties + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [] + rag = GraphRag( embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])), graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), triples_client=AsyncMock(query=AsyncMock(return_value=[])), + reranker_client=mock_reranker, prompt_client=client, verbose=False ) diff --git a/tests/unit/test_base/test_prompt_client_streaming.py b/tests/unit/test_base/test_prompt_client_streaming.py index 83a4b90e..fecf6095 100644 --- a/tests/unit/test_base/test_prompt_client_streaming.py +++ b/tests/unit/test_base/test_prompt_client_streaming.py @@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback: assert callback.call_args_list[0] == call("test", False) assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio - async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client): - """Test that kg_prompt correctly passes streaming parameters""" - # Arrange - async def mock_request(request, recipient=None, timeout=600): - if recipient: - responses = [ - PromptResponse(text="Answer", object=None, error=None, end_of_stream=False), - PromptResponse(text="", object=None, error=None, end_of_stream=True), - ] - for resp in responses: - should_stop = await recipient(resp) - if should_stop: - break - - prompt_client.request = mock_request - - callback = AsyncMock() - - # Act - await prompt_client.kg_prompt( - query="What is machine learning?", - kg=[("subject", "predicate", "object")], - streaming=True, - chunk_callback=callback - ) - - # Assert - assert callback.call_count == 2 - assert callback.call_args_list[0] == call("Answer", False) - assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio async def test_document_prompt_passes_parameters_to_callback(self, prompt_client): """Test that document_prompt correctly passes streaming parameters""" diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index e65ef2e3..d1ce097a 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -107,6 +107,7 @@ class TestGraphRagDagStructure: embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() embeddings_client.embed.return_value = [[0.1, 0.2]] graph_embeddings_client.query.return_value = [ @@ -121,27 +122,22 @@ class TestGraphRagDagStructure: ] triples_client.query.return_value = [] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + reranker_client.rerank.return_value = [result] + async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": return PromptResult(response_type="text", text="concept") - elif template_id == "kg-edge-scoring": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "score": 10} for e in edges], - ) - elif template_id == "kg-edge-reasoning": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges], - ) elif template_id == "kg-synthesis": return PromptResult(response_type="text", text="Answer.") return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) @pytest.mark.asyncio async def test_dag_chain(self, mock_clients): @@ -152,7 +148,7 @@ class TestGraphRagDagStructure: events.append({"explain_id": explain_id, "triples": triples}) await rag.query( - query="test", explain_callback=explain_cb, edge_score_limit=0, + query="test", explain_callback=explain_cb, ) dag = _collect_events(events) diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index d1979211..15ffdc9d 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -15,54 +15,52 @@ class TestGraphRag: def test_graph_rag_initialization_with_defaults(self): """Test GraphRag initialization with default verbose setting""" - # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() - # Initialize GraphRag - graph_rag = GraphRag( - prompt_client=mock_prompt_client, - embeddings_client=mock_embeddings_client, - graph_embeddings_client=mock_graph_embeddings_client, - triples_client=mock_triples_client - ) - - # Verify initialization - assert graph_rag.prompt_client == mock_prompt_client - assert graph_rag.embeddings_client == mock_embeddings_client - assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client - assert graph_rag.triples_client == mock_triples_client - assert graph_rag.verbose is False # Default value - # Verify label_cache is an LRUCacheWithTTL instance - from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL - assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) - - def test_graph_rag_initialization_with_verbose(self): - """Test GraphRag initialization with verbose enabled""" - # Create mock clients - mock_prompt_client = MagicMock() - mock_embeddings_client = MagicMock() - mock_graph_embeddings_client = MagicMock() - mock_triples_client = MagicMock() - - # Initialize GraphRag with verbose=True graph_rag = GraphRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=True + reranker_client=mock_reranker_client, ) - # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client + assert graph_rag.verbose is False + from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL + assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) + + def test_graph_rag_initialization_with_verbose(self): + """Test GraphRag initialization with verbose enabled""" + mock_prompt_client = MagicMock() + mock_embeddings_client = MagicMock() + mock_graph_embeddings_client = MagicMock() + mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() + + graph_rag = GraphRag( + prompt_client=mock_prompt_client, + embeddings_client=mock_embeddings_client, + graph_embeddings_client=mock_graph_embeddings_client, + triples_client=mock_triples_client, + reranker_client=mock_reranker_client, + verbose=True, + ) + + assert graph_rag.prompt_client == mock_prompt_client + assert graph_rag.embeddings_client == mock_embeddings_client + assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client + assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client assert graph_rag.verbose is True - # Verify label_cache is an LRUCacheWithTTL instance from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) @@ -365,244 +363,162 @@ class TestQuery: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_never_passes_workspace(self): - """Verify follow_edges never passes workspace to query_stream.""" + async def test_hop_and_filter_never_passes_workspace(self): + """Verify hop_and_filter never passes workspace to query_stream.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None mock_triple = MagicMock() - mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triple.s = "e1" + mock_triple.p = "p1" + mock_triple.o = "o1" mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] + + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.9 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, ) - subgraph = set() - await query.follow_edges("e1", subgraph, path_length=1) + await query.hop_and_filter(["e1"], ["concept"]) for c in mock_triples_client.query_stream.call_args_list: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_basic_functionality(self): - """Test Query.follow_edges method basic triple discovery""" + async def test_hop_and_filter_basic_functionality(self): + """Test hop_and_filter retrieves edges and scores them with reranker.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - mock_triple1 = MagicMock() - mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" + mock_triple = MagicMock() + mock_triple.s = "entity1" + mock_triple.p = "predicate1" + mock_triple.o = "object1" + mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] - mock_triple2 = MagicMock() - mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" - - mock_triple3 = MagicMock() - mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - - mock_triples_client.query_stream.side_effect = [ - [mock_triple1], # s=ent - [mock_triple2], # p=ent - [mock_triple3], # o=ent - ] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, + edge_limit=25, ) - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=1) - - assert mock_triples_client.query_stream.call_count == 3 - - mock_triples_client.query_stream.assert_any_call( - s="entity1", p=None, o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p="entity1", o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p=None, o="entity1", limit=10, - collection="test_collection", batch_size=20, g="" + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["test concept"], ) - expected_subgraph = { - ("entity1", "predicate1", "object1"), - ("subject2", "entity1", "object2"), - ("subject3", "predicate3", "entity1") - } - assert subgraph == expected_subgraph + assert len(selected) == 1 + assert len(uri_map) == 1 + assert len(edge_meta) == 1 + + mock_reranker_client.rerank.assert_called_once() + call_kwargs = mock_reranker_client.rerank.call_args + assert call_kwargs.kwargs["limit"] == 25 @pytest.mark.asyncio - async def test_follow_edges_with_path_length_zero(self): - """Test Query.follow_edges method with path_length=0""" + async def test_hop_and_filter_with_empty_frontier(self): + """Test hop_and_filter with no seed entities returns empty.""" + mock_rag = MagicMock() + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + ) + + selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"]) + + assert selected == [] + assert uri_map == {} + assert edge_meta == {} + + @pytest.mark.asyncio + async def test_hop_and_filter_filters_label_triples(self): + """Test hop_and_filter skips rdfs:label edges.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False - ) + label_triple = MagicMock() + label_triple.s = "entity1" + label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label" + label_triple.o = "Entity One" - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=0) - - mock_triples_client.query_stream.assert_not_called() - assert subgraph == set() - - @pytest.mark.asyncio - async def test_follow_edges_with_max_subgraph_size_limit(self): - """Test Query.follow_edges method respects max_subgraph_size""" - mock_rag = MagicMock() - mock_triples_client = AsyncMock() - mock_rag.triples_client = mock_triples_client + mock_triples_client.query_stream.return_value = [label_triple] + mock_triples_client.query.return_value = [] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - max_subgraph_size=2 + triple_limit=10, ) - subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} - - await query.follow_edges("entity1", subgraph, path_length=1) - - mock_triples_client.query_stream.assert_not_called() - assert len(subgraph) == 3 - - @pytest.mark.asyncio - async def test_get_subgraph_method(self): - """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_path_length=1 + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["concept"], ) - # Mock get_entities to return (entities, concepts) tuple - query.get_entities = AsyncMock( - return_value=(["entity1", "entity2"], ["concept1"]) - ) - - query.follow_edges_batch = AsyncMock(return_value=( - { - ("entity1", "predicate1", "object1"), - ("entity2", "predicate2", "object2") - }, - {} - )) - - subgraph, term_map, entities, concepts = await query.get_subgraph("test query") - - query.get_entities.assert_called_once_with("test query") - query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) - - assert isinstance(subgraph, list) - assert len(subgraph) == 2 - assert ("entity1", "predicate1", "object1") in subgraph - assert ("entity2", "predicate2", "object2") in subgraph - assert entities == ["entity1", "entity2"] - assert concepts == ["concept1"] - - @pytest.mark.asyncio - async def test_get_labelgraph_method(self): - """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_subgraph_size=100 - ) - - test_subgraph = [ - ("entity1", "predicate1", "object1"), - ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), - ("entity3", "predicate3", "object3") - ] - test_entities = ["entity1", "entity3"] - test_concepts = ["concept1"] - query.get_subgraph = AsyncMock( - return_value=(test_subgraph, {}, test_entities, test_concepts) - ) - - async def mock_maybe_label(entity): - label_map = { - "entity1": "Human Entity One", - "predicate1": "Human Predicate One", - "object1": "Human Object One", - "entity3": "Human Entity Three", - "predicate3": "Human Predicate Three", - "object3": "Human Object Three" - } - return label_map.get(entity, entity) - - query.maybe_label = AsyncMock(side_effect=mock_maybe_label) - - labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query") - - query.get_subgraph.assert_called_once_with("test query") - - # Label triples filtered out - assert len(labeled_edges) == 2 - - # maybe_label called for non-label triples - assert query.maybe_label.call_count == 6 - - expected_edges = [ - ("Human Entity One", "Human Predicate One", "Human Object One"), - ("Human Entity Three", "Human Predicate Three", "Human Object Three") - ] - assert labeled_edges == expected_edges - - assert len(uri_map) == 2 - assert entities == test_entities - assert concepts == test_concepts + assert selected == [] + mock_reranker_client.rerank.assert_not_called() @pytest.mark.asyncio async def test_graph_rag_query_method(self): """Test GraphRag.query method orchestrates full RAG pipeline with provenance""" - import json from trustgraph.retrieval.graph_rag.graph_rag import edge_id mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() expected_response = "This is the RAG response" - test_labelgraph = [("Subject", "Predicate", "Object")] - test_edge_id = edge_id("Subject", "Predicate", "Object") + test_selected_edges = [("Subject", "Predicate", "Object")] + test_eid = edge_id("Subject", "Predicate", "Object") test_uri_map = { - test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + } + test_edge_metadata = { + test_eid: {"concept": "test concept", "score": 0.95} } - test_entities = ["http://example.org/subject"] - test_concepts = ["test concept"] - # Mock prompt responses for the multi-step process + mock_embeddings_client.embed.return_value = [[0.1, 0.2]] + mock_graph_embeddings_client.query.return_value = [] + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}]) - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}]) + return PromptResult(response_type="text", text="test concept") elif prompt_name == "kg-synthesis": return PromptResult(response_type="text", text=expected_response) return PromptResult(response_type="text", text="") @@ -614,16 +530,16 @@ class TestQuery: embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=False + reranker_client=mock_reranker_client, + verbose=False, ) - # Patch Query.get_labelgraph to return test data - original_get_labelgraph = Query.get_labelgraph + original_hop_and_filter = Query.hop_and_filter - async def mock_get_labelgraph(self, query_text): - return test_labelgraph, test_uri_map, test_entities, test_concepts + async def mock_hop_and_filter(self, seed_entities, concepts): + return test_selected_edges, test_uri_map, test_edge_metadata - Query.get_labelgraph = mock_get_labelgraph + Query.hop_and_filter = mock_hop_and_filter provenance_events = [] @@ -636,7 +552,7 @@ class TestQuery: collection="test_collection", entity_limit=25, triple_limit=15, - explain_callback=collect_provenance + explain_callback=collect_provenance, ) response_text, usage = response @@ -650,7 +566,6 @@ class TestQuery: assert len(triples) > 0 assert prov_id.startswith("urn:trustgraph:") - # Verify order assert "question" in provenance_events[0][1] assert "grounding" in provenance_events[1][1] assert "exploration" in provenance_events[2][1] @@ -658,4 +573,4 @@ class TestQuery: assert "synthesis" in provenance_events[4][1] finally: - Query.get_labelgraph = original_get_labelgraph + Query.hop_and_filter = original_hop_and_filter diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py index 1eb0dd72..bc2cb368 100644 --- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import ( TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE, TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT, - TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION, ) @@ -91,17 +91,17 @@ def build_mock_clients(): 1. prompt_client.prompt("extract-concepts", ...) -> concepts 2. embeddings_client.embed(concepts) -> vectors 3. graph_embeddings_client.query(vector, ...) -> entity matches - 4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch) + 4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter) 5. triples_client.query(s, LABEL, ...) -> labels (maybe_label) - 6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges - 7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning - 8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) - 9. prompt_client.prompt("kg-synthesis", ...) -> final answer + 6. reranker_client.rerank(queries, documents, limit) -> scored edges + 7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) + 8. prompt_client.prompt("kg-synthesis", ...) -> final answer """ prompt_client = AsyncMock() embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() # 1. Concept extraction prompt_responses = {} @@ -116,7 +116,7 @@ def build_mock_clients(): EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)), ] - # 4. Triple queries (follow_edges_batch) - return our edges + # 4. Triple queries (hop_and_filter) - return our edges kg_triples = [ make_schema_triple(*EDGE_1), make_schema_triple(*EDGE_2), @@ -130,9 +130,18 @@ def build_mock_clients(): return [] # No labels found, will fall back to URI triples_client.query.side_effect = mock_label_query - # 6+7. Edge scoring and reasoning: dynamically score/reason about - # whatever edges the query method sends us, since edge IDs are computed - # from str(Term) representations which include the full dataclass repr. + # 6. Reranker: select all documents with high scores + async def mock_rerank(queries, documents, limit): + results = [] + for i, doc in enumerate(documents): + result = MagicMock() + result.document_id = doc["id"] + result.query_id = queries[0]["id"] if queries else "0" + result.score = 0.9 - (i * 0.1) + results.append(result) + return results[:limit] + reranker_client.rerank.side_effect = mock_rerank + synthesis_answer = "Quantum computing applies physics principles to computation." async def mock_prompt(template_id, variables=None, **kwargs): @@ -141,26 +150,6 @@ def build_mock_clients(): response_type="text", text=prompt_responses["extract-concepts"], ) - elif template_id == "kg-edge-scoring": - # Score all edges highly, using the IDs that GraphRag computed - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "score": 10 - i} - for i, e in enumerate(edges) - ], - ) - elif template_id == "kg-edge-reasoning": - # Provide reasoning for each edge - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "reasoning": f"Relevant edge {i}"} - for i, e in enumerate(edges) - ], - ) elif template_id == "kg-synthesis": return PromptResult( response_type="text", @@ -170,7 +159,8 @@ def build_mock_clients(): prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) # --------------------------------------------------------------------------- @@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, # skip semantic pre-filter for simplicity + ) assert len(events) == 5, ( @@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) expected_types = [ @@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) uris = [e["explain_id"] for e in events] @@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) q_uri = events[0]["explain_id"] @@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) gnd_uri = events[1]["explain_id"] @@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) exp_uri = events[2]["explain_id"] @@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance: assert int(t.o.value) > 0 @pytest.mark.asyncio - async def test_focus_has_selected_edges_with_reasoning(self): + async def test_focus_has_selected_edges_with_concept_and_score(self): """ The focus event should carry selected edges as quoted triples - with reasoning text. + with cross-encoder concept and score metadata. """ clients = build_mock_clients() rag = GraphRag(*clients) @@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, ) foc_uri = events[3]["explain_id"] @@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance: for t in edge_t: assert t.o.triple is not None, "tg:edge object must be a quoted triple" - # Should have reasoning - reasoning = find_triples(foc_triples, TG_REASONING) - assert len(reasoning) > 0, "Focus should have reasoning for selected edges" - reasoning_texts = {t.o.value for t in reasoning} - assert any(r for r in reasoning_texts), "Reasoning should not be empty" + # Edge selections should be typed as EdgeSelection + edge_sel_uris = [t.o.iri for t in selected] + for uri in edge_sel_uris: + assert has_type(foc_triples, uri, TG_EDGE_SELECTION) + + # Should have concept and score + concepts = find_triples(foc_triples, TG_CONCEPT) + assert len(concepts) > 0, "Focus should have tg:concept for selected edges" + + scores = find_triples(foc_triples, TG_SCORE) + assert len(scores) > 0, "Focus should have tg:score for selected edges" + for t in scores: + float(t.o.value) # Should be parseable as float @pytest.mark.asyncio async def test_synthesis_is_answer_type(self): @@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) syn_uri = events[4]["explain_id"] @@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + parent_uri=parent, ) @@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) for event in events: diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index bf0b2ba1..de592b59 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -646,6 +646,16 @@ class AsyncFlowInstance: return await self.request("embeddings", request_data) + async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any): + request_data = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request_data.update(kwargs) + + return await self.request("reranker", request_data) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any): """ Query RDF triples using pattern matching. diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index fdbe2f67..78b608a7 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -443,6 +443,19 @@ class AsyncSocketFlowInstance: return await self.client._send_request("embeddings", self.flow_id, request) + async def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs): + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return await self.client._send_request( + "reranker", self.flow_id, request, + ) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs): """Triple pattern query""" request = {"limit": limit} diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 656ff95f..74a8f32e 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" TG_CONCEPT = TG + "concept" TG_ENTITY = TG + "entity" @@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" @dataclass class EdgeSelection: - """A selected edge with reasoning from GraphRAG Focus step.""" + """A selected edge with cross-encoder metadata from GraphRAG Focus step.""" uri: str edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...} reasoning: str = "" + concept: str = "" + score: Optional[float] = None @dataclass @@ -209,7 +212,7 @@ class Exploration(ExplainEntity): @dataclass class Focus(ExplainEntity): - """Focus entity - selected edges with LLM reasoning (GraphRAG only).""" + """Focus entity - selected edges with cross-encoder scoring (GraphRAG only).""" selected_edge_uris: List[str] = field(default_factory=list) edge_selections: List[EdgeSelection] = field(default_factory=list) @@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel uri = triples[0][0] if triples else "" edge = None reasoning = "" + concept = "" + score = None for s, p, o in triples: if p == TG_EDGE and isinstance(o, dict): edge = o elif p == TG_REASONING: reasoning = o + elif p == TG_CONCEPT: + concept = o + elif p == TG_SCORE: + try: + score = float(o) + except (ValueError, TypeError): + score = None - return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning) + return EdgeSelection( + uri=uri, edge=edge, reasoning=reasoning, + concept=concept, score=score, + ) def extract_term_value(term: Dict[str, Any]) -> Any: diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 961e348b..886306b3 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -491,6 +491,19 @@ class FlowInstance: input )["vectors"] + def rerank(self, queries, documents, limit=10): + + input = { + "queries": queries, + "documents": documents, + "limit": limit, + } + + return self.request( + "service/reranker", + input + ) + def graph_embeddings_query(self, text, collection, limit=10): """ Query knowledge graph entities using semantic similarity. diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index e87d85ac..3a06e0d8 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -885,6 +885,19 @@ class SocketFlowInstance: return self.client._send_request_sync("embeddings", self.flow_id, request, False) + def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs: Any) -> Dict[str, Any]: + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return self.client._send_request_sync( + "reranker", self.flow_id, request, False, + ) + def triples_query( self, s: Optional[Union[str, Dict[str, Any]]] = None, diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 6062543b..be905116 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService from . tool_service_client import ToolServiceClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec +from . reranker_client import RerankerClientSpec +from . reranker_service import RerankerService from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec from . collection_config_handler import CollectionConfigHandler diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index d4822ece..b1813ba2 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -157,21 +157,6 @@ class PromptClient(RequestResponse): timeout = timeout, ) - async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None): - return await self.prompt( - id = "kg-prompt", - variables = { - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout = timeout, - streaming = streaming, - chunk_callback = chunk_callback, - ) - async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None): return await self.prompt( id = "document-prompt", diff --git a/trustgraph-base/trustgraph/base/reranker_client.py b/trustgraph-base/trustgraph/base/reranker_client.py new file mode 100644 index 00000000..d0bed394 --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_client.py @@ -0,0 +1,43 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, +) + +class RerankerClient(RequestResponse): + async def rerank(self, queries, documents, limit=10, timeout=300): + + resp = await self.request( + RerankerRequest( + queries=[ + RerankerQuery(query_id=q["id"], query_text=q["text"]) + for q in queries + ], + documents=[ + RerankerDocument( + document_id=d["id"], document_text=d["text"] + ) + for d in documents + ], + limit=limit, + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.results + +class RerankerClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(RerankerClientSpec, self).__init__( + request_name = request_name, + request_schema = RerankerRequest, + response_name = response_name, + response_schema = RerankerResponse, + impl = RerankerClient, + ) diff --git a/trustgraph-base/trustgraph/base/reranker_service.py b/trustgraph-base/trustgraph/base/reranker_service.py new file mode 100644 index 00000000..1da3a8bf --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_service.py @@ -0,0 +1,109 @@ + +from __future__ import annotations + +from argparse import ArgumentParser + +import logging + +from .. schema import ( + RerankerRequest, RerankerResponse, RerankerResult, Error, +) +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec + +logger = logging.getLogger(__name__) + +default_ident = "reranker" +default_concurrency = 1 + +class RerankerService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + concurrency = params.get("concurrency", 1) + + super(RerankerService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = RerankerRequest, + handler = self.on_request, + concurrency = concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = RerankerResponse + ) + ) + + self.register_specification( + ParameterSpec( + name = "model", + ) + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + id = msg.properties()["id"] + + logger.debug(f"Handling reranker request {id}...") + + model = flow("model") + results = await self.on_rerank( + request.queries, request.documents, + request.limit, model=model, + ) + + await flow("response").send( + RerankerResponse( + error = None, + results = results, + ), + properties={"id": id} + ) + + logger.debug("Reranker request handled successfully") + + except TooManyRequests as e: + raise e + + except Exception as e: + + logger.error(f"Exception in reranker service: {e}", exc_info=True) + + logger.info("Sending error response...") + + await flow.producer["response"].send( + RerankerResponse( + error=Error( + type = "reranker-error", + message = str(e), + ), + results=[], + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser: ArgumentParser) -> None: + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + + FlowProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index 12c9c194..ff29ec0a 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -140,20 +140,6 @@ class PromptClient(BaseClient): timeout=timeout ) - def request_kg_prompt(self, query, kg, timeout=300): - - return self.request( - id="kg-prompt", - variables={ - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout=timeout - ) - def request_document_prompt(self, query, documents, timeout=300): return self.request( diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 9fcfa6f7..097153ac 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator @@ -163,6 +164,12 @@ TranslatorRegistry.register_service( SparqlQueryResponseTranslator() ) +TranslatorRegistry.register_service( + "reranker", + RerankerRequestTranslator(), + RerankerResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py index 5b5820fa..b0f88e88 100644 --- a/trustgraph-base/trustgraph/messaging/translators/__init__.py +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -20,3 +20,4 @@ from .embeddings_query import ( ) from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .reranker import RerankerRequestTranslator, RerankerResponseTranslator diff --git a/trustgraph-base/trustgraph/messaging/translators/reranker.py b/trustgraph-base/trustgraph/messaging/translators/reranker.py new file mode 100644 index 00000000..2d5dabc2 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/reranker.py @@ -0,0 +1,73 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, RerankerResult, +) +from .base import MessageTranslator + + +class RerankerRequestTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerRequest: + return RerankerRequest( + queries=[ + RerankerQuery( + query_id=q["query_id"], + query_text=q["query_text"], + ) + for q in data.get("queries", []) + ], + documents=[ + RerankerDocument( + document_id=d["document_id"], + document_text=d["document_text"], + ) + for d in data.get("documents", []) + ], + limit=data.get("limit", 10), + ) + + def encode(self, obj: RerankerRequest) -> Dict[str, Any]: + return { + "queries": [ + {"query_id": q.query_id, "query_text": q.query_text} + for q in obj.queries + ], + "documents": [ + {"document_id": d.document_id, "document_text": d.document_text} + for d in obj.documents + ], + "limit": obj.limit, + } + + +class RerankerResponseTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerResponse: + return RerankerResponse( + results=[ + RerankerResult( + document_id=r["document_id"], + query_id=r["query_id"], + score=r["score"], + ) + for r in data.get("results", []) + ], + ) + + def encode(self, obj: RerankerResponse) -> Dict[str, Any]: + return { + "results": [ + { + "document_id": r.document_id, + "query_id": r.query_id, + "score": r.score, + } + for r in obj.results + ], + } + + def encode_with_completion( + self, obj: RerankerResponse + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 051efc66..ce91a3cb 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -89,7 +89,9 @@ from . namespaces import ( TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types @@ -212,7 +214,9 @@ __all__ = [ "TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE", # Query-time provenance predicates (GraphRAG) "TG_QUERY", "TG_CONCEPT", "TG_ENTITY", - "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", + "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE", + # Edge selection entity type + "TG_EDGE_SELECTION", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", # Explainability entity types diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 0b14f1b9..6f81f122 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" # Reference to document in librarian +# Edge selection entity type (cross-encoder scored edge in Focus) +TG_EDGE_SELECTION = TG + "EdgeSelection" + # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 8dedff9a..8e4871c3 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -24,8 +24,10 @@ from . namespaces import ( TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE, TG_DOCUMENT, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types @@ -536,10 +538,9 @@ def focus_triples( _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), ] - # Add each selected edge with its reasoning via intermediate entity + # Add each selected edge with metadata via intermediate entity for idx, edge_info in enumerate(selected_edges_with_reasoning): edge = edge_info.get("edge") - reasoning = edge_info.get("reasoning", "") if edge: s, p, o = edge @@ -552,13 +553,32 @@ def focus_triples( _triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri)) ) + # Type the edge selection entity + triples.append( + _triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION)) + ) + # Attach quoted triple to edge selection entity quoted = _quoted_triple(s, p, o) triples.append( Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted) ) - # Attach reasoning to edge selection entity + # Structured cross-encoder metadata + concept = edge_info.get("concept") + if concept: + triples.append( + _triple(edge_sel_uri, TG_CONCEPT, _literal(concept)) + ) + + score = edge_info.get("score") + if score is not None: + triples.append( + _triple(edge_sel_uri, TG_SCORE, _literal(str(score))) + ) + + # Legacy reasoning text (for non-cross-encoder callers) + reasoning = edge_info.get("reasoning", "") if reasoning: triples.append( _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index afb5c30f..1434d45d 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -29,6 +29,7 @@ from . namespaces import ( TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_EDGE_SELECTION, TG_SCORE, ) @@ -93,6 +94,7 @@ TG_CLASS_LABELS = [ _label_triple(TG_FINDING, "Finding"), _label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_STEP_RESULT, "Step Result"), + _label_triple(TG_EDGE_SELECTION, "Edge Selection"), ] # TrustGraph predicate labels @@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [ _label_triple(TG_ENTITY, "entity"), _label_triple(TG_SUBAGENT_GOAL, "subagent goal"), _label_triple(TG_PLAN_STEP, "plan step"), + _label_triple(TG_SCORE, "score"), ] diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 2a214201..63dc05fd 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -15,4 +15,5 @@ from .diagnosis import * from .collection import * from .storage import * from .tool_service import * -from .sparql_query import * \ No newline at end of file +from .sparql_query import * +from .reranker import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index 1696790b..0a9c23ef 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -6,17 +6,6 @@ from ..core.primitives import Error # Prompt services, abstract the prompt generation -# extract-definitions: -# chunk -> definitions -# extract-relationships: -# chunk -> relationships -# kg-prompt: -# query, triples -> answer -# document-prompt: -# query, documents -> answer -# extract-rows -# schema, chunk -> rows - @dataclass class PromptRequest: id: str = "" @@ -46,4 +35,4 @@ class PromptResponse: out_token: int | None = None model: str | None = None -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/reranker.py b/trustgraph-base/trustgraph/schema/services/reranker.py new file mode 100644 index 00000000..948746e4 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/reranker.py @@ -0,0 +1,35 @@ + +from dataclasses import dataclass, field + +from ..core.primitives import Error + +############################################################################ + +# Cross-encoder reranker + +@dataclass +class RerankerQuery: + query_id: str = "" + query_text: str = "" + +@dataclass +class RerankerDocument: + document_id: str = "" + document_text: str = "" + +@dataclass +class RerankerRequest: + queries: list[RerankerQuery] = field(default_factory=list) + documents: list[RerankerDocument] = field(default_factory=list) + limit: int = 10 + +@dataclass +class RerankerResult: + document_id: str = "" + query_id: str = "" + score: float = 0.0 + +@dataclass +class RerankerResponse: + error: Error | None = None + results: list[RerankerResult] = field(default_factory=list) diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 006a07f4..193ee1cd 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main" tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" +tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" tg-load-kg-core = "trustgraph.cli.load_kg_core:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index f39cdab0..892d2d35 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -112,14 +112,13 @@ def _question_explainable_api( if focus_full and focus_full.edge_selections: for edge_sel in focus_full.edge_selections: if edge_sel.edge: - # Resolve labels for edge components s_label, p_label, o_label = explain_client.resolve_edge_labels( edge_sel.edge, collection ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reason: {r_short}", file=sys.stderr) + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr) elif isinstance(entity, Synthesis): print(f"\n [synthesis] {prov_id}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/invoke_reranker.py b/trustgraph-cli/trustgraph/cli/invoke_reranker.py new file mode 100644 index 00000000..91337c97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_reranker.py @@ -0,0 +1,127 @@ +""" +Invokes the reranker service to score and rank documents by relevance +to one or more queries. +""" + +import argparse +import json +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def query(url, flow_id, queries, documents, limit, token=None, + workspace="default"): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + + query_objects = [ + {"query_id": str(i), "query_text": q} + for i, q in enumerate(queries) + ] + + document_objects = [ + {"document_id": str(i), "document_text": d} + for i, d in enumerate(documents) + ] + + result = flow.rerank( + queries=query_objects, + documents=document_objects, + limit=limit, + ) + + if "error" in result and result["error"]: + err = result["error"] + print(f"Error: [{err.get('type', '')}] {err.get('message', '')}") + return + + for r in result.get("results", []): + doc_idx = int(r["document_id"]) + query_idx = int(r["query_id"]) + print( + f" {r['score']:.4f} | " + f"query: {queries[query_idx]} | " + f"doc: {documents[doc_idx]}" + ) + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-reranker', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + '-q', '--query', + action='append', + required=True, + help='Query text (can be specified multiple times)', + ) + + parser.add_argument( + 'documents', + nargs='+', + help='Documents to rerank', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + queries=args.query, + documents=args.documents, + limit=args.limit, + token=args.token, + workspace=args.workspace, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 17aaca1a..ed4a9807 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_ ) print(f" {i}. ({s_label}, {p_label}, {o_label})") - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reasoning: {r_short}") + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}") if show_provenance and edge_sel.edge: provenance = trace_edge_provenance( @@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type): "selected_edges": [ { "edge": edge_sel.edge, - "reasoning": edge_sel.reasoning, + "concept": edge_sel.concept, + "score": edge_sel.score, } for edge_sel in focus.edge_selections ], diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index f9f6c5d9..90647104 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "faiss-cpu", "falkordb", "fastembed", + "flashrank", "ibis", "jsonschema", "langchain", @@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone: graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" graph-embeddings = "trustgraph.embeddings.graph_embeddings:run" graph-rag = "trustgraph.retrieval.graph_rag:run" +reranker-flashrank = "trustgraph.reranker.flashrank:run" kg-extract-agent = "trustgraph.extract.kg.agent:run" kg-extract-definitions = "trustgraph.extract.kg.definitions:run" kg-extract-rows = "trustgraph.extract.kg.rows:run" diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index bddb009d..7285250f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . document_embeddings_query import DocumentEmbeddingsQueryRequestor from . row_embeddings_query import RowEmbeddingsQueryRequestor from . mcp_tool import McpToolRequestor +from . reranker import RerankerRequestor from . text_load import TextLoad from . document_load import DocumentLoad @@ -74,6 +75,7 @@ request_response_dispatchers = { "structured-diag": StructuredDiagRequestor, "row-embeddings": RowEmbeddingsQueryRequestor, "sparql": SparqlQueryRequestor, + "reranker": RerankerRequestor, } system_dispatchers = { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py new file mode 100644 index 00000000..e456f3d1 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py @@ -0,0 +1,31 @@ + +from ... schema import RerankerRequest, RerankerResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class RerankerRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(RerankerRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=RerankerRequest, + response_schema=RerankerResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("reranker") + self.response_translator = TranslatorRegistry.get_response_translator("reranker") + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index bdc3ed4c..14f820c2 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -518,6 +518,7 @@ _FLOW_SERVICES = { "structured-diag": "structured-query:read", "row-embeddings": "row-embeddings:read", "sparql": "sparql:read", + "reranker": "reranker", } for _kind, _cap in _FLOW_SERVICES.items(): _register_flow_kind("flow-service", _kind, _cap) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index f1f7d92d..fced972e 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -72,6 +72,7 @@ _READER_CAPS = { "row-embeddings:read", "llm", "embeddings", + "reranker", "mcp", "config:read", "flows:read", diff --git a/trustgraph-flow/trustgraph/reranker/__init__.py b/trustgraph-flow/trustgraph/reranker/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/__init__.py @@ -0,0 +1 @@ + diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py new file mode 100644 index 00000000..bd3b0e96 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py @@ -0,0 +1,2 @@ + +from . processor import * diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py new file mode 100644 index 00000000..1ebce4d4 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/processor.py b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py new file mode 100644 index 00000000..481d1a79 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py @@ -0,0 +1,109 @@ + +""" +Reranker service using flashrank. +Scores query-document pairs and returns the top results ranked by +relevance. +""" + +import asyncio +import logging + +from ... base import RerankerService +from ... schema import RerankerResult + +from flashrank import Ranker, RerankRequest + +logger = logging.getLogger(__name__) + +default_ident = "reranker" + +default_model = "ms-marco-MiniLM-L-12-v2" + +class Processor(RerankerService): + + def __init__(self, **params): + + model = params.get("model", default_model) + + super(Processor, self).__init__( + **params | { "model": model } + ) + + self.default_model = model + + self.cached_model_name = None + self.ranker = None + + self._load_model(model) + + def _load_model(self, model_name): + if self.cached_model_name != model_name: + logger.info(f"Loading flashrank model: {model_name}") + self.ranker = Ranker(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"flashrank model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + def _run_rerank(self, query, passages): + request = RerankRequest(query=query, passages=passages) + return self.ranker.rerank(request) + + async def on_rerank(self, queries, documents, limit, model=None): + + if not queries or not documents: + return [] + + use_model = model or self.default_model + + if self.cached_model_name != use_model: + await asyncio.to_thread(self._load_model, use_model) + + passages = [ + {"id": d.document_id, "text": d.document_text} + for d in documents + ] + + best_scores = {} + + for q in queries: + ranked = await asyncio.to_thread( + self._run_rerank, q.query_text, passages, + ) + + for r in ranked: + doc_id = r["id"] + score = r["score"] + score = float(score) + if doc_id not in best_scores or score > best_scores[doc_id][1]: + best_scores[doc_id] = (q.query_id, score) + + results = sorted( + best_scores.items(), + key=lambda x: x[1][1], + reverse=True, + )[:limit] + + return [ + RerankerResult( + document_id=doc_id, + query_id=query_id, + score=score, + ) + for doc_id, (query_id, score) in results + ] + + @staticmethod + def add_args(parser): + + RerankerService.add_args(parser) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Reranker model (default: {default_model})' + ) + +def run(): + + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 81dc8fe2..06c0b5b4 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -120,7 +120,7 @@ class Query: def __init__( self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, track_usage=None, + max_path_length=2, edge_limit=25, track_usage=None, ): self.rag = rag self.collection = collection @@ -129,6 +129,7 @@ class Query: self.triple_limit = triple_limit self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length + self.edge_limit = edge_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -217,12 +218,9 @@ class Query: logger.debug(f" {ent}") return entities, concepts - + async def maybe_label(self, e): - # The label cache lives on a per-request GraphRag instance — no - # cross-request isolation concern. The collection prefix keeps - # entries from different collections distinct within one request. cache_key = f"{self.collection}:{e}" cached_label = self.rag.label_cache.get(cache_key) @@ -244,11 +242,10 @@ class Query: return label async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently using streaming""" + """Execute triple queries for multiple entities concurrently.""" tasks = [] for entity in entities: - # Create concurrent streaming tasks for all 3 query types per entity tasks.extend([ self.rag.triples_client.query_stream( s=entity, p=None, o=None, @@ -270,10 +267,8 @@ class Query: ) ]) - # Execute all queries concurrently results = await asyncio.gather(*tasks, return_exceptions=True) - # Combine all results all_triples = [] for result in results: if not isinstance(result, Exception) and result is not None: @@ -281,168 +276,151 @@ class Query: return all_triples - async def follow_edges_batch(self, entities, max_depth): - """Optimized iterative graph traversal with batching. - - Returns: - tuple: (subgraph, term_map) where subgraph is a set of - (str, str, str) tuples and term_map maps each string tuple - to its original (Term, Term, Term) for type-preserving - provenance. - """ - visited = set() - current_level = set(entities) - subgraph = set() - term_map = {} # (str, str, str) -> (Term, Term, Term) - - for depth in range(max_depth): - if not current_level or len(subgraph) >= self.max_subgraph_size: - break - - # Filter out already visited entities - unvisited_entities = [e for e in current_level if e not in visited] - if not unvisited_entities: - break - - # Batch query all unvisited entities at current level - triples = await self.execute_batch_triple_queries( - unvisited_entities, self.triple_limit - ) - - # Process results and collect next level entities - next_level = set() - for triple in triples: - triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) - subgraph.add(triple_tuple) - term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o)) - - # Collect entities for next level (only from s and o positions) - if depth < max_depth - 1: # Don't collect for final depth - s, p, o = triple_tuple - if s not in visited: - next_level.add(s) - if o not in visited: - next_level.add(o) - - # Stop if subgraph size limit reached - if len(subgraph) >= self.max_subgraph_size: - return subgraph, term_map - - # Update for next iteration - visited.update(current_level) - current_level = next_level - - return subgraph, term_map - - async def follow_edges(self, ent, subgraph, path_length): - """Legacy method - replaced by follow_edges_batch""" - # Maintain backward compatibility with early termination checks - if path_length <= 0: - return - - if len(subgraph) >= self.max_subgraph_size: - return - - # For backward compatibility, convert to new approach - batch_result, _ = await self.follow_edges_batch([ent], path_length) - subgraph.update(batch_result) - - async def get_subgraph(self, query): - """ - Get subgraph by extracting concepts, finding entities, and traversing. - - Returns: - tuple: (subgraph, term_map, entities, concepts) where subgraph is - a list of (s, p, o) string tuples, term_map maps each string - tuple to its original (Term, Term, Term), entities is the seed - entity list, and concepts is the extracted concept list. - """ - - entities, concepts = await self.get_entities(query) - - if self.verbose: - logger.debug("Getting subgraph...") - - # Use optimized batch traversal instead of sequential processing - subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length) - - return list(subgraph), term_map, entities, concepts - async def resolve_labels_batch(self, entities): - """Resolve labels for multiple entities in parallel""" - tasks = [] - for entity in entities: - tasks.append(self.maybe_label(entity)) - + """Resolve labels for multiple entities in parallel.""" + tasks = [self.maybe_label(entity) for entity in entities] return await asyncio.gather(*tasks, return_exceptions=True) - async def get_labelgraph(self, query): - """ - Get subgraph with labels resolved for display. + async def hop_and_filter(self, seed_entities, concepts): + """Iterative hop-and-filter graph traversal with cross-encoder. + + At each hop: + 1. Retrieve all edges one hop from the frontier. + 2. Resolve labels and represent each edge as "{p} {o}". + 3. Score edges against concepts using the cross-encoder. + 4. Select the top edges; their target nodes become the next + frontier. Returns: - tuple: (labeled_edges, uri_map, entities, concepts) where: - - labeled_edges: list of (label_s, label_p, label_o) tuples - - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) - - entities: list of seed entity URI strings - - concepts: list of concept strings extracted from query + tuple: (selected_edges, uri_map, edge_metadata) where: + - selected_edges: list of (label_s, label_p, label_o) + - uri_map: dict mapping edge_id -> (Term, Term, Term) + - edge_metadata: dict mapping edge_id -> {concept, score} """ - subgraph, term_map, entities, concepts = await self.get_subgraph(query) + all_selected_edges = [] + uri_map = {} + edge_metadata = {} + frontier = set(seed_entities) + visited_entities = set() + seen_edges = set() - # Filter out label triples - filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] + for hop in range(self.max_path_length): + if not frontier: + break - # Collect all unique entities that need label resolution - entities_to_resolve = set() - for s, p, o in filtered_subgraph: - entities_to_resolve.update([s, p, o]) + unvisited = [e for e in frontier if e not in visited_entities] + if not unvisited: + break - # Batch resolve labels for all entities in parallel - entity_list = list(entities_to_resolve) - resolved_labels = await self.resolve_labels_batch(entity_list) + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(unvisited)} frontier entities" + ) - # Create entity-to-label mapping - label_map = {} - for entity, label in zip(entity_list, resolved_labels): - if not isinstance(label, Exception): - label_map[entity] = label - else: - label_map[entity] = entity # Fallback to entity itself - - # Apply labels to subgraph and build URI mapping - labeled_edges = [] - uri_map = {} # Maps edge_id of labeled edge -> original Term triple - - for s, p, o in filtered_subgraph: - labeled_triple = ( - label_map.get(s, s), - label_map.get(p, p), - label_map.get(o, o) + # Retrieve edges one hop from frontier + triples = await self.execute_batch_triple_queries( + unvisited, self.triple_limit, ) - labeled_edges.append(labeled_triple) - # Map from labeled edge ID to original Terms (preserving types) - labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) - uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o)) + # Deduplicate and filter already-seen edges + hop_triples = [] + hop_term_map = {} + for triple in triples: + triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) + if triple_tuple[1] == LABEL: + continue + if triple_tuple in seen_edges: + continue + seen_edges.add(triple_tuple) + hop_triples.append(triple_tuple) + hop_term_map[triple_tuple] = ( + to_term(triple.s), to_term(triple.p), to_term(triple.o), + ) - labeled_edges = labeled_edges[0:self.max_subgraph_size] + if not hop_triples: + visited_entities.update(frontier) + break - if self.verbose: - logger.debug("Subgraph:") - for edge in labeled_edges: - logger.debug(f" {str(edge)}") + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(hop_triples)} candidate edges" + ) - if self.verbose: - logger.debug("Done.") + # Resolve labels for all entities in hop edges + entities_to_resolve = set() + for s, p, o in hop_triples: + entities_to_resolve.update([s, p, o]) - return labeled_edges, uri_map, entities, concepts + entity_list = list(entities_to_resolve) + resolved = await self.resolve_labels_batch(entity_list) + + label_map = {} + for entity, label in zip(entity_list, resolved): + if not isinstance(label, Exception): + label_map[entity] = label + else: + label_map[entity] = entity + + # Build labeled edges and documents for cross-encoder + labeled_hop = [] + for s, p, o in hop_triples: + ls = label_map.get(s, s) + lp = label_map.get(p, p) + lo = label_map.get(o, o) + labeled_hop.append((ls, lp, lo)) + + documents = [ + {"id": str(i), "text": f"{lp} {lo}"} + for i, (ls, lp, lo) in enumerate(labeled_hop) + ] + + queries = [ + {"id": str(i), "text": c} + for i, c in enumerate(concepts) + ] + + # Score with cross-encoder + results = await self.rag.reranker_client.rerank( + queries=queries, + documents=documents, + limit=self.edge_limit, + ) + + # Collect selected edges and metadata + next_frontier = set() + for r in results: + idx = int(r.document_id) + ls, lp, lo = labeled_hop[idx] + s, p, o = hop_triples[idx] + eid = edge_id(ls, lp, lo) + + all_selected_edges.append((ls, lp, lo)) + uri_map[eid] = hop_term_map[(s, p, o)] + edge_metadata[eid] = { + "concept": concepts[int(r.query_id)], + "score": r.score, + } + + # Target nodes become next frontier + next_frontier.add(s) + next_frontier.add(o) + + if self.verbose: + logger.debug( + f"Hop {hop + 1}: selected {len(results)} edges" + ) + + visited_entities.update(frontier) + frontier = next_frontier - visited_entities + + return all_selected_edges, uri_map, edge_metadata async def trace_source_documents(self, edge_uris): """ Trace selected edges back to their source documents via provenance. - Follows the chain: edge → subgraph (via tg:contains) → chunk → - page → document (via prov:wasDerivedFrom), all in urn:graph:source. + Follows the chain: edge -> subgraph (via tg:contains) -> chunk -> + page -> document (via prov:wasDerivedFrom), all in urn:graph:source. Args: edge_uris: List of (s, p, o) URI string tuples @@ -453,7 +431,6 @@ class Query: # Step 1: Find subgraphs containing these edges via tg:contains subgraph_tasks = [] for s, p, o in edge_uris: - # s, p, o may be Term objects (preserving types) or strings s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s) p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p) o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o) @@ -487,12 +464,10 @@ class Query: return [] # Step 2: Walk prov:wasDerivedFrom chain to find documents - # Each level: query ?entity prov:wasDerivedFrom ?parent - # Stop when we find entities typed tg:Document current_uris = subgraph_uris doc_uris = set() - for depth in range(4): # Max depth: subgraph → chunk → page → doc + for depth in range(4): if not current_uris: break @@ -509,7 +484,6 @@ class Query: *derivation_tasks, return_exceptions=True ) - # URIs with no parent are root documents next_uris = set() for uri, result in zip(current_uris, derivation_results): if isinstance(result, Exception) or not result: @@ -524,7 +498,6 @@ class Query: return [] # Step 3: Get all document metadata properties - # Skip structural predicates that aren't useful context SKIP_PREDICATES = { PROV_WAS_DERIVED_FROM, "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", @@ -565,7 +538,7 @@ class GraphRag: def __init__( self, prompt_client, embeddings_client, graph_embeddings_client, - triples_client, verbose=False, + triples_client, reranker_client, verbose=False, ): self.verbose = verbose @@ -574,9 +547,8 @@ class GraphRag: self.embeddings_client = embeddings_client self.graph_embeddings_client = graph_embeddings_client self.triples_client = triples_client + self.reranker_client = reranker_client - # Replace simple dict with LRU cache with TTL - # CRITICAL: This cache only lives for one request due to per-request instantiation self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) if self.verbose: @@ -585,33 +557,12 @@ class GraphRag: async def query( self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, edge_score_limit = 30, edge_limit = 25, + max_path_length = 2, edge_limit = 25, streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, parent_uri = "", ): - """ - Execute a GraphRAG query with real-time explainability tracking. - - Args: - query: The query string - collection: Collection identifier - entity_limit: Max entities to retrieve - triple_limit: Max triples per entity - max_subgraph_size: Max edges in subgraph - max_path_length: Max hops from seed entities - edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter) - edge_limit: Max edges after LLM scoring - streaming: Enable streaming LLM response - chunk_callback: async def callback(chunk, end_of_stream) for streaming - explain_callback: async def callback(triples, explain_id) for real-time explainability - save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian - - Returns: - tuple: (answer_text, usage) where usage is a dict with - in_token, out_token, model - """ # Accumulate token usage across all prompt calls total_in = 0 total_out = 0 @@ -638,7 +589,9 @@ class GraphRag: foc_uri = make_focus_uri(session_id) syn_uri = make_synthesis_uri(session_id) - timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + timestamp = datetime.now(timezone.utc).isoformat().replace( + "+00:00", "Z", + ) # Emit question explainability immediately if explain_callback: @@ -657,10 +610,12 @@ class GraphRag: triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, track_usage = track_usage, ) - kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) + # Step 1: Extract concepts and find seed entities + seed_entities, concepts = await q.get_entities(query) # Emit grounding explain after concept extraction if explain_callback: @@ -676,11 +631,16 @@ class GraphRag: ) await explain_callback(gnd_triples, gnd_uri) - # Emit exploration explain after graph retrieval completes + # Step 2: Iterative hop-and-filter with cross-encoder + selected_edges, uri_map, edge_metadata = await q.hop_and_filter( + seed_entities, concepts, + ) + + # Emit exploration explain if explain_callback: exp_triples = set_graph( exploration_triples( - exp_uri, gnd_uri, len(kg), + exp_uri, gnd_uri, len(selected_edges), entities=seed_entities, ), GRAPH_RETRIEVAL @@ -688,235 +648,63 @@ class GraphRag: await explain_callback(exp_triples, exp_uri) if self.verbose: - logger.debug("Invoking LLM...") - logger.debug(f"Knowledge graph: {kg}") - logger.debug(f"Query: {query}") - - # Semantic pre-filter: reduce edges before expensive LLM scoring - if edge_score_limit > 0 and len(kg) > edge_score_limit: - - if self.verbose: + logger.debug(f"Selected {len(selected_edges)} edges") + for s, p, o in selected_edges: + eid = edge_id(s, p, o) + meta = edge_metadata.get(eid, {}) logger.debug( - f"Semantic pre-filter: {len(kg)} edges > " - f"limit {edge_score_limit}, filtering..." + f" {meta.get('score', 0):.4f} " + f"[{meta.get('concept', '')}] " + f"{s} | {p} | {o}" ) - # Embed edge descriptions: "subject, predicate, object" - edge_descriptions = [ - f"{s}, {p}, {o}" for s, p, o in kg - ] - - # Embed concepts and edge descriptions concurrently - concept_embed_task = self.embeddings_client.embed(concepts) - edge_embed_task = self.embeddings_client.embed(edge_descriptions) - - concept_vectors, edge_vectors = await asyncio.gather( - concept_embed_task, edge_embed_task - ) - - # Score each edge by max cosine similarity to any concept - def cosine_similarity(a, b): - dot = sum(x * y for x, y in zip(a, b)) - norm_a = math.sqrt(sum(x * x for x in a)) - norm_b = math.sqrt(sum(x * x for x in b)) - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) - - edge_scores = [] - for i, edge_vec in enumerate(edge_vectors): - max_sim = max( - cosine_similarity(edge_vec, cv) - for cv in concept_vectors - ) - edge_scores.append((max_sim, i)) - - # Sort by similarity descending and keep top edge_score_limit - edge_scores.sort(reverse=True) - keep_indices = set( - idx for _, idx in edge_scores[:edge_score_limit] - ) - - # Filter kg and rebuild uri_map - filtered_kg = [] - filtered_uri_map = {} - for i, (s, p, o) in enumerate(kg): - if i in keep_indices: - filtered_kg.append((s, p, o)) - eid = edge_id(s, p, o) - if eid in uri_map: - filtered_uri_map[eid] = uri_map[eid] - - if self.verbose: - logger.debug( - f"Semantic pre-filter kept {len(filtered_kg)} " - f"of {len(kg)} edges" - ) - - kg = filtered_kg - uri_map = filtered_uri_map - - # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} - # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) - edge_map = {} - edges_with_ids = [] - for s, p, o in kg: - eid = edge_id(s, p, o) - edge_map[eid] = (s, p, o) - edges_with_ids.append({ - "id": eid, - "s": s, - "p": p, - "o": o - }) - - if self.verbose: - logger.debug(f"Built edge map with {len(edge_map)} edges") - - # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_result = await self.prompt_client.prompt( - "kg-edge-scoring", - variables={ - "query": query, - "knowledge": edges_with_ids - } - ) - track_usage(scoring_result) - - if self.verbose: - logger.debug(f"Edge scoring result: {scoring_result}") - - # Parse scoring response (jsonl) to get edge IDs with scores - scored_edges = [] - - for obj in scoring_result.objects or []: - if isinstance(obj, dict) and "id" in obj and "score" in obj: - try: - score = int(obj["score"]) - except (ValueError, TypeError): - score = 0 - scored_edges.append({"id": obj["id"], "score": score}) - - # Select top N edges by score - scored_edges.sort(key=lambda x: x["score"], reverse=True) - top_edges = scored_edges[:edge_limit] - selected_ids = {e["id"] for e in top_edges} - - if self.verbose: - logger.debug( - f"Scored {len(scored_edges)} edges, " - f"selected top {len(selected_ids)}" - ) - - # Filter to selected edges - selected_edges = [] - for eid in selected_ids: - if eid in edge_map: - selected_edges.append(edge_map[eid]) - - # Step 1b: Edge Reasoning + Document Tracing (concurrent) - selected_edges_with_ids = [ - {"id": eid, "s": s, "p": p, "o": o} - for eid in selected_ids - if eid in edge_map - for s, p, o in [edge_map[eid]] - ] - - # Collect selected edge URIs for document tracing + # Step 3: Document tracing selected_edge_uris = [ - uri_map[eid] - for eid in selected_ids - if eid in uri_map + uri_map[edge_id(s, p, o)] + for s, p, o in selected_edges + if edge_id(s, p, o) in uri_map ] - # Run reasoning and document tracing concurrently - async def _get_reasoning(): - result = await self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) - track_usage(result) - return result - - reasoning_task = _get_reasoning() - doc_trace_task = q.trace_source_documents(selected_edge_uris) - - reasoning_result, source_documents = await asyncio.gather( - reasoning_task, doc_trace_task, return_exceptions=True + source_documents = await q.trace_source_documents( + selected_edge_uris, ) - # Handle exceptions from gather - if isinstance(reasoning_result, Exception): - logger.warning( - f"Edge reasoning failed: {reasoning_result}" - ) - reasoning_result = None if isinstance(source_documents, Exception): logger.warning( f"Document tracing failed: {source_documents}" ) source_documents = [] - - if self.verbose: - logger.debug(f"Edge reasoning result: {reasoning_result}") - - # Parse reasoning response (jsonl) and build explainability data - reasoning_map = {} - - if reasoning_result is not None: - for obj in reasoning_result.objects or []: - if isinstance(obj, dict) and "id" in obj: - reasoning_map[obj["id"]] = obj.get("reasoning", "") - + # Build focus explainability data with cross-encoder metadata selected_edges_with_reasoning = [] - for eid in selected_ids: + for s, p, o in selected_edges: + eid = edge_id(s, p, o) if eid in uri_map: uri_s, uri_p, uri_o = uri_map[eid] + meta = edge_metadata.get(eid, {}) selected_edges_with_reasoning.append({ "edge": (uri_s, uri_p, uri_o), - "reasoning": reasoning_map.get(eid, ""), + "concept": meta.get("concept", ""), + "score": meta.get("score", 0), }) - if self.verbose: - logger.debug(f"Filtered to {len(selected_edges)} edges") - - # Emit focus explain after edge selection completes + # Emit focus explain if explain_callback: - # Sum scoring + reasoning token usage for focus event - focus_in = 0 - focus_out = 0 - focus_model = None - for r in [scoring_result, reasoning_result]: - if r is not None: - if r.in_token is not None: - focus_in += r.in_token - if r.out_token is not None: - focus_out += r.out_token - if r.model is not None: - focus_model = r.model - foc_triples = set_graph( focus_triples( - foc_uri, exp_uri, selected_edges_with_reasoning, session_id, - in_token=focus_in or None, - out_token=focus_out or None, - model=focus_model, + foc_uri, exp_uri, + selected_edges_with_reasoning, session_id, ), GRAPH_RETRIEVAL ) await explain_callback(foc_triples, foc_uri) - # Step 2: Synthesis - LLM generates answer from selected edges only + # Step 4: Synthesis selected_edge_dicts = [ {"s": s, "p": p, "o": o} for s, p, o in selected_edges ] - # Add source document metadata as knowledge edges for s, p, o in source_documents: selected_edge_dicts.append({ "s": s, "p": p, "o": o, @@ -928,7 +716,6 @@ class GraphRag: } if streaming and chunk_callback: - # Accumulate chunks for answer storage while forwarding to callback accumulated_chunks = [] async def accumulating_callback(chunk, end_of_stream): @@ -942,7 +729,6 @@ class GraphRag: chunk_callback=accumulating_callback ) track_usage(synthesis_result) - # Combine all chunks into full response resp = "".join(accumulated_chunks) else: synthesis_result = await self.prompt_client.prompt( @@ -955,29 +741,42 @@ class GraphRag: if self.verbose: logger.debug("Query processing complete") - # Emit synthesis explain after synthesis completes + # Emit synthesis explain if explain_callback: synthesis_doc_id = None answer_text = resp if resp else "" - # Save answer to librarian if save_answer_callback and answer_text: synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" try: await save_answer_callback(synthesis_doc_id, answer_text) if self.verbose: - logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + logger.debug( + f"Saved answer to librarian: " + f"{synthesis_doc_id}" + ) except Exception as e: - logger.warning(f"Failed to save answer to librarian: {e}") + logger.warning( + f"Failed to save answer to librarian: {e}" + ) synthesis_doc_id = None syn_triples = set_graph( synthesis_triples( syn_uri, foc_uri, document_id=synthesis_doc_id, - in_token=synthesis_result.in_token if synthesis_result else None, - out_token=synthesis_result.out_token if synthesis_result else None, - model=synthesis_result.model if synthesis_result else None, + in_token=( + synthesis_result.in_token + if synthesis_result else None + ), + out_token=( + synthesis_result.out_token + if synthesis_result else None + ), + model=( + synthesis_result.model + if synthesis_result else None + ), ), GRAPH_RETRIEVAL ) @@ -993,4 +792,3 @@ class GraphRag: } return resp, usage - diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 959ae8e0..27ec4937 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -13,6 +13,7 @@ from . graph_rag import GraphRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -32,7 +33,6 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) - edge_score_limit = params.get("edge_score_limit", 30) edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( @@ -43,7 +43,6 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, - "edge_score_limit": edge_score_limit, "edge_limit": edge_limit, } ) @@ -52,7 +51,6 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length - self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit # Workspace isolation is enforced by the flow layer (flow.workspace). @@ -96,6 +94,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -163,6 +168,7 @@ class Processor(FlowProcessor): graph_embeddings_client=flow("graph-embeddings-request"), triples_client=flow("triples-request"), prompt_client=flow("prompt-request"), + reranker_client=flow("reranker-request"), verbose=True, ) @@ -186,11 +192,6 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length - if v.edge_score_limit: - edge_score_limit = v.edge_score_limit - else: - edge_score_limit = self.default_edge_score_limit - if v.edge_limit: edge_limit = v.edge_limit else: @@ -225,7 +226,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, streaming = True, chunk_callback = send_chunk, @@ -241,7 +242,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, @@ -338,18 +339,11 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) - parser.add_argument( - '--edge-score-limit', - type=int, - default=30, - help=f'Semantic pre-filter limit before LLM scoring (default: 30)' - ) - parser.add_argument( '--edge-limit', type=int, default=25, - help=f'Max edges after LLM scoring (default: 25)' + help=f'Max edges selected per hop by cross-encoder (default: 25)' ) # Note: Explainability triples are now stored in the request's collection