From 01cc8dbc640a0dc89a1bbf0148a802704b9f7ab7 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 30 Jun 2026 14:36:37 +0100 Subject: [PATCH 1/9] feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005) Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests --- docs/tech-specs/graph-rag-semantic-filter.md | 523 ++++++++++++++++ .../integration/test_graph_rag_integration.py | 18 +- .../test_graph_rag_streaming_integration.py | 16 +- .../test_rag_streaming_protocol.py | 19 +- .../test_base/test_prompt_client_streaming.py | 32 - .../test_provenance/test_dag_structure.py | 24 +- tests/unit/test_retrieval/test_graph_rag.py | 341 ++++------- .../test_graph_rag_provenance_integration.py | 97 ++- trustgraph-base/trustgraph/api/async_flow.py | 10 + .../trustgraph/api/async_socket_client.py | 13 + .../trustgraph/api/explainability.py | 21 +- trustgraph-base/trustgraph/api/flow.py | 13 + .../trustgraph/api/socket_client.py | 13 + trustgraph-base/trustgraph/base/__init__.py | 2 + .../trustgraph/base/prompt_client.py | 15 - .../trustgraph/base/reranker_client.py | 43 ++ .../trustgraph/base/reranker_service.py | 109 ++++ .../trustgraph/clients/prompt_client.py | 14 - .../trustgraph/messaging/__init__.py | 7 + .../messaging/translators/__init__.py | 1 + .../messaging/translators/reranker.py | 73 +++ .../trustgraph/provenance/__init__.py | 8 +- .../trustgraph/provenance/namespaces.py | 4 + .../trustgraph/provenance/triples.py | 28 +- .../trustgraph/provenance/vocabulary.py | 3 + .../trustgraph/schema/services/__init__.py | 3 +- .../trustgraph/schema/services/prompt.py | 13 +- .../trustgraph/schema/services/reranker.py | 35 ++ trustgraph-cli/pyproject.toml | 1 + .../trustgraph/cli/invoke_graph_rag.py | 7 +- .../trustgraph/cli/invoke_reranker.py | 127 ++++ .../trustgraph/cli/show_explain_trace.py | 9 +- trustgraph-flow/pyproject.toml | 2 + .../trustgraph/gateway/dispatch/manager.py | 2 + .../trustgraph/gateway/dispatch/reranker.py | 31 + .../trustgraph/gateway/registry.py | 1 + trustgraph-flow/trustgraph/iam/service/iam.py | 1 + .../trustgraph/reranker/__init__.py | 1 + .../trustgraph/reranker/flashrank/__init__.py | 2 + .../trustgraph/reranker/flashrank/__main__.py | 6 + .../reranker/flashrank/processor.py | 109 ++++ .../retrieval/graph_rag/graph_rag.py | 578 ++++++------------ .../trustgraph/retrieval/graph_rag/rag.py | 30 +- 43 files changed, 1613 insertions(+), 792 deletions(-) create mode 100644 docs/tech-specs/graph-rag-semantic-filter.md create mode 100644 trustgraph-base/trustgraph/base/reranker_client.py create mode 100644 trustgraph-base/trustgraph/base/reranker_service.py create mode 100644 trustgraph-base/trustgraph/messaging/translators/reranker.py create mode 100644 trustgraph-base/trustgraph/schema/services/reranker.py create mode 100644 trustgraph-cli/trustgraph/cli/invoke_reranker.py create mode 100644 trustgraph-flow/trustgraph/gateway/dispatch/reranker.py create mode 100644 trustgraph-flow/trustgraph/reranker/__init__.py create mode 100644 trustgraph-flow/trustgraph/reranker/flashrank/__init__.py create mode 100644 trustgraph-flow/trustgraph/reranker/flashrank/__main__.py create mode 100644 trustgraph-flow/trustgraph/reranker/flashrank/processor.py diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md new file mode 100644 index 00000000..0401947e --- /dev/null +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -0,0 +1,523 @@ +# GraphRAG Semantic Filter Improvement + +## Problem Statement + +The GraphRAG semantic filter is observed to be ineffective with certain +LLM models. Smaller models in particular produce poor-quality edge +relevance scores, and there is a suspicion that models trained or +evaluated heavily on non-Roman-script datasets offer lower performance +on the semantic ranking operation. + +The root cause is that the current implementation delegates edge +relevance scoring to the LLM via a prompt that asks the model to +assign a 1–10 relevance score to each knowledge-graph edge. This +task — ranking structured triples for relevance to a natural-language +query — is not well covered in standard LLM evaluation suites, so +model benchmark scores are not predictive of performance on this +operation. The result is that GraphRAG quality varies unpredictably +across model choices, undermining confidence in the pipeline. + +Beyond model variability, the LLM scoring step has further problems: + +- **Cost and latency.** The LLM call consumes tokens and adds + latency to every query, yet its output is unreliable. Even when + the model performs well, the cost is disproportionate for what is + fundamentally a ranking operation. + +- **Subjective scoring scale.** The 1–10 relevance scale gives the + model no objective criteria for what constitutes a 5 versus a 7. + Different models interpret the scale differently, and even the same + model can produce inconsistent scores across runs. + +- **Redundancy with the embedding pre-filter.** The pipeline already + contains a cosine-similarity stage that ranks edges by semantic + relevance using embeddings. The LLM scoring step is a second + filter applied on top of this, and it is not clear that it adds + enough value to justify the additional cost and risk of + degradation. + +### Industry context + +Semantic ranking is rigorously evaluated on dedicated benchmarks such +as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking +Information Retrieval), which test retrieval and reranking across +diverse domains. The current TrustGraph approach — prompting a +general-purpose LLM to score and rank documents (the "listwise" +approach) — is known to be poorly optimized for this task. It +suffers from positional bias, formatting failures, and +inconsistency at scale. + +The industry standard for semantic ranking has moved to +cross-encoder models: lightweight, purpose-built models that take a +query–document pair as input and produce a single relevance score. +These models are fine-tuned on millions of relevance-labelled pairs +and dominate retrieval benchmarks. They are fast, deterministic, +and do not require an LLM inference call. + +## Architecture + +### Cross-encoder service + +A new request/response service that exposes a generic semantic +ranking API. The service is not specific to GraphRAG — it is a +reusable building block for any component that needs to rank text +by relevance. + +The service interface is pluggable. Alternative implementations +can be swapped in behind the same API. + +**Packaging options considered:** + +- *`sentence-transformers`.* Full-featured, widely used. + However, it pulls in PyTorch (~2 GB), making containers + very large. Tested at ~1.8 seconds for 2200 edges. + +- *`optimum.onnxruntime`.* ONNX-based inference. Still + depends on PyTorch at import time despite using ONNX for + inference. Tested at ~4.2 seconds for 2200 edges. + +- *`flashrank`.* Lightweight wrapper around ONNX Runtime + with a clean API (`Ranker`, `RerankRequest`). No PyTorch + dependency. Tested at ~4.4 seconds for 2200 edges. + +- *Pure `onnxruntime` + `tokenizers`.* Leanest option + (~200 MB total). Requires manual tokenisation, padding, + and numpy array management — more boilerplate to maintain. + +- *External API (e.g. Cohere Rerank).* No local model at + all. Adds network latency and an external dependency. + +**Decision:** `flashrank` for the initial implementation. +No PyTorch dependency, clean API, comparable performance. +The pluggable interface allows swapping to another backend +later. + +**Request:** + +- `queries` — list of `{id, text}` objects. In the GraphRAG use + case these are the concepts extracted from the user's question. +- `documents` — list of `{id, text}` objects. In the GraphRAG + use case these are the candidate knowledge-graph edges + represented as text. +- `limit` — integer. Maximum number of results to return. + +**Scoring:** + +The service produces the cartesian product of all query–document +pairs and scores each pair through the cross-encoder model. For +each document, the maximum score across all queries is taken as the +document's relevance score. Documents are then ranked by this +score and the top `limit` results are returned. + +**Response:** + +A list of the top `limit` results, each containing: + +- `document_id` — the ID of the matched document. +- `query_id` — the ID of the query (concept) that produced the + highest score for this document. +- `score` — the relevance score. + +Including `query_id` in the response supports the explainability +interface: it records that an edge was selected because it is +related to a specific concept. + +### Integration + +The cross-encoder service follows the standard TrustGraph service +integration pattern: + +- **Base package (trustgraph-base).** Schema definitions for the + cross-encoder request/response messages. A client class that + other components (e.g. GraphRAG) can use to call the + cross-encoder service. Message translator registration so the + pub/sub layer can serialise/deserialise the messages. + +- **Flow package (trustgraph-flow).** The cross-encoder service + implementation itself — loads the model, listens for requests, + scores pairs, returns results. Flow definition support so the + cross-encoder can be introduced into a processing flow via the + standard flow configuration. `flashrank` is added as a + dependency of `trustgraph-flow`. The service runs in its own + container. + +- **API gateway.** A gateway endpoint that routes cross-encoder + requests from the HTTP API to the service over pub/sub and + returns the response. + +- **CLI tool.** A command-line utility + (e.g. `tg-invoke-cross-encoder`) that calls the gateway + endpoint for manual testing and debugging. + +### Current GraphRAG pipeline + +The current pipeline follows these steps: + +1. **Concept extraction.** An LLM prompt extracts key concepts + from the user's query. + +2. **Graph exploration.** Seed entities are found via embedding + similarity. A subgraph is built by multi-hop traversal from + the seed entities (up to `max_path_length` hops, capped at + `max_subgraph_size` edges). + +3. **Embedding pre-filter.** Each edge is embedded as + `"subject, predicate, object"` and scored by cosine similarity + against the concept embeddings. The top `edge_score_limit` + (default 30) edges are kept. + +4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the + LLM to assign a 1–10 relevance score to each remaining edge. + The top `edge_limit` (default 25) edges are kept. + +5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks + the LLM to explain why each selected edge is relevant to the + query. Used for the explainability interface. + +6. **Document tracing.** Selected edges are traced back to their + source documents in the librarian. Runs concurrently with + step 5. + +7. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Potential improvements + +#### Replace LLM edge scoring with cross-encoder (step 4) + +The LLM edge scoring step is replaced by a call to the +cross-encoder service. The candidate edges are the documents and +`edge_limit` is the limit. This is a direct substitution: faster, +cheaper, deterministic, and more reliable across model choices. +The LLM `kg-edge-scoring` prompt is retired. + +**Cross-encoder query input: concepts vs. raw query.** There are +two options for what to use as the cross-encoder queries: + +- *Option A: Raw user query.* Pass the original question as a + single query string. Simpler, no dependency on concept + extraction. However, raw queries contain noise words and + conversational phrasing that do not match well against the + structured vocabulary of knowledge-graph edges. A single query + also means every edge competes against the full question — a + partial match on one aspect is diluted. + +- *Option B: Extracted concepts.* Pass the concepts from step 1 + as separate queries. The concepts are distilled, focused terms + that are closer to the language of the edges. With multiple + concepts as independent queries, the cross-encoder scores each + edge against each concept separately, giving better coverage — + an edge only needs to match one concept well to be selected. + The trade-off is a dependency on the LLM concept extraction + step, but this is already in the pipeline and is a lightweight, + reliable LLM call. + +**Decision:** Option B — use extracted concepts. The concept +extraction is fast, and the resulting terms produce better +cross-encoder matches against structured triples. + +#### Edge text representation + +The current embedding pre-filter represents each edge as +`"subject, predicate, object"`. Two changes: + +- **Drop commas.** Commas add tokenisation noise without semantic + value. + +- **Drop the subject.** The subject identifies which entity the + edge belongs to, but it does not contribute to whether the + edge's content is relevant to the query. The predicate and + object carry the semantic meaning — what relationship exists + and what it connects to. Representing edges as `"{p} {o}"` + produces cleaner cross-encoder matches. + +#### Remove the embedding pre-filter (step 3) + +The embedding pre-filter was introduced to reduce the number of +edges before the expensive LLM scoring call. With the +cross-encoder replacing the LLM call, this cost equation changes. + +**Arguments for removal:** + +- The cross-encoder is fast enough to score the full subgraph + directly. In testing, 2200 edges scored in ~1.8 seconds; at + the default `max_subgraph_size` of 150 edges, scoring takes + a fraction of a second. + +- The pre-filter is a weaker version of what the cross-encoder + does. Bi-encoder cosine similarity embeds the query and + document independently and compares vectors; the cross-encoder + processes both texts together through the full transformer, + giving it much better relevance judgement. Running a weaker + filter before a stronger one adds latency without improving + quality. + +- Removing it eliminates an embedding service call (two batches: + concepts + edges) and the associated latency. + +**Arguments for keeping it:** + +- If the subgraph is very large (thousands of edges), the + cross-encoder's linear scaling could become a bottleneck. + The pre-filter would act as a safety valve. + +- The embedding call is cheap compared to an LLM call, so the + overhead is modest. + +**Decision:** Remove the pre-filter. The `max_subgraph_size` +parameter (default 150) already caps the number of edges entering +this stage, so the cross-encoder will not face an unbounded +workload. If very large subgraphs become a concern in future, +the pre-filter can be reintroduced or `max_subgraph_size` can be +tuned. + +#### Iterative graph traversal with cross-encoder filtering + +The current pipeline performs graph exploration and edge filtering +as separate phases: first build the full subgraph (up to +`max_path_length` hops), then score and filter edges. An +alternative is to interleave traversal and filtering — at each +hop, use the cross-encoder to select relevant edges before +expanding further. + +**Option A: Big-bang traversal then filter.** Traverse the full +subgraph up to `max_path_length` hops from the seed entities, +collecting all edges up to `max_subgraph_size`. Then +cross-encode the entire result to select the top edges. + +- Simple to implement — the current traversal logic is largely + unchanged. +- Produces large, unfocused subgraphs. Irrelevant branches are + explored and scored even though they will be discarded. +- Poorly suited to multi-hop reasoning. For a query about + Voyager 1, the subgraph includes Voyager 2's edges because + they are within hop distance, and the filter must then + separate them. + +**Option B: Iterative hop-and-filter.** At each hop: + +1. Retrieve all edges one hop from the current frontier nodes. +2. Cross-encode these edges against the query concepts. +3. Select the top relevant edges. +4. The target nodes of the selected edges become the frontier + for the next hop. +5. Repeat up to `max_path_length` hops. + +The final set of selected edges across all hops is the input to +synthesis. + +- **Guided exploration.** Each hop focuses the search by + pruning irrelevant branches before expanding further. The + working set stays small and relevant at every step. +- **Multi-hop reasoning works naturally.** Following + "Voyager 1 → has-event → crossed the heliopause" succeeds + because each hop is individually relevant and leads to the + next. +- **Smaller total workload.** Fewer edges are scored overall + because irrelevant branches are never expanded. +- **Trade-off: greedy pruning.** An edge discarded at hop 1 + cannot lead to relevant edges at hop 2. This is inherent in + any bounded traversal, and the cross-encoder is better + equipped to make this relevance judgement than a blind hop + limit. +- **Trade-off: sequential latency.** Hops cannot be + parallelised since each depends on the previous. However, + each cross-encoder call on a small edge set is very fast + (sub-second for typical working sets). + +**Decision:** Option B — iterative hop-and-filter. The guided +traversal produces more focused subgraphs and supports multi-hop +reasoning, which is a significant quality improvement over the +current approach. + +#### Replace LLM edge reasoning with cross-encoder metadata (step 5) + +The current `kg-edge-reasoning` prompt asks the LLM to explain why +each edge is relevant. With the cross-encoder now making the +selection, this explanation would be a post-hoc fabrication — the +LLM was not involved in the decision. + +- *Option A: Keep LLM reasoning.* Generates natural-language + explanations but they are not grounded in the actual selection + process. Adds an LLM call per query. + +- *Option B: Record cross-encoder metadata.* The cross-encoder + already returns the matched concept and score for each selected + edge. Use this directly as the explanation. + +**Decision:** Option B. The cross-encoder metadata is the true +reason the edge was selected. The `kg-edge-reasoning` prompt is +retired. + +#### Explainability interface update + +The explainability interface uses a `Focus` entity containing +`EdgeSelection` sub-entities. Each `EdgeSelection` currently +carries an `edge` (the quoted triple) and a `reasoning` field +(free-text LLM prose), stored as `tg:reasoning` in the +provenance graph. + +With the cross-encoder replacing LLM reasoning, the +`EdgeSelection` type gains two new predicates and drops one: + +- **Remove** `tg:reasoning` — no longer produced. +- **Add** `tg:concept` — the concept text that produced the + highest cross-encoder score for this edge. +- **Add** `tg:score` — the cross-encoder relevance score. + +This is an evolution of the existing `EdgeSelection` type, not a +new entity type. The edge selection sub-entities currently have +no `rdf:type` declared; a new `tg:EdgeSelection` type should be +added so that consumers can identify them in the provenance +graph. The `Focus` entity and its relationship to `Exploration` +are unchanged. + +The `Focus` entity's token-usage metadata (`tg:inToken`, +`tg:outToken`, `tg:llmModel`) no longer applies since there is +no LLM call. These fields are dropped from the Focus entity. + +### Proposed pipeline + +1. **Concept extraction.** Unchanged — LLM extracts key concepts + from the user's query. + +2. **Seed entity lookup.** Find seed entities via embedding + similarity against the extracted concepts. + +3. **Iterative hop-and-filter.** For each hop up to + `max_path_length`: + + a. Retrieve all edges one hop from the current frontier nodes. + + b. Represent each edge as `"{predicate} {object}"`. + + c. Score edges against the extracted concepts using the + cross-encoder service. + + d. Select the top relevant edges. The target nodes of the + selected edges become the frontier for the next hop. + +4. **Document tracing.** Selected edges are traced back to source + documents. + +5. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Implementation order + +1. Cross-encoder service with full integration (base schema, + flow service, gateway endpoint, CLI tool). +2. GraphRAG pipeline changes (iterative hop-and-filter, + edge representation, remove pre-filter). +3. Explainability update (`tg:EdgeSelection` type, concept + and score predicates, retire `tg:reasoning`). +4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts. +5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace` + to display the new metadata. Use these as the main + end-to-end test. +6. Fix any failing unit tests, then add new tests as needed. +7. Write guidance for UX devs to update the UI for the new + explainability predicates. + +## UX developer guidance + +This section describes the changes to the explainability interface +that affect frontend rendering of GraphRAG Focus events. + +### What changed + +Edge selection in GraphRAG previously used LLM-based scoring and +reasoning. Each selected edge carried a `tg:reasoning` predicate +with free-text explanation from the LLM. This has been replaced +by a cross-encoder reranker that scores edges against query +concepts. The explainability data now carries structured metadata +instead of free text. + +### Removed + +- **`tg:reasoning`** is no longer emitted on edge selection + entities in GraphRAG Focus events. UX code that reads + `edge_sel.reasoning` will get an empty string. Remove any + rendering that displays a "Reasoning" or "Reason" field for + Focus edges. + +- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and + **`kg-edge-selection`** prompts are retired. Any UX that + references these prompt names should be cleaned up. + +### Added + +Each edge selection entity within a Focus event now has three +new properties: + +| RDF predicate | API field | Type | Description | +|---|---|---|---| +| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed | +| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge | +| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) | + +The `tg:edge` predicate (RDF-star quoted triple) is unchanged. + +### How to render + +The recommended rendering for each selected edge in a Focus event: + +``` +Edge: (subject_label, predicate_label, object_label) + Concept: Score: +``` + +Scores near 1.0 indicate high relevance; scores near 0.0 indicate +low relevance. UX could use the score to drive visual indicators +such as colour intensity or a relevance bar. + +Edges are not returned in score order — they arrive in traversal +order across hops. If the UX wants to display edges ranked by +relevance, sort by `edge_sel.score` descending. + +### API classes (Python) + +The `EdgeSelection` dataclass in `trustgraph.api.explainability` +has these fields: + +```python +@dataclass +class EdgeSelection: + uri: str + edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...} + reasoning: str = "" # Legacy, always empty for new traces + concept: str = "" # Query concept that matched + score: Optional[float] = None # Cross-encoder relevance score +``` + +These are populated when calling +`ExplainabilityClient.fetch_focus_with_edges()` or when parsing +inline provenance triples from the streaming response. + +### WebSocket response format + +For inline explainability via the streaming WebSocket, Focus events +arrive as `message_type: "explain"` responses. The `explain_triples` +array contains the edge selection triples. The relevant predicates +in wire format are: + +```json +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"}, + "o": {"t": "l", "v": "flyby event"}} + +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/score"}, + "o": {"t": "l", "v": "0.9962"}} +``` + +Note that `tg:score` is transmitted as a string literal and must +be parsed to a float on the client side. + +### Exploration event + +The Exploration event's `edge_count` field now reports the number +of edges selected by the cross-encoder across all hops (previously +it reported the total number of edges retrieved before filtering). +The `entities` list continues to report the seed entities found +by vector search. diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 696df7ec..8930d159 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -95,10 +95,6 @@ class TestGraphRagIntegration: async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": return PromptResult( response_type="text", @@ -113,14 +109,22 @@ class TestGraphRagIntegration: client.prompt.side_effect = mock_prompt return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_prompt_client): + mock_triples_client, mock_reranker_client, mock_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_prompt_client, verbose=True ) @@ -167,8 +171,8 @@ class TestGraphRagIntegration: # 3. Should query triples to build knowledge subgraph assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis) - assert mock_prompt_client.prompt.call_count == 4 + # 4. Should call prompt twice (extract-concepts + synthesis) + assert mock_prompt_client.prompt.call_count == 2 # Verify final response response, usage = response diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 48e26618..8dfd9c2b 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -63,11 +63,6 @@ class TestGraphRagStreaming: async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): if prompt_id == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_id == "kg-edge-scoring": - # Edge scoring returns JSONL with IDs and scores - return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n') - elif prompt_id == "kg-edge-reasoning": - return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n') elif prompt_id == "kg-synthesis": if streaming and chunk_callback: # Simulate streaming chunks with end_of_stream flags @@ -88,14 +83,23 @@ class TestGraphRagStreaming: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with streaming support""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=True ) diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 279c81ef..efce5922 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol: client = AsyncMock() async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=False ) @@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases: client = AsyncMock() async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases: client.prompt.side_effect = prompt_with_empties + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [] + rag = GraphRag( embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])), graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), triples_client=AsyncMock(query=AsyncMock(return_value=[])), + reranker_client=mock_reranker, prompt_client=client, verbose=False ) diff --git a/tests/unit/test_base/test_prompt_client_streaming.py b/tests/unit/test_base/test_prompt_client_streaming.py index 83a4b90e..fecf6095 100644 --- a/tests/unit/test_base/test_prompt_client_streaming.py +++ b/tests/unit/test_base/test_prompt_client_streaming.py @@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback: assert callback.call_args_list[0] == call("test", False) assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio - async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client): - """Test that kg_prompt correctly passes streaming parameters""" - # Arrange - async def mock_request(request, recipient=None, timeout=600): - if recipient: - responses = [ - PromptResponse(text="Answer", object=None, error=None, end_of_stream=False), - PromptResponse(text="", object=None, error=None, end_of_stream=True), - ] - for resp in responses: - should_stop = await recipient(resp) - if should_stop: - break - - prompt_client.request = mock_request - - callback = AsyncMock() - - # Act - await prompt_client.kg_prompt( - query="What is machine learning?", - kg=[("subject", "predicate", "object")], - streaming=True, - chunk_callback=callback - ) - - # Assert - assert callback.call_count == 2 - assert callback.call_args_list[0] == call("Answer", False) - assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio async def test_document_prompt_passes_parameters_to_callback(self, prompt_client): """Test that document_prompt correctly passes streaming parameters""" diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index e65ef2e3..d1ce097a 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -107,6 +107,7 @@ class TestGraphRagDagStructure: embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() embeddings_client.embed.return_value = [[0.1, 0.2]] graph_embeddings_client.query.return_value = [ @@ -121,27 +122,22 @@ class TestGraphRagDagStructure: ] triples_client.query.return_value = [] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + reranker_client.rerank.return_value = [result] + async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": return PromptResult(response_type="text", text="concept") - elif template_id == "kg-edge-scoring": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "score": 10} for e in edges], - ) - elif template_id == "kg-edge-reasoning": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges], - ) elif template_id == "kg-synthesis": return PromptResult(response_type="text", text="Answer.") return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) @pytest.mark.asyncio async def test_dag_chain(self, mock_clients): @@ -152,7 +148,7 @@ class TestGraphRagDagStructure: events.append({"explain_id": explain_id, "triples": triples}) await rag.query( - query="test", explain_callback=explain_cb, edge_score_limit=0, + query="test", explain_callback=explain_cb, ) dag = _collect_events(events) diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index d1979211..15ffdc9d 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -15,54 +15,52 @@ class TestGraphRag: def test_graph_rag_initialization_with_defaults(self): """Test GraphRag initialization with default verbose setting""" - # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() - # Initialize GraphRag - graph_rag = GraphRag( - prompt_client=mock_prompt_client, - embeddings_client=mock_embeddings_client, - graph_embeddings_client=mock_graph_embeddings_client, - triples_client=mock_triples_client - ) - - # Verify initialization - assert graph_rag.prompt_client == mock_prompt_client - assert graph_rag.embeddings_client == mock_embeddings_client - assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client - assert graph_rag.triples_client == mock_triples_client - assert graph_rag.verbose is False # Default value - # Verify label_cache is an LRUCacheWithTTL instance - from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL - assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) - - def test_graph_rag_initialization_with_verbose(self): - """Test GraphRag initialization with verbose enabled""" - # Create mock clients - mock_prompt_client = MagicMock() - mock_embeddings_client = MagicMock() - mock_graph_embeddings_client = MagicMock() - mock_triples_client = MagicMock() - - # Initialize GraphRag with verbose=True graph_rag = GraphRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=True + reranker_client=mock_reranker_client, ) - # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client + assert graph_rag.verbose is False + from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL + assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) + + def test_graph_rag_initialization_with_verbose(self): + """Test GraphRag initialization with verbose enabled""" + mock_prompt_client = MagicMock() + mock_embeddings_client = MagicMock() + mock_graph_embeddings_client = MagicMock() + mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() + + graph_rag = GraphRag( + prompt_client=mock_prompt_client, + embeddings_client=mock_embeddings_client, + graph_embeddings_client=mock_graph_embeddings_client, + triples_client=mock_triples_client, + reranker_client=mock_reranker_client, + verbose=True, + ) + + assert graph_rag.prompt_client == mock_prompt_client + assert graph_rag.embeddings_client == mock_embeddings_client + assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client + assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client assert graph_rag.verbose is True - # Verify label_cache is an LRUCacheWithTTL instance from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) @@ -365,244 +363,162 @@ class TestQuery: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_never_passes_workspace(self): - """Verify follow_edges never passes workspace to query_stream.""" + async def test_hop_and_filter_never_passes_workspace(self): + """Verify hop_and_filter never passes workspace to query_stream.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None mock_triple = MagicMock() - mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triple.s = "e1" + mock_triple.p = "p1" + mock_triple.o = "o1" mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] + + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.9 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, ) - subgraph = set() - await query.follow_edges("e1", subgraph, path_length=1) + await query.hop_and_filter(["e1"], ["concept"]) for c in mock_triples_client.query_stream.call_args_list: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_basic_functionality(self): - """Test Query.follow_edges method basic triple discovery""" + async def test_hop_and_filter_basic_functionality(self): + """Test hop_and_filter retrieves edges and scores them with reranker.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - mock_triple1 = MagicMock() - mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" + mock_triple = MagicMock() + mock_triple.s = "entity1" + mock_triple.p = "predicate1" + mock_triple.o = "object1" + mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] - mock_triple2 = MagicMock() - mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" - - mock_triple3 = MagicMock() - mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - - mock_triples_client.query_stream.side_effect = [ - [mock_triple1], # s=ent - [mock_triple2], # p=ent - [mock_triple3], # o=ent - ] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, + edge_limit=25, ) - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=1) - - assert mock_triples_client.query_stream.call_count == 3 - - mock_triples_client.query_stream.assert_any_call( - s="entity1", p=None, o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p="entity1", o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p=None, o="entity1", limit=10, - collection="test_collection", batch_size=20, g="" + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["test concept"], ) - expected_subgraph = { - ("entity1", "predicate1", "object1"), - ("subject2", "entity1", "object2"), - ("subject3", "predicate3", "entity1") - } - assert subgraph == expected_subgraph + assert len(selected) == 1 + assert len(uri_map) == 1 + assert len(edge_meta) == 1 + + mock_reranker_client.rerank.assert_called_once() + call_kwargs = mock_reranker_client.rerank.call_args + assert call_kwargs.kwargs["limit"] == 25 @pytest.mark.asyncio - async def test_follow_edges_with_path_length_zero(self): - """Test Query.follow_edges method with path_length=0""" + async def test_hop_and_filter_with_empty_frontier(self): + """Test hop_and_filter with no seed entities returns empty.""" + mock_rag = MagicMock() + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + ) + + selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"]) + + assert selected == [] + assert uri_map == {} + assert edge_meta == {} + + @pytest.mark.asyncio + async def test_hop_and_filter_filters_label_triples(self): + """Test hop_and_filter skips rdfs:label edges.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False - ) + label_triple = MagicMock() + label_triple.s = "entity1" + label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label" + label_triple.o = "Entity One" - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=0) - - mock_triples_client.query_stream.assert_not_called() - assert subgraph == set() - - @pytest.mark.asyncio - async def test_follow_edges_with_max_subgraph_size_limit(self): - """Test Query.follow_edges method respects max_subgraph_size""" - mock_rag = MagicMock() - mock_triples_client = AsyncMock() - mock_rag.triples_client = mock_triples_client + mock_triples_client.query_stream.return_value = [label_triple] + mock_triples_client.query.return_value = [] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - max_subgraph_size=2 + triple_limit=10, ) - subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} - - await query.follow_edges("entity1", subgraph, path_length=1) - - mock_triples_client.query_stream.assert_not_called() - assert len(subgraph) == 3 - - @pytest.mark.asyncio - async def test_get_subgraph_method(self): - """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_path_length=1 + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["concept"], ) - # Mock get_entities to return (entities, concepts) tuple - query.get_entities = AsyncMock( - return_value=(["entity1", "entity2"], ["concept1"]) - ) - - query.follow_edges_batch = AsyncMock(return_value=( - { - ("entity1", "predicate1", "object1"), - ("entity2", "predicate2", "object2") - }, - {} - )) - - subgraph, term_map, entities, concepts = await query.get_subgraph("test query") - - query.get_entities.assert_called_once_with("test query") - query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) - - assert isinstance(subgraph, list) - assert len(subgraph) == 2 - assert ("entity1", "predicate1", "object1") in subgraph - assert ("entity2", "predicate2", "object2") in subgraph - assert entities == ["entity1", "entity2"] - assert concepts == ["concept1"] - - @pytest.mark.asyncio - async def test_get_labelgraph_method(self): - """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_subgraph_size=100 - ) - - test_subgraph = [ - ("entity1", "predicate1", "object1"), - ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), - ("entity3", "predicate3", "object3") - ] - test_entities = ["entity1", "entity3"] - test_concepts = ["concept1"] - query.get_subgraph = AsyncMock( - return_value=(test_subgraph, {}, test_entities, test_concepts) - ) - - async def mock_maybe_label(entity): - label_map = { - "entity1": "Human Entity One", - "predicate1": "Human Predicate One", - "object1": "Human Object One", - "entity3": "Human Entity Three", - "predicate3": "Human Predicate Three", - "object3": "Human Object Three" - } - return label_map.get(entity, entity) - - query.maybe_label = AsyncMock(side_effect=mock_maybe_label) - - labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query") - - query.get_subgraph.assert_called_once_with("test query") - - # Label triples filtered out - assert len(labeled_edges) == 2 - - # maybe_label called for non-label triples - assert query.maybe_label.call_count == 6 - - expected_edges = [ - ("Human Entity One", "Human Predicate One", "Human Object One"), - ("Human Entity Three", "Human Predicate Three", "Human Object Three") - ] - assert labeled_edges == expected_edges - - assert len(uri_map) == 2 - assert entities == test_entities - assert concepts == test_concepts + assert selected == [] + mock_reranker_client.rerank.assert_not_called() @pytest.mark.asyncio async def test_graph_rag_query_method(self): """Test GraphRag.query method orchestrates full RAG pipeline with provenance""" - import json from trustgraph.retrieval.graph_rag.graph_rag import edge_id mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() expected_response = "This is the RAG response" - test_labelgraph = [("Subject", "Predicate", "Object")] - test_edge_id = edge_id("Subject", "Predicate", "Object") + test_selected_edges = [("Subject", "Predicate", "Object")] + test_eid = edge_id("Subject", "Predicate", "Object") test_uri_map = { - test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + } + test_edge_metadata = { + test_eid: {"concept": "test concept", "score": 0.95} } - test_entities = ["http://example.org/subject"] - test_concepts = ["test concept"] - # Mock prompt responses for the multi-step process + mock_embeddings_client.embed.return_value = [[0.1, 0.2]] + mock_graph_embeddings_client.query.return_value = [] + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}]) - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}]) + return PromptResult(response_type="text", text="test concept") elif prompt_name == "kg-synthesis": return PromptResult(response_type="text", text=expected_response) return PromptResult(response_type="text", text="") @@ -614,16 +530,16 @@ class TestQuery: embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=False + reranker_client=mock_reranker_client, + verbose=False, ) - # Patch Query.get_labelgraph to return test data - original_get_labelgraph = Query.get_labelgraph + original_hop_and_filter = Query.hop_and_filter - async def mock_get_labelgraph(self, query_text): - return test_labelgraph, test_uri_map, test_entities, test_concepts + async def mock_hop_and_filter(self, seed_entities, concepts): + return test_selected_edges, test_uri_map, test_edge_metadata - Query.get_labelgraph = mock_get_labelgraph + Query.hop_and_filter = mock_hop_and_filter provenance_events = [] @@ -636,7 +552,7 @@ class TestQuery: collection="test_collection", entity_limit=25, triple_limit=15, - explain_callback=collect_provenance + explain_callback=collect_provenance, ) response_text, usage = response @@ -650,7 +566,6 @@ class TestQuery: assert len(triples) > 0 assert prov_id.startswith("urn:trustgraph:") - # Verify order assert "question" in provenance_events[0][1] assert "grounding" in provenance_events[1][1] assert "exploration" in provenance_events[2][1] @@ -658,4 +573,4 @@ class TestQuery: assert "synthesis" in provenance_events[4][1] finally: - Query.get_labelgraph = original_get_labelgraph + Query.hop_and_filter = original_hop_and_filter diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py index 1eb0dd72..bc2cb368 100644 --- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import ( TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE, TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT, - TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION, ) @@ -91,17 +91,17 @@ def build_mock_clients(): 1. prompt_client.prompt("extract-concepts", ...) -> concepts 2. embeddings_client.embed(concepts) -> vectors 3. graph_embeddings_client.query(vector, ...) -> entity matches - 4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch) + 4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter) 5. triples_client.query(s, LABEL, ...) -> labels (maybe_label) - 6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges - 7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning - 8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) - 9. prompt_client.prompt("kg-synthesis", ...) -> final answer + 6. reranker_client.rerank(queries, documents, limit) -> scored edges + 7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) + 8. prompt_client.prompt("kg-synthesis", ...) -> final answer """ prompt_client = AsyncMock() embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() # 1. Concept extraction prompt_responses = {} @@ -116,7 +116,7 @@ def build_mock_clients(): EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)), ] - # 4. Triple queries (follow_edges_batch) - return our edges + # 4. Triple queries (hop_and_filter) - return our edges kg_triples = [ make_schema_triple(*EDGE_1), make_schema_triple(*EDGE_2), @@ -130,9 +130,18 @@ def build_mock_clients(): return [] # No labels found, will fall back to URI triples_client.query.side_effect = mock_label_query - # 6+7. Edge scoring and reasoning: dynamically score/reason about - # whatever edges the query method sends us, since edge IDs are computed - # from str(Term) representations which include the full dataclass repr. + # 6. Reranker: select all documents with high scores + async def mock_rerank(queries, documents, limit): + results = [] + for i, doc in enumerate(documents): + result = MagicMock() + result.document_id = doc["id"] + result.query_id = queries[0]["id"] if queries else "0" + result.score = 0.9 - (i * 0.1) + results.append(result) + return results[:limit] + reranker_client.rerank.side_effect = mock_rerank + synthesis_answer = "Quantum computing applies physics principles to computation." async def mock_prompt(template_id, variables=None, **kwargs): @@ -141,26 +150,6 @@ def build_mock_clients(): response_type="text", text=prompt_responses["extract-concepts"], ) - elif template_id == "kg-edge-scoring": - # Score all edges highly, using the IDs that GraphRag computed - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "score": 10 - i} - for i, e in enumerate(edges) - ], - ) - elif template_id == "kg-edge-reasoning": - # Provide reasoning for each edge - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "reasoning": f"Relevant edge {i}"} - for i, e in enumerate(edges) - ], - ) elif template_id == "kg-synthesis": return PromptResult( response_type="text", @@ -170,7 +159,8 @@ def build_mock_clients(): prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) # --------------------------------------------------------------------------- @@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, # skip semantic pre-filter for simplicity + ) assert len(events) == 5, ( @@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) expected_types = [ @@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) uris = [e["explain_id"] for e in events] @@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) q_uri = events[0]["explain_id"] @@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) gnd_uri = events[1]["explain_id"] @@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) exp_uri = events[2]["explain_id"] @@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance: assert int(t.o.value) > 0 @pytest.mark.asyncio - async def test_focus_has_selected_edges_with_reasoning(self): + async def test_focus_has_selected_edges_with_concept_and_score(self): """ The focus event should carry selected edges as quoted triples - with reasoning text. + with cross-encoder concept and score metadata. """ clients = build_mock_clients() rag = GraphRag(*clients) @@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, ) foc_uri = events[3]["explain_id"] @@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance: for t in edge_t: assert t.o.triple is not None, "tg:edge object must be a quoted triple" - # Should have reasoning - reasoning = find_triples(foc_triples, TG_REASONING) - assert len(reasoning) > 0, "Focus should have reasoning for selected edges" - reasoning_texts = {t.o.value for t in reasoning} - assert any(r for r in reasoning_texts), "Reasoning should not be empty" + # Edge selections should be typed as EdgeSelection + edge_sel_uris = [t.o.iri for t in selected] + for uri in edge_sel_uris: + assert has_type(foc_triples, uri, TG_EDGE_SELECTION) + + # Should have concept and score + concepts = find_triples(foc_triples, TG_CONCEPT) + assert len(concepts) > 0, "Focus should have tg:concept for selected edges" + + scores = find_triples(foc_triples, TG_SCORE) + assert len(scores) > 0, "Focus should have tg:score for selected edges" + for t in scores: + float(t.o.value) # Should be parseable as float @pytest.mark.asyncio async def test_synthesis_is_answer_type(self): @@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) syn_uri = events[4]["explain_id"] @@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + parent_uri=parent, ) @@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) for event in events: diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index bf0b2ba1..de592b59 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -646,6 +646,16 @@ class AsyncFlowInstance: return await self.request("embeddings", request_data) + async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any): + request_data = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request_data.update(kwargs) + + return await self.request("reranker", request_data) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any): """ Query RDF triples using pattern matching. diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index fdbe2f67..78b608a7 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -443,6 +443,19 @@ class AsyncSocketFlowInstance: return await self.client._send_request("embeddings", self.flow_id, request) + async def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs): + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return await self.client._send_request( + "reranker", self.flow_id, request, + ) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs): """Triple pattern query""" request = {"limit": limit} diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 656ff95f..74a8f32e 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" TG_CONCEPT = TG + "concept" TG_ENTITY = TG + "entity" @@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" @dataclass class EdgeSelection: - """A selected edge with reasoning from GraphRAG Focus step.""" + """A selected edge with cross-encoder metadata from GraphRAG Focus step.""" uri: str edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...} reasoning: str = "" + concept: str = "" + score: Optional[float] = None @dataclass @@ -209,7 +212,7 @@ class Exploration(ExplainEntity): @dataclass class Focus(ExplainEntity): - """Focus entity - selected edges with LLM reasoning (GraphRAG only).""" + """Focus entity - selected edges with cross-encoder scoring (GraphRAG only).""" selected_edge_uris: List[str] = field(default_factory=list) edge_selections: List[EdgeSelection] = field(default_factory=list) @@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel uri = triples[0][0] if triples else "" edge = None reasoning = "" + concept = "" + score = None for s, p, o in triples: if p == TG_EDGE and isinstance(o, dict): edge = o elif p == TG_REASONING: reasoning = o + elif p == TG_CONCEPT: + concept = o + elif p == TG_SCORE: + try: + score = float(o) + except (ValueError, TypeError): + score = None - return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning) + return EdgeSelection( + uri=uri, edge=edge, reasoning=reasoning, + concept=concept, score=score, + ) def extract_term_value(term: Dict[str, Any]) -> Any: diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 961e348b..886306b3 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -491,6 +491,19 @@ class FlowInstance: input )["vectors"] + def rerank(self, queries, documents, limit=10): + + input = { + "queries": queries, + "documents": documents, + "limit": limit, + } + + return self.request( + "service/reranker", + input + ) + def graph_embeddings_query(self, text, collection, limit=10): """ Query knowledge graph entities using semantic similarity. diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index e87d85ac..3a06e0d8 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -885,6 +885,19 @@ class SocketFlowInstance: return self.client._send_request_sync("embeddings", self.flow_id, request, False) + def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs: Any) -> Dict[str, Any]: + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return self.client._send_request_sync( + "reranker", self.flow_id, request, False, + ) + def triples_query( self, s: Optional[Union[str, Dict[str, Any]]] = None, diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 6062543b..be905116 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService from . tool_service_client import ToolServiceClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec +from . reranker_client import RerankerClientSpec +from . reranker_service import RerankerService from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec from . collection_config_handler import CollectionConfigHandler diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index d4822ece..b1813ba2 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -157,21 +157,6 @@ class PromptClient(RequestResponse): timeout = timeout, ) - async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None): - return await self.prompt( - id = "kg-prompt", - variables = { - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout = timeout, - streaming = streaming, - chunk_callback = chunk_callback, - ) - async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None): return await self.prompt( id = "document-prompt", diff --git a/trustgraph-base/trustgraph/base/reranker_client.py b/trustgraph-base/trustgraph/base/reranker_client.py new file mode 100644 index 00000000..d0bed394 --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_client.py @@ -0,0 +1,43 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, +) + +class RerankerClient(RequestResponse): + async def rerank(self, queries, documents, limit=10, timeout=300): + + resp = await self.request( + RerankerRequest( + queries=[ + RerankerQuery(query_id=q["id"], query_text=q["text"]) + for q in queries + ], + documents=[ + RerankerDocument( + document_id=d["id"], document_text=d["text"] + ) + for d in documents + ], + limit=limit, + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.results + +class RerankerClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(RerankerClientSpec, self).__init__( + request_name = request_name, + request_schema = RerankerRequest, + response_name = response_name, + response_schema = RerankerResponse, + impl = RerankerClient, + ) diff --git a/trustgraph-base/trustgraph/base/reranker_service.py b/trustgraph-base/trustgraph/base/reranker_service.py new file mode 100644 index 00000000..1da3a8bf --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_service.py @@ -0,0 +1,109 @@ + +from __future__ import annotations + +from argparse import ArgumentParser + +import logging + +from .. schema import ( + RerankerRequest, RerankerResponse, RerankerResult, Error, +) +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec + +logger = logging.getLogger(__name__) + +default_ident = "reranker" +default_concurrency = 1 + +class RerankerService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + concurrency = params.get("concurrency", 1) + + super(RerankerService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = RerankerRequest, + handler = self.on_request, + concurrency = concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = RerankerResponse + ) + ) + + self.register_specification( + ParameterSpec( + name = "model", + ) + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + id = msg.properties()["id"] + + logger.debug(f"Handling reranker request {id}...") + + model = flow("model") + results = await self.on_rerank( + request.queries, request.documents, + request.limit, model=model, + ) + + await flow("response").send( + RerankerResponse( + error = None, + results = results, + ), + properties={"id": id} + ) + + logger.debug("Reranker request handled successfully") + + except TooManyRequests as e: + raise e + + except Exception as e: + + logger.error(f"Exception in reranker service: {e}", exc_info=True) + + logger.info("Sending error response...") + + await flow.producer["response"].send( + RerankerResponse( + error=Error( + type = "reranker-error", + message = str(e), + ), + results=[], + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser: ArgumentParser) -> None: + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + + FlowProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index 12c9c194..ff29ec0a 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -140,20 +140,6 @@ class PromptClient(BaseClient): timeout=timeout ) - def request_kg_prompt(self, query, kg, timeout=300): - - return self.request( - id="kg-prompt", - variables={ - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout=timeout - ) - def request_document_prompt(self, query, documents, timeout=300): return self.request( diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 9fcfa6f7..097153ac 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator @@ -163,6 +164,12 @@ TranslatorRegistry.register_service( SparqlQueryResponseTranslator() ) +TranslatorRegistry.register_service( + "reranker", + RerankerRequestTranslator(), + RerankerResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py index 5b5820fa..b0f88e88 100644 --- a/trustgraph-base/trustgraph/messaging/translators/__init__.py +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -20,3 +20,4 @@ from .embeddings_query import ( ) from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .reranker import RerankerRequestTranslator, RerankerResponseTranslator diff --git a/trustgraph-base/trustgraph/messaging/translators/reranker.py b/trustgraph-base/trustgraph/messaging/translators/reranker.py new file mode 100644 index 00000000..2d5dabc2 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/reranker.py @@ -0,0 +1,73 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, RerankerResult, +) +from .base import MessageTranslator + + +class RerankerRequestTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerRequest: + return RerankerRequest( + queries=[ + RerankerQuery( + query_id=q["query_id"], + query_text=q["query_text"], + ) + for q in data.get("queries", []) + ], + documents=[ + RerankerDocument( + document_id=d["document_id"], + document_text=d["document_text"], + ) + for d in data.get("documents", []) + ], + limit=data.get("limit", 10), + ) + + def encode(self, obj: RerankerRequest) -> Dict[str, Any]: + return { + "queries": [ + {"query_id": q.query_id, "query_text": q.query_text} + for q in obj.queries + ], + "documents": [ + {"document_id": d.document_id, "document_text": d.document_text} + for d in obj.documents + ], + "limit": obj.limit, + } + + +class RerankerResponseTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerResponse: + return RerankerResponse( + results=[ + RerankerResult( + document_id=r["document_id"], + query_id=r["query_id"], + score=r["score"], + ) + for r in data.get("results", []) + ], + ) + + def encode(self, obj: RerankerResponse) -> Dict[str, Any]: + return { + "results": [ + { + "document_id": r.document_id, + "query_id": r.query_id, + "score": r.score, + } + for r in obj.results + ], + } + + def encode_with_completion( + self, obj: RerankerResponse + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 051efc66..ce91a3cb 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -89,7 +89,9 @@ from . namespaces import ( TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types @@ -212,7 +214,9 @@ __all__ = [ "TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE", # Query-time provenance predicates (GraphRAG) "TG_QUERY", "TG_CONCEPT", "TG_ENTITY", - "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", + "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE", + # Edge selection entity type + "TG_EDGE_SELECTION", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", # Explainability entity types diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 0b14f1b9..6f81f122 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" # Reference to document in librarian +# Edge selection entity type (cross-encoder scored edge in Focus) +TG_EDGE_SELECTION = TG + "EdgeSelection" + # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 8dedff9a..8e4871c3 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -24,8 +24,10 @@ from . namespaces import ( TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE, TG_DOCUMENT, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, # Explainability entity types @@ -536,10 +538,9 @@ def focus_triples( _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), ] - # Add each selected edge with its reasoning via intermediate entity + # Add each selected edge with metadata via intermediate entity for idx, edge_info in enumerate(selected_edges_with_reasoning): edge = edge_info.get("edge") - reasoning = edge_info.get("reasoning", "") if edge: s, p, o = edge @@ -552,13 +553,32 @@ def focus_triples( _triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri)) ) + # Type the edge selection entity + triples.append( + _triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION)) + ) + # Attach quoted triple to edge selection entity quoted = _quoted_triple(s, p, o) triples.append( Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted) ) - # Attach reasoning to edge selection entity + # Structured cross-encoder metadata + concept = edge_info.get("concept") + if concept: + triples.append( + _triple(edge_sel_uri, TG_CONCEPT, _literal(concept)) + ) + + score = edge_info.get("score") + if score is not None: + triples.append( + _triple(edge_sel_uri, TG_SCORE, _literal(str(score))) + ) + + # Legacy reasoning text (for non-cross-encoder callers) + reasoning = edge_info.get("reasoning", "") if reasoning: triples.append( _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index afb5c30f..1434d45d 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -29,6 +29,7 @@ from . namespaces import ( TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_EDGE_SELECTION, TG_SCORE, ) @@ -93,6 +94,7 @@ TG_CLASS_LABELS = [ _label_triple(TG_FINDING, "Finding"), _label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_STEP_RESULT, "Step Result"), + _label_triple(TG_EDGE_SELECTION, "Edge Selection"), ] # TrustGraph predicate labels @@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [ _label_triple(TG_ENTITY, "entity"), _label_triple(TG_SUBAGENT_GOAL, "subagent goal"), _label_triple(TG_PLAN_STEP, "plan step"), + _label_triple(TG_SCORE, "score"), ] diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 2a214201..63dc05fd 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -15,4 +15,5 @@ from .diagnosis import * from .collection import * from .storage import * from .tool_service import * -from .sparql_query import * \ No newline at end of file +from .sparql_query import * +from .reranker import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index 1696790b..0a9c23ef 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -6,17 +6,6 @@ from ..core.primitives import Error # Prompt services, abstract the prompt generation -# extract-definitions: -# chunk -> definitions -# extract-relationships: -# chunk -> relationships -# kg-prompt: -# query, triples -> answer -# document-prompt: -# query, documents -> answer -# extract-rows -# schema, chunk -> rows - @dataclass class PromptRequest: id: str = "" @@ -46,4 +35,4 @@ class PromptResponse: out_token: int | None = None model: str | None = None -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/reranker.py b/trustgraph-base/trustgraph/schema/services/reranker.py new file mode 100644 index 00000000..948746e4 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/reranker.py @@ -0,0 +1,35 @@ + +from dataclasses import dataclass, field + +from ..core.primitives import Error + +############################################################################ + +# Cross-encoder reranker + +@dataclass +class RerankerQuery: + query_id: str = "" + query_text: str = "" + +@dataclass +class RerankerDocument: + document_id: str = "" + document_text: str = "" + +@dataclass +class RerankerRequest: + queries: list[RerankerQuery] = field(default_factory=list) + documents: list[RerankerDocument] = field(default_factory=list) + limit: int = 10 + +@dataclass +class RerankerResult: + document_id: str = "" + query_id: str = "" + score: float = 0.0 + +@dataclass +class RerankerResponse: + error: Error | None = None + results: list[RerankerResult] = field(default_factory=list) diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 006a07f4..193ee1cd 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main" tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" +tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" tg-load-kg-core = "trustgraph.cli.load_kg_core:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index f39cdab0..892d2d35 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -112,14 +112,13 @@ def _question_explainable_api( if focus_full and focus_full.edge_selections: for edge_sel in focus_full.edge_selections: if edge_sel.edge: - # Resolve labels for edge components s_label, p_label, o_label = explain_client.resolve_edge_labels( edge_sel.edge, collection ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reason: {r_short}", file=sys.stderr) + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr) elif isinstance(entity, Synthesis): print(f"\n [synthesis] {prov_id}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/invoke_reranker.py b/trustgraph-cli/trustgraph/cli/invoke_reranker.py new file mode 100644 index 00000000..91337c97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_reranker.py @@ -0,0 +1,127 @@ +""" +Invokes the reranker service to score and rank documents by relevance +to one or more queries. +""" + +import argparse +import json +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def query(url, flow_id, queries, documents, limit, token=None, + workspace="default"): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + + query_objects = [ + {"query_id": str(i), "query_text": q} + for i, q in enumerate(queries) + ] + + document_objects = [ + {"document_id": str(i), "document_text": d} + for i, d in enumerate(documents) + ] + + result = flow.rerank( + queries=query_objects, + documents=document_objects, + limit=limit, + ) + + if "error" in result and result["error"]: + err = result["error"] + print(f"Error: [{err.get('type', '')}] {err.get('message', '')}") + return + + for r in result.get("results", []): + doc_idx = int(r["document_id"]) + query_idx = int(r["query_id"]) + print( + f" {r['score']:.4f} | " + f"query: {queries[query_idx]} | " + f"doc: {documents[doc_idx]}" + ) + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-reranker', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + '-q', '--query', + action='append', + required=True, + help='Query text (can be specified multiple times)', + ) + + parser.add_argument( + 'documents', + nargs='+', + help='Documents to rerank', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + queries=args.query, + documents=args.documents, + limit=args.limit, + token=args.token, + workspace=args.workspace, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 17aaca1a..ed4a9807 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_ ) print(f" {i}. ({s_label}, {p_label}, {o_label})") - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reasoning: {r_short}") + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}") if show_provenance and edge_sel.edge: provenance = trace_edge_provenance( @@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type): "selected_edges": [ { "edge": edge_sel.edge, - "reasoning": edge_sel.reasoning, + "concept": edge_sel.concept, + "score": edge_sel.score, } for edge_sel in focus.edge_selections ], diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index f9f6c5d9..90647104 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "faiss-cpu", "falkordb", "fastembed", + "flashrank", "ibis", "jsonschema", "langchain", @@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone: graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" graph-embeddings = "trustgraph.embeddings.graph_embeddings:run" graph-rag = "trustgraph.retrieval.graph_rag:run" +reranker-flashrank = "trustgraph.reranker.flashrank:run" kg-extract-agent = "trustgraph.extract.kg.agent:run" kg-extract-definitions = "trustgraph.extract.kg.definitions:run" kg-extract-rows = "trustgraph.extract.kg.rows:run" diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index bddb009d..7285250f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . document_embeddings_query import DocumentEmbeddingsQueryRequestor from . row_embeddings_query import RowEmbeddingsQueryRequestor from . mcp_tool import McpToolRequestor +from . reranker import RerankerRequestor from . text_load import TextLoad from . document_load import DocumentLoad @@ -74,6 +75,7 @@ request_response_dispatchers = { "structured-diag": StructuredDiagRequestor, "row-embeddings": RowEmbeddingsQueryRequestor, "sparql": SparqlQueryRequestor, + "reranker": RerankerRequestor, } system_dispatchers = { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py new file mode 100644 index 00000000..e456f3d1 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py @@ -0,0 +1,31 @@ + +from ... schema import RerankerRequest, RerankerResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class RerankerRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(RerankerRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=RerankerRequest, + response_schema=RerankerResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("reranker") + self.response_translator = TranslatorRegistry.get_response_translator("reranker") + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index bdc3ed4c..14f820c2 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -518,6 +518,7 @@ _FLOW_SERVICES = { "structured-diag": "structured-query:read", "row-embeddings": "row-embeddings:read", "sparql": "sparql:read", + "reranker": "reranker", } for _kind, _cap in _FLOW_SERVICES.items(): _register_flow_kind("flow-service", _kind, _cap) diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index f1f7d92d..fced972e 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -72,6 +72,7 @@ _READER_CAPS = { "row-embeddings:read", "llm", "embeddings", + "reranker", "mcp", "config:read", "flows:read", diff --git a/trustgraph-flow/trustgraph/reranker/__init__.py b/trustgraph-flow/trustgraph/reranker/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/__init__.py @@ -0,0 +1 @@ + diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py new file mode 100644 index 00000000..bd3b0e96 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py @@ -0,0 +1,2 @@ + +from . processor import * diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py new file mode 100644 index 00000000..1ebce4d4 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/processor.py b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py new file mode 100644 index 00000000..481d1a79 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py @@ -0,0 +1,109 @@ + +""" +Reranker service using flashrank. +Scores query-document pairs and returns the top results ranked by +relevance. +""" + +import asyncio +import logging + +from ... base import RerankerService +from ... schema import RerankerResult + +from flashrank import Ranker, RerankRequest + +logger = logging.getLogger(__name__) + +default_ident = "reranker" + +default_model = "ms-marco-MiniLM-L-12-v2" + +class Processor(RerankerService): + + def __init__(self, **params): + + model = params.get("model", default_model) + + super(Processor, self).__init__( + **params | { "model": model } + ) + + self.default_model = model + + self.cached_model_name = None + self.ranker = None + + self._load_model(model) + + def _load_model(self, model_name): + if self.cached_model_name != model_name: + logger.info(f"Loading flashrank model: {model_name}") + self.ranker = Ranker(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"flashrank model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + def _run_rerank(self, query, passages): + request = RerankRequest(query=query, passages=passages) + return self.ranker.rerank(request) + + async def on_rerank(self, queries, documents, limit, model=None): + + if not queries or not documents: + return [] + + use_model = model or self.default_model + + if self.cached_model_name != use_model: + await asyncio.to_thread(self._load_model, use_model) + + passages = [ + {"id": d.document_id, "text": d.document_text} + for d in documents + ] + + best_scores = {} + + for q in queries: + ranked = await asyncio.to_thread( + self._run_rerank, q.query_text, passages, + ) + + for r in ranked: + doc_id = r["id"] + score = r["score"] + score = float(score) + if doc_id not in best_scores or score > best_scores[doc_id][1]: + best_scores[doc_id] = (q.query_id, score) + + results = sorted( + best_scores.items(), + key=lambda x: x[1][1], + reverse=True, + )[:limit] + + return [ + RerankerResult( + document_id=doc_id, + query_id=query_id, + score=score, + ) + for doc_id, (query_id, score) in results + ] + + @staticmethod + def add_args(parser): + + RerankerService.add_args(parser) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Reranker model (default: {default_model})' + ) + +def run(): + + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 81dc8fe2..06c0b5b4 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -120,7 +120,7 @@ class Query: def __init__( self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, track_usage=None, + max_path_length=2, edge_limit=25, track_usage=None, ): self.rag = rag self.collection = collection @@ -129,6 +129,7 @@ class Query: self.triple_limit = triple_limit self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length + self.edge_limit = edge_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -217,12 +218,9 @@ class Query: logger.debug(f" {ent}") return entities, concepts - + async def maybe_label(self, e): - # The label cache lives on a per-request GraphRag instance — no - # cross-request isolation concern. The collection prefix keeps - # entries from different collections distinct within one request. cache_key = f"{self.collection}:{e}" cached_label = self.rag.label_cache.get(cache_key) @@ -244,11 +242,10 @@ class Query: return label async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently using streaming""" + """Execute triple queries for multiple entities concurrently.""" tasks = [] for entity in entities: - # Create concurrent streaming tasks for all 3 query types per entity tasks.extend([ self.rag.triples_client.query_stream( s=entity, p=None, o=None, @@ -270,10 +267,8 @@ class Query: ) ]) - # Execute all queries concurrently results = await asyncio.gather(*tasks, return_exceptions=True) - # Combine all results all_triples = [] for result in results: if not isinstance(result, Exception) and result is not None: @@ -281,168 +276,151 @@ class Query: return all_triples - async def follow_edges_batch(self, entities, max_depth): - """Optimized iterative graph traversal with batching. - - Returns: - tuple: (subgraph, term_map) where subgraph is a set of - (str, str, str) tuples and term_map maps each string tuple - to its original (Term, Term, Term) for type-preserving - provenance. - """ - visited = set() - current_level = set(entities) - subgraph = set() - term_map = {} # (str, str, str) -> (Term, Term, Term) - - for depth in range(max_depth): - if not current_level or len(subgraph) >= self.max_subgraph_size: - break - - # Filter out already visited entities - unvisited_entities = [e for e in current_level if e not in visited] - if not unvisited_entities: - break - - # Batch query all unvisited entities at current level - triples = await self.execute_batch_triple_queries( - unvisited_entities, self.triple_limit - ) - - # Process results and collect next level entities - next_level = set() - for triple in triples: - triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) - subgraph.add(triple_tuple) - term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o)) - - # Collect entities for next level (only from s and o positions) - if depth < max_depth - 1: # Don't collect for final depth - s, p, o = triple_tuple - if s not in visited: - next_level.add(s) - if o not in visited: - next_level.add(o) - - # Stop if subgraph size limit reached - if len(subgraph) >= self.max_subgraph_size: - return subgraph, term_map - - # Update for next iteration - visited.update(current_level) - current_level = next_level - - return subgraph, term_map - - async def follow_edges(self, ent, subgraph, path_length): - """Legacy method - replaced by follow_edges_batch""" - # Maintain backward compatibility with early termination checks - if path_length <= 0: - return - - if len(subgraph) >= self.max_subgraph_size: - return - - # For backward compatibility, convert to new approach - batch_result, _ = await self.follow_edges_batch([ent], path_length) - subgraph.update(batch_result) - - async def get_subgraph(self, query): - """ - Get subgraph by extracting concepts, finding entities, and traversing. - - Returns: - tuple: (subgraph, term_map, entities, concepts) where subgraph is - a list of (s, p, o) string tuples, term_map maps each string - tuple to its original (Term, Term, Term), entities is the seed - entity list, and concepts is the extracted concept list. - """ - - entities, concepts = await self.get_entities(query) - - if self.verbose: - logger.debug("Getting subgraph...") - - # Use optimized batch traversal instead of sequential processing - subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length) - - return list(subgraph), term_map, entities, concepts - async def resolve_labels_batch(self, entities): - """Resolve labels for multiple entities in parallel""" - tasks = [] - for entity in entities: - tasks.append(self.maybe_label(entity)) - + """Resolve labels for multiple entities in parallel.""" + tasks = [self.maybe_label(entity) for entity in entities] return await asyncio.gather(*tasks, return_exceptions=True) - async def get_labelgraph(self, query): - """ - Get subgraph with labels resolved for display. + async def hop_and_filter(self, seed_entities, concepts): + """Iterative hop-and-filter graph traversal with cross-encoder. + + At each hop: + 1. Retrieve all edges one hop from the frontier. + 2. Resolve labels and represent each edge as "{p} {o}". + 3. Score edges against concepts using the cross-encoder. + 4. Select the top edges; their target nodes become the next + frontier. Returns: - tuple: (labeled_edges, uri_map, entities, concepts) where: - - labeled_edges: list of (label_s, label_p, label_o) tuples - - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) - - entities: list of seed entity URI strings - - concepts: list of concept strings extracted from query + tuple: (selected_edges, uri_map, edge_metadata) where: + - selected_edges: list of (label_s, label_p, label_o) + - uri_map: dict mapping edge_id -> (Term, Term, Term) + - edge_metadata: dict mapping edge_id -> {concept, score} """ - subgraph, term_map, entities, concepts = await self.get_subgraph(query) + all_selected_edges = [] + uri_map = {} + edge_metadata = {} + frontier = set(seed_entities) + visited_entities = set() + seen_edges = set() - # Filter out label triples - filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] + for hop in range(self.max_path_length): + if not frontier: + break - # Collect all unique entities that need label resolution - entities_to_resolve = set() - for s, p, o in filtered_subgraph: - entities_to_resolve.update([s, p, o]) + unvisited = [e for e in frontier if e not in visited_entities] + if not unvisited: + break - # Batch resolve labels for all entities in parallel - entity_list = list(entities_to_resolve) - resolved_labels = await self.resolve_labels_batch(entity_list) + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(unvisited)} frontier entities" + ) - # Create entity-to-label mapping - label_map = {} - for entity, label in zip(entity_list, resolved_labels): - if not isinstance(label, Exception): - label_map[entity] = label - else: - label_map[entity] = entity # Fallback to entity itself - - # Apply labels to subgraph and build URI mapping - labeled_edges = [] - uri_map = {} # Maps edge_id of labeled edge -> original Term triple - - for s, p, o in filtered_subgraph: - labeled_triple = ( - label_map.get(s, s), - label_map.get(p, p), - label_map.get(o, o) + # Retrieve edges one hop from frontier + triples = await self.execute_batch_triple_queries( + unvisited, self.triple_limit, ) - labeled_edges.append(labeled_triple) - # Map from labeled edge ID to original Terms (preserving types) - labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) - uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o)) + # Deduplicate and filter already-seen edges + hop_triples = [] + hop_term_map = {} + for triple in triples: + triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) + if triple_tuple[1] == LABEL: + continue + if triple_tuple in seen_edges: + continue + seen_edges.add(triple_tuple) + hop_triples.append(triple_tuple) + hop_term_map[triple_tuple] = ( + to_term(triple.s), to_term(triple.p), to_term(triple.o), + ) - labeled_edges = labeled_edges[0:self.max_subgraph_size] + if not hop_triples: + visited_entities.update(frontier) + break - if self.verbose: - logger.debug("Subgraph:") - for edge in labeled_edges: - logger.debug(f" {str(edge)}") + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(hop_triples)} candidate edges" + ) - if self.verbose: - logger.debug("Done.") + # Resolve labels for all entities in hop edges + entities_to_resolve = set() + for s, p, o in hop_triples: + entities_to_resolve.update([s, p, o]) - return labeled_edges, uri_map, entities, concepts + entity_list = list(entities_to_resolve) + resolved = await self.resolve_labels_batch(entity_list) + + label_map = {} + for entity, label in zip(entity_list, resolved): + if not isinstance(label, Exception): + label_map[entity] = label + else: + label_map[entity] = entity + + # Build labeled edges and documents for cross-encoder + labeled_hop = [] + for s, p, o in hop_triples: + ls = label_map.get(s, s) + lp = label_map.get(p, p) + lo = label_map.get(o, o) + labeled_hop.append((ls, lp, lo)) + + documents = [ + {"id": str(i), "text": f"{lp} {lo}"} + for i, (ls, lp, lo) in enumerate(labeled_hop) + ] + + queries = [ + {"id": str(i), "text": c} + for i, c in enumerate(concepts) + ] + + # Score with cross-encoder + results = await self.rag.reranker_client.rerank( + queries=queries, + documents=documents, + limit=self.edge_limit, + ) + + # Collect selected edges and metadata + next_frontier = set() + for r in results: + idx = int(r.document_id) + ls, lp, lo = labeled_hop[idx] + s, p, o = hop_triples[idx] + eid = edge_id(ls, lp, lo) + + all_selected_edges.append((ls, lp, lo)) + uri_map[eid] = hop_term_map[(s, p, o)] + edge_metadata[eid] = { + "concept": concepts[int(r.query_id)], + "score": r.score, + } + + # Target nodes become next frontier + next_frontier.add(s) + next_frontier.add(o) + + if self.verbose: + logger.debug( + f"Hop {hop + 1}: selected {len(results)} edges" + ) + + visited_entities.update(frontier) + frontier = next_frontier - visited_entities + + return all_selected_edges, uri_map, edge_metadata async def trace_source_documents(self, edge_uris): """ Trace selected edges back to their source documents via provenance. - Follows the chain: edge → subgraph (via tg:contains) → chunk → - page → document (via prov:wasDerivedFrom), all in urn:graph:source. + Follows the chain: edge -> subgraph (via tg:contains) -> chunk -> + page -> document (via prov:wasDerivedFrom), all in urn:graph:source. Args: edge_uris: List of (s, p, o) URI string tuples @@ -453,7 +431,6 @@ class Query: # Step 1: Find subgraphs containing these edges via tg:contains subgraph_tasks = [] for s, p, o in edge_uris: - # s, p, o may be Term objects (preserving types) or strings s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s) p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p) o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o) @@ -487,12 +464,10 @@ class Query: return [] # Step 2: Walk prov:wasDerivedFrom chain to find documents - # Each level: query ?entity prov:wasDerivedFrom ?parent - # Stop when we find entities typed tg:Document current_uris = subgraph_uris doc_uris = set() - for depth in range(4): # Max depth: subgraph → chunk → page → doc + for depth in range(4): if not current_uris: break @@ -509,7 +484,6 @@ class Query: *derivation_tasks, return_exceptions=True ) - # URIs with no parent are root documents next_uris = set() for uri, result in zip(current_uris, derivation_results): if isinstance(result, Exception) or not result: @@ -524,7 +498,6 @@ class Query: return [] # Step 3: Get all document metadata properties - # Skip structural predicates that aren't useful context SKIP_PREDICATES = { PROV_WAS_DERIVED_FROM, "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", @@ -565,7 +538,7 @@ class GraphRag: def __init__( self, prompt_client, embeddings_client, graph_embeddings_client, - triples_client, verbose=False, + triples_client, reranker_client, verbose=False, ): self.verbose = verbose @@ -574,9 +547,8 @@ class GraphRag: self.embeddings_client = embeddings_client self.graph_embeddings_client = graph_embeddings_client self.triples_client = triples_client + self.reranker_client = reranker_client - # Replace simple dict with LRU cache with TTL - # CRITICAL: This cache only lives for one request due to per-request instantiation self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) if self.verbose: @@ -585,33 +557,12 @@ class GraphRag: async def query( self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, edge_score_limit = 30, edge_limit = 25, + max_path_length = 2, edge_limit = 25, streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, parent_uri = "", ): - """ - Execute a GraphRAG query with real-time explainability tracking. - - Args: - query: The query string - collection: Collection identifier - entity_limit: Max entities to retrieve - triple_limit: Max triples per entity - max_subgraph_size: Max edges in subgraph - max_path_length: Max hops from seed entities - edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter) - edge_limit: Max edges after LLM scoring - streaming: Enable streaming LLM response - chunk_callback: async def callback(chunk, end_of_stream) for streaming - explain_callback: async def callback(triples, explain_id) for real-time explainability - save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian - - Returns: - tuple: (answer_text, usage) where usage is a dict with - in_token, out_token, model - """ # Accumulate token usage across all prompt calls total_in = 0 total_out = 0 @@ -638,7 +589,9 @@ class GraphRag: foc_uri = make_focus_uri(session_id) syn_uri = make_synthesis_uri(session_id) - timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + timestamp = datetime.now(timezone.utc).isoformat().replace( + "+00:00", "Z", + ) # Emit question explainability immediately if explain_callback: @@ -657,10 +610,12 @@ class GraphRag: triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, track_usage = track_usage, ) - kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) + # Step 1: Extract concepts and find seed entities + seed_entities, concepts = await q.get_entities(query) # Emit grounding explain after concept extraction if explain_callback: @@ -676,11 +631,16 @@ class GraphRag: ) await explain_callback(gnd_triples, gnd_uri) - # Emit exploration explain after graph retrieval completes + # Step 2: Iterative hop-and-filter with cross-encoder + selected_edges, uri_map, edge_metadata = await q.hop_and_filter( + seed_entities, concepts, + ) + + # Emit exploration explain if explain_callback: exp_triples = set_graph( exploration_triples( - exp_uri, gnd_uri, len(kg), + exp_uri, gnd_uri, len(selected_edges), entities=seed_entities, ), GRAPH_RETRIEVAL @@ -688,235 +648,63 @@ class GraphRag: await explain_callback(exp_triples, exp_uri) if self.verbose: - logger.debug("Invoking LLM...") - logger.debug(f"Knowledge graph: {kg}") - logger.debug(f"Query: {query}") - - # Semantic pre-filter: reduce edges before expensive LLM scoring - if edge_score_limit > 0 and len(kg) > edge_score_limit: - - if self.verbose: + logger.debug(f"Selected {len(selected_edges)} edges") + for s, p, o in selected_edges: + eid = edge_id(s, p, o) + meta = edge_metadata.get(eid, {}) logger.debug( - f"Semantic pre-filter: {len(kg)} edges > " - f"limit {edge_score_limit}, filtering..." + f" {meta.get('score', 0):.4f} " + f"[{meta.get('concept', '')}] " + f"{s} | {p} | {o}" ) - # Embed edge descriptions: "subject, predicate, object" - edge_descriptions = [ - f"{s}, {p}, {o}" for s, p, o in kg - ] - - # Embed concepts and edge descriptions concurrently - concept_embed_task = self.embeddings_client.embed(concepts) - edge_embed_task = self.embeddings_client.embed(edge_descriptions) - - concept_vectors, edge_vectors = await asyncio.gather( - concept_embed_task, edge_embed_task - ) - - # Score each edge by max cosine similarity to any concept - def cosine_similarity(a, b): - dot = sum(x * y for x, y in zip(a, b)) - norm_a = math.sqrt(sum(x * x for x in a)) - norm_b = math.sqrt(sum(x * x for x in b)) - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) - - edge_scores = [] - for i, edge_vec in enumerate(edge_vectors): - max_sim = max( - cosine_similarity(edge_vec, cv) - for cv in concept_vectors - ) - edge_scores.append((max_sim, i)) - - # Sort by similarity descending and keep top edge_score_limit - edge_scores.sort(reverse=True) - keep_indices = set( - idx for _, idx in edge_scores[:edge_score_limit] - ) - - # Filter kg and rebuild uri_map - filtered_kg = [] - filtered_uri_map = {} - for i, (s, p, o) in enumerate(kg): - if i in keep_indices: - filtered_kg.append((s, p, o)) - eid = edge_id(s, p, o) - if eid in uri_map: - filtered_uri_map[eid] = uri_map[eid] - - if self.verbose: - logger.debug( - f"Semantic pre-filter kept {len(filtered_kg)} " - f"of {len(kg)} edges" - ) - - kg = filtered_kg - uri_map = filtered_uri_map - - # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} - # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) - edge_map = {} - edges_with_ids = [] - for s, p, o in kg: - eid = edge_id(s, p, o) - edge_map[eid] = (s, p, o) - edges_with_ids.append({ - "id": eid, - "s": s, - "p": p, - "o": o - }) - - if self.verbose: - logger.debug(f"Built edge map with {len(edge_map)} edges") - - # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_result = await self.prompt_client.prompt( - "kg-edge-scoring", - variables={ - "query": query, - "knowledge": edges_with_ids - } - ) - track_usage(scoring_result) - - if self.verbose: - logger.debug(f"Edge scoring result: {scoring_result}") - - # Parse scoring response (jsonl) to get edge IDs with scores - scored_edges = [] - - for obj in scoring_result.objects or []: - if isinstance(obj, dict) and "id" in obj and "score" in obj: - try: - score = int(obj["score"]) - except (ValueError, TypeError): - score = 0 - scored_edges.append({"id": obj["id"], "score": score}) - - # Select top N edges by score - scored_edges.sort(key=lambda x: x["score"], reverse=True) - top_edges = scored_edges[:edge_limit] - selected_ids = {e["id"] for e in top_edges} - - if self.verbose: - logger.debug( - f"Scored {len(scored_edges)} edges, " - f"selected top {len(selected_ids)}" - ) - - # Filter to selected edges - selected_edges = [] - for eid in selected_ids: - if eid in edge_map: - selected_edges.append(edge_map[eid]) - - # Step 1b: Edge Reasoning + Document Tracing (concurrent) - selected_edges_with_ids = [ - {"id": eid, "s": s, "p": p, "o": o} - for eid in selected_ids - if eid in edge_map - for s, p, o in [edge_map[eid]] - ] - - # Collect selected edge URIs for document tracing + # Step 3: Document tracing selected_edge_uris = [ - uri_map[eid] - for eid in selected_ids - if eid in uri_map + uri_map[edge_id(s, p, o)] + for s, p, o in selected_edges + if edge_id(s, p, o) in uri_map ] - # Run reasoning and document tracing concurrently - async def _get_reasoning(): - result = await self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) - track_usage(result) - return result - - reasoning_task = _get_reasoning() - doc_trace_task = q.trace_source_documents(selected_edge_uris) - - reasoning_result, source_documents = await asyncio.gather( - reasoning_task, doc_trace_task, return_exceptions=True + source_documents = await q.trace_source_documents( + selected_edge_uris, ) - # Handle exceptions from gather - if isinstance(reasoning_result, Exception): - logger.warning( - f"Edge reasoning failed: {reasoning_result}" - ) - reasoning_result = None if isinstance(source_documents, Exception): logger.warning( f"Document tracing failed: {source_documents}" ) source_documents = [] - - if self.verbose: - logger.debug(f"Edge reasoning result: {reasoning_result}") - - # Parse reasoning response (jsonl) and build explainability data - reasoning_map = {} - - if reasoning_result is not None: - for obj in reasoning_result.objects or []: - if isinstance(obj, dict) and "id" in obj: - reasoning_map[obj["id"]] = obj.get("reasoning", "") - + # Build focus explainability data with cross-encoder metadata selected_edges_with_reasoning = [] - for eid in selected_ids: + for s, p, o in selected_edges: + eid = edge_id(s, p, o) if eid in uri_map: uri_s, uri_p, uri_o = uri_map[eid] + meta = edge_metadata.get(eid, {}) selected_edges_with_reasoning.append({ "edge": (uri_s, uri_p, uri_o), - "reasoning": reasoning_map.get(eid, ""), + "concept": meta.get("concept", ""), + "score": meta.get("score", 0), }) - if self.verbose: - logger.debug(f"Filtered to {len(selected_edges)} edges") - - # Emit focus explain after edge selection completes + # Emit focus explain if explain_callback: - # Sum scoring + reasoning token usage for focus event - focus_in = 0 - focus_out = 0 - focus_model = None - for r in [scoring_result, reasoning_result]: - if r is not None: - if r.in_token is not None: - focus_in += r.in_token - if r.out_token is not None: - focus_out += r.out_token - if r.model is not None: - focus_model = r.model - foc_triples = set_graph( focus_triples( - foc_uri, exp_uri, selected_edges_with_reasoning, session_id, - in_token=focus_in or None, - out_token=focus_out or None, - model=focus_model, + foc_uri, exp_uri, + selected_edges_with_reasoning, session_id, ), GRAPH_RETRIEVAL ) await explain_callback(foc_triples, foc_uri) - # Step 2: Synthesis - LLM generates answer from selected edges only + # Step 4: Synthesis selected_edge_dicts = [ {"s": s, "p": p, "o": o} for s, p, o in selected_edges ] - # Add source document metadata as knowledge edges for s, p, o in source_documents: selected_edge_dicts.append({ "s": s, "p": p, "o": o, @@ -928,7 +716,6 @@ class GraphRag: } if streaming and chunk_callback: - # Accumulate chunks for answer storage while forwarding to callback accumulated_chunks = [] async def accumulating_callback(chunk, end_of_stream): @@ -942,7 +729,6 @@ class GraphRag: chunk_callback=accumulating_callback ) track_usage(synthesis_result) - # Combine all chunks into full response resp = "".join(accumulated_chunks) else: synthesis_result = await self.prompt_client.prompt( @@ -955,29 +741,42 @@ class GraphRag: if self.verbose: logger.debug("Query processing complete") - # Emit synthesis explain after synthesis completes + # Emit synthesis explain if explain_callback: synthesis_doc_id = None answer_text = resp if resp else "" - # Save answer to librarian if save_answer_callback and answer_text: synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" try: await save_answer_callback(synthesis_doc_id, answer_text) if self.verbose: - logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + logger.debug( + f"Saved answer to librarian: " + f"{synthesis_doc_id}" + ) except Exception as e: - logger.warning(f"Failed to save answer to librarian: {e}") + logger.warning( + f"Failed to save answer to librarian: {e}" + ) synthesis_doc_id = None syn_triples = set_graph( synthesis_triples( syn_uri, foc_uri, document_id=synthesis_doc_id, - in_token=synthesis_result.in_token if synthesis_result else None, - out_token=synthesis_result.out_token if synthesis_result else None, - model=synthesis_result.model if synthesis_result else None, + in_token=( + synthesis_result.in_token + if synthesis_result else None + ), + out_token=( + synthesis_result.out_token + if synthesis_result else None + ), + model=( + synthesis_result.model + if synthesis_result else None + ), ), GRAPH_RETRIEVAL ) @@ -993,4 +792,3 @@ class GraphRag: } return resp, usage - diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 959ae8e0..27ec4937 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -13,6 +13,7 @@ from . graph_rag import GraphRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -32,7 +33,6 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) - edge_score_limit = params.get("edge_score_limit", 30) edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( @@ -43,7 +43,6 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, - "edge_score_limit": edge_score_limit, "edge_limit": edge_limit, } ) @@ -52,7 +51,6 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length - self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit # Workspace isolation is enforced by the flow layer (flow.workspace). @@ -96,6 +94,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -163,6 +168,7 @@ class Processor(FlowProcessor): graph_embeddings_client=flow("graph-embeddings-request"), triples_client=flow("triples-request"), prompt_client=flow("prompt-request"), + reranker_client=flow("reranker-request"), verbose=True, ) @@ -186,11 +192,6 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length - if v.edge_score_limit: - edge_score_limit = v.edge_score_limit - else: - edge_score_limit = self.default_edge_score_limit - if v.edge_limit: edge_limit = v.edge_limit else: @@ -225,7 +226,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, streaming = True, chunk_callback = send_chunk, @@ -241,7 +242,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, @@ -338,18 +339,11 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) - parser.add_argument( - '--edge-score-limit', - type=int, - default=30, - help=f'Semantic pre-filter limit before LLM scoring (default: 30)' - ) - parser.add_argument( '--edge-limit', type=int, default=25, - help=f'Max edges after LLM scoring (default: 25)' + help=f'Max edges selected per hop by cross-encoder (default: 25)' ) # Note: Explainability triples are now stored in the request's collection From f20b50cfb2b7713f098a7722fd75c5b06a56aae8 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Jul 2026 14:48:32 +0100 Subject: [PATCH 2/9] feat: add API variant profiles and thinking support to OpenAI processor (#1007) Add a --variant flag (openai, deepseek, qwen, mistral, llama) that encapsulates provider-specific API differences: output token parameter names, thinking/reasoning toggles, temperature rules, and thinking output extraction. Add --thinking flag (off, low, medium, high) to control reasoning effort. --- .../model/text_completion/openai/llm.py | 73 ++++++-- .../model/text_completion/openai/variants.py | 176 ++++++++++++++++++ 2 files changed, 233 insertions(+), 16 deletions(-) create mode 100644 trustgraph-flow/trustgraph/model/text_completion/openai/variants.py diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index c8ab9c36..01035bc9 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -10,6 +10,7 @@ import logging from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk +from . variants import get_variant, DEFAULT_VARIANT, VARIANTS # Module logger logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ default_temperature = 0.0 default_max_output = 4096 default_api_key = os.getenv("OPENAI_TOKEN") default_base_url = os.getenv("OPENAI_BASE_URL") +default_thinking = "off" if default_base_url is None or default_base_url == "": default_base_url = "https://api.openai.com/v1" @@ -28,16 +30,21 @@ if default_base_url is None or default_base_url == "": class Processor(LlmService): def __init__(self, **params): - + model = params.get("model", default_model) api_key = params.get("api_key", default_api_key) base_url = params.get("url", default_base_url) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + thinking = params.get("thinking", default_thinking) + variant_name = params.get("variant", DEFAULT_VARIANT) if not api_key: api_key = "not-set" + self.variant = get_variant(variant_name) + self.thinking = thinking + super(Processor, self).__init__( **params | { "model": model, @@ -56,13 +63,28 @@ class Processor(LlmService): else: self.openai = OpenAI(api_key=api_key) - logger.info("OpenAI LLM service initialized") + logger.info( + f"OpenAI LLM service initialized " + f"(variant={self.variant.name}, thinking={self.thinking})" + ) + + def _build_kwargs(self, model_name, temperature): + """Build API call kwargs using the active variant.""" + return self.variant.completion_kwargs( + max_output=self.max_output, + temperature=temperature, + thinking=self.thinking, + ) + + def _extract_content(self, message): + """Extract visible content from a response message.""" + if hasattr(self.variant, "extract_content"): + return self.variant.extract_content(message) + return message.content async def generate_content(self, system, prompt, model=None, temperature=None): - # Use provided model or fall back to default model_name = model or self.default_model - # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model: {model_name}") @@ -72,6 +94,8 @@ class Processor(LlmService): try: + api_kwargs = self._build_kwargs(model_name, effective_temperature) + resp = self.openai.chat.completions.create( model=model_name, messages=[ @@ -85,18 +109,23 @@ class Processor(LlmService): ] } ], - temperature=effective_temperature, - max_completion_tokens=self.max_output, + **api_kwargs, ) - + inputtokens = resp.usage.prompt_tokens outputtokens = resp.usage.completion_tokens - logger.debug(f"LLM response: {resp.choices[0].message.content}") + + content = self._extract_content(resp.choices[0].message) + thinking = self.variant.extract_thinking(resp.choices[0].message) + + logger.debug(f"LLM response: {content}") + if thinking: + logger.debug(f"LLM thinking: {thinking[:200]}...") logger.info(f"Input Tokens: {inputtokens}") logger.info(f"Output Tokens: {outputtokens}") resp = LlmResult( - text = resp.choices[0].message.content, + text = content, in_token = inputtokens, out_token = outputtokens, model = model_name @@ -136,9 +165,7 @@ class Processor(LlmService): Stream content generation from OpenAI. Yields LlmChunk objects with is_final=True on the last chunk. """ - # Use provided model or fall back to default model_name = model or self.default_model - # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model (streaming): {model_name}") @@ -147,6 +174,8 @@ class Processor(LlmService): prompt = system + "\n\n" + prompt try: + api_kwargs = self._build_kwargs(model_name, effective_temperature) + response = self.openai.chat.completions.create( model=model_name, messages=[ @@ -160,16 +189,14 @@ class Processor(LlmService): ] } ], - temperature=effective_temperature, - max_completion_tokens=self.max_output, stream=True, - stream_options={"include_usage": True} + stream_options={"include_usage": True}, + **api_kwargs, ) total_input_tokens = 0 total_output_tokens = 0 - # Stream chunks for chunk in response: if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( @@ -254,6 +281,20 @@ class Processor(LlmService): help=f'LLM max output tokens (default: {default_max_output})' ) + parser.add_argument( + '--thinking', + choices=["off", "low", "medium", "high"], + default=default_thinking, + help=f'Thinking/reasoning effort level (default: {default_thinking})' + ) + + parser.add_argument( + '--variant', + choices=sorted(VARIANTS.keys()), + default=DEFAULT_VARIANT, + help=f'API variant (default: {DEFAULT_VARIANT})' + ) + def run(): - + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py new file mode 100644 index 00000000..d49b8991 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -0,0 +1,176 @@ +""" +OpenAI API variant profiles. + +Different providers expose OpenAI-compatible APIs with subtle differences +in parameter names, thinking/reasoning support, and temperature handling. +Each variant encapsulates those quirks so the processor doesn't need +provider-specific conditionals. +""" + +import re +import logging + +logger = logging.getLogger(__name__) + + +class Variant: + """Base variant — defines the interface all variants implement.""" + + name = None + token_param = "max_completion_tokens" + temperature_with_thinking = False + + def completion_kwargs(self, max_output, temperature, thinking): + """Build provider-specific kwargs for chat.completions.create(). + + Parameters + ---------- + max_output : int + Configured max output tokens. + temperature : float + Configured temperature. + thinking : str + Thinking effort level: "off", "low", "medium", "high". + + Returns + ------- + dict + Extra kwargs to spread into the API call. + """ + kwargs = {self.token_param: max_output} + + if thinking != "off": + kwargs.update(self.thinking_kwargs(thinking)) + if not self.temperature_with_thinking: + kwargs["temperature"] = 1.0 + else: + kwargs["temperature"] = temperature + else: + kwargs["temperature"] = temperature + + return kwargs + + def thinking_kwargs(self, effort): + """Return kwargs to enable thinking at the given effort level.""" + return {} + + def extract_thinking(self, message): + """Extract thinking/reasoning content from a response message.""" + return getattr(message, "reasoning_content", None) + + def extract_thinking_stream(self, delta): + """Extract thinking content from a streaming delta.""" + return getattr(delta, "reasoning_content", None) + + +class OpenAIVariant(Variant): + """Standard OpenAI API (GPT-4o, o1, o3, etc.).""" + + name = "openai" + token_param = "max_completion_tokens" + temperature_with_thinking = False + + def thinking_kwargs(self, effort): + return {"reasoning_effort": effort} + + +class DeepSeekVariant(Variant): + """DeepSeek API (R1, V3, etc.).""" + + name = "deepseek" + token_param = "max_completion_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = "enabled" if thinking != "off" else "disabled" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "thinking": {"type": enabled}, + }, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + +class QwenVariant(Variant): + """Qwen / Alibaba Cloud API.""" + + name = "qwen" + token_param = "max_completion_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = thinking != "off" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "enable_thinking": enabled, + }, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + +class MistralVariant(Variant): + """Mistral API (Mistral Large, etc.).""" + + name = "mistral" + token_param = "max_tokens" + temperature_with_thinking = False + + def thinking_kwargs(self, effort): + return {"reasoning_effort": effort} + + +class LlamaVariant(Variant): + """Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.). + + Thinking is typically always-on or always-off depending on the model. + When present, thinking appears inline as ... tags. + """ + + name = "llama" + token_param = "max_tokens" + temperature_with_thinking = True + + def thinking_kwargs(self, effort): + return {} + + def extract_thinking(self, message): + content = message.content or "" + match = re.search(r"(.*?)", content, re.DOTALL) + return match.group(1).strip() if match else None + + def extract_content(self, message): + """Strip think tags from visible content.""" + content = message.content or "" + return re.sub(r".*?", "", content, flags=re.DOTALL).strip() + + +VARIANTS = { + "openai": OpenAIVariant, + "deepseek": DeepSeekVariant, + "qwen": QwenVariant, + "mistral": MistralVariant, + "llama": LlamaVariant, +} + +DEFAULT_VARIANT = "openai" + + +def get_variant(name): + """Look up a variant by name, raising ValueError if unknown.""" + cls = VARIANTS.get(name) + if cls is None: + raise ValueError( + f"Unknown variant {name!r}. " + f"Available: {', '.join(sorted(VARIANTS))}" + ) + return cls() From 656ca430b95ac56f3810bfebf776a791196ddf52 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Jul 2026 15:40:23 +0100 Subject: [PATCH 3/9] fix: wire variant into text-completion integration test mocks (#1008) Tests using MagicMock processors need the variant, thinking mode, and _build_kwargs/_extract_content methods bound to work with the new variant-based API kwargs construction. --- .../test_text_completion_integration.py | 15 ++++++++++++++- .../test_text_completion_streaming_integration.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_text_completion_integration.py b/tests/integration/test_text_completion_integration.py index 6615bf84..521f7d74 100644 --- a/tests/integration/test_text_completion_integration.py +++ b/tests/integration/test_text_completion_integration.py @@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage from trustgraph.model.text_completion.openai.llm import Processor +from trustgraph.model.text_completion.openai.variants import get_variant from trustgraph.exceptions import TooManyRequests from trustgraph.base import LlmResult from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error +def _wire_variant(processor): + """Attach variant methods to a MagicMock processor.""" + processor.variant = get_variant("openai") + processor.thinking = "off" + processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor) + processor._extract_content = Processor._extract_content.__get__(processor, Processor) + + @pytest.mark.integration class TestTextCompletionIntegration: """Integration tests for OpenAI text completion service coordination""" @@ -66,6 +75,7 @@ class TestTextCompletionIntegration: # Add the actual generate_content method from Processor class processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) return processor @@ -119,6 +129,7 @@ class TestTextCompletionIntegration: # Add the actual generate_content method processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) # Act result = await processor.generate_content("System prompt", "User prompt") @@ -129,7 +140,7 @@ class TestTextCompletionIntegration: assert result.in_token == 50 assert result.out_token == 100 # Note: result.model comes from mock response, not processor config - + # Verify configuration was applied call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == config['model'] @@ -247,6 +258,7 @@ class TestTextCompletionIntegration: processor.max_output = processor_config["max_output"] processor.openai = mock_openai_client processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) processors.append(processor) # Simulate multiple concurrent requests @@ -354,6 +366,7 @@ class TestTextCompletionIntegration: processor.max_output = 2048 processor.openai = mock_openai_client processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) # Act await processor.generate_content("System prompt", "User prompt") diff --git a/tests/integration/test_text_completion_streaming_integration.py b/tests/integration/test_text_completion_streaming_integration.py index 6968affa..caa3ec9c 100644 --- a/tests/integration/test_text_completion_streaming_integration.py +++ b/tests/integration/test_text_completion_streaming_integration.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta from trustgraph.model.text_completion.openai.llm import Processor +from trustgraph.model.text_completion.openai.variants import get_variant from trustgraph.base import LlmChunk from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, @@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import ( ) +def _wire_variant(processor): + """Attach variant methods to a MagicMock processor.""" + processor.variant = get_variant("openai") + processor.thinking = "off" + processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor) + processor._extract_content = Processor._extract_content.__get__(processor, Processor) + + @pytest.mark.integration class TestTextCompletionStreaming: """Integration tests for Text Completion streaming""" @@ -69,6 +78,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) return processor @@ -190,6 +200,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] @@ -223,6 +234,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] From 11ca7c89c44a9fee92b3a32baaf9f8dc4fd1411a Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Jul 2026 16:20:43 +0100 Subject: [PATCH 4/9] feat: add GLM (Zhipu AI) variant for OpenAI processor (#1009) --- .../model/text_completion/openai/variants.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py index d49b8991..0c314a04 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -129,6 +129,28 @@ class MistralVariant(Variant): return {"reasoning_effort": effort} +class GlmVariant(Variant): + """GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.).""" + + name = "glm" + token_param = "max_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = "enabled" if thinking != "off" else "disabled" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "thinking": {"type": enabled}, + }, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + class LlamaVariant(Variant): """Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.). @@ -159,6 +181,7 @@ VARIANTS = { "deepseek": DeepSeekVariant, "qwen": QwenVariant, "mistral": MistralVariant, + "glm": GlmVariant, "llama": LlamaVariant, } From 55e2a2a3cedee2a0b83f5d3a10d58eb2f8b2773d Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Jul 2026 16:50:14 +0100 Subject: [PATCH 5/9] feat: add guided macOS installer and developer install guide (#1003) Interactive bash installer (install_trustgraph.sh) that detects hardware, recommends an LLM mode (OpenAI or Ollama), installs missing prerequisites via Homebrew, sets up a Python venv, runs the test suite, generates a deployment via npx @trustgraph/config, starts the Docker Compose stack, health-checks the API gateway, and opens the Workbench UI. Includes README.dev-install.md with usage documentation covering CLI options, environment variables, LLM mode selection, non-interactive/CI usage, uninstall, and troubleshooting. Currently macOS only. --- README.dev-install.md | 218 ++++ install_trustgraph.sh | 2603 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2821 insertions(+) create mode 100644 README.dev-install.md create mode 100644 install_trustgraph.sh diff --git a/README.dev-install.md b/README.dev-install.md new file mode 100644 index 00000000..d57cb1f3 --- /dev/null +++ b/README.dev-install.md @@ -0,0 +1,218 @@ +# TrustGraph Developer Install Guide + +A guided installer that gets TrustGraph running locally in a single +command. It detects your hardware, recommends an LLM backend, installs +missing prerequisites, runs the test suite, generates a compose deployment, +starts the stack, and opens the Workbench UI. + +> **macOS only.** This installer has only been tested on macOS. If you are +> on Linux or Windows, use the standard docker-compose / podman-compose +> installation instructions instead. + +## Quick start + +```bash +./install_trustgraph.sh +``` + +The installer walks you through each step interactively. When it finishes, +the Workbench UI opens at `http://localhost:8888` and the API gateway is +available at `http://localhost:8088/`. + +## Prerequisites + +The installer checks for these and offers to install any that are missing +(via Homebrew): + +- **Python 3** with venv support +- **Node.js / npx** (drives the `@trustgraph/config` deployment generator) +- **Docker** (with Compose) or **Podman** (with podman-compose) +- **curl** and **unzip** +- **Ollama** (only if you choose local LLMs) + +The installer can also launch Docker Desktop or the Ollama app for you if +they are installed but not running. + +## What the installer does + +1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU. +2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and + a GPU or >= 8 cores; `openai` otherwise. +3. **Collects configuration** -- API key, LLM provider, model choices, + install directory. Answers are saved to + `/trustgraph-installer.env` and reused on subsequent runs. +4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or + Podman, Ollama (if selected). +5. **Downloads Ollama models** (if using Ollama) -- chat model + (`granite4:350m` by default) and embeddings model (`mxbai-embed-large`). +6. **Creates a Python venv** and installs the local TrustGraph packages into + it, along with NLTK data and tiktoken caches. +7. **Runs the full pytest suite** against the local source tree. +8. **Runs `npx @trustgraph/config`** -- the existing interactive config + wizard that produces a `deploy.zip` with a compose file. +9. **Starts the compose stack** and waits for the API gateway to respond. +10. **Bootstraps IAM** and verifies the API key authenticates. +11. **Opens the Workbench UI** in your default browser. + +## Command-line options + +| Option | Description | +|---|---| +| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) | +| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) | +| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) | +| `--use-existing-compose FILE` | Skip config generation and start this compose file directly | +| `--skip-tests` | Do not run the pytest suite | +| `--no-launch` | Do not open the Workbench UI at the end | +| `--non-interactive` | Accept all defaults without prompting | +| `--yes` | Auto-accept confirmation prompts | +| `--fresh` | Remove installer-managed files before generating a new deployment | +| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files | +| `--dry-run` | Print detected hardware and planned defaults, then exit | +| `-h`, `--help` | Show the built-in help text | + +## Environment variables + +These override the interactive prompts when set: + +| Variable | Purpose | +|---|---| +| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) | +| `TRUSTGRAPH_URL` | API gateway URL | +| `TRUSTGRAPH_UI_URL` | Workbench UI URL | +| `OPENAI_TOKEN` | OpenAI-compatible API key | +| `OPENAI_BASE_URL` | OpenAI-compatible base URL | +| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL | +| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) | +| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) | +| `TG_INSTALL_DIR` | Override the install directory | +| `TG_VENV_DIR` | Override the Python venv location | +| `TG_NLTK_DATA_DIR` | Override the NLTK data directory | +| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory | +| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) | + +## Choosing an LLM mode + +### OpenAI (or any OpenAI-compatible provider) + +Best when you already have an API key or are running against a remote +endpoint. The installer asks for a base URL and an API key. + +```bash +OPENAI_TOKEN=sk-... ./install_trustgraph.sh +``` + +### Ollama (local models) + +Best on machines with enough RAM to run a small model. The installer detects +locally installed Ollama models and offers to pull missing ones. It uses +`host.docker.internal` so the Docker containers can reach the host-side +Ollama service. + +```bash +./install_trustgraph.sh # choose "ollama" when prompted +``` + +### None + +Start the platform without an LLM. Agent and RAG features will not work +until you configure one later through the Workbench. + +## Saved answers and re-running + +The installer saves your answers to +`/trustgraph-installer.env`. On the next run it loads those +answers as defaults, so you can re-run with a single Enter through each +prompt. + +To start completely fresh: + +```bash +./install_trustgraph.sh --fresh +``` + +This stops any running containers (keeping Docker volumes), removes +installer-managed files, and re-runs the full flow. + +## Using an existing compose file + +If you already have a compose file from the config tool or another source: + +```bash +./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml +``` + +This skips the config wizard and `npx` prerequisite check, and goes straight +to starting the stack. + +## Non-interactive / CI usage + +```bash +TRUSTGRAPH_TOKEN=tg_my-token \ +OPENAI_TOKEN=sk-... \ +./install_trustgraph.sh --non-interactive --yes --skip-tests +``` + +In non-interactive mode the installer uses defaults for every prompt. Pair +with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you +want a faster run. + +## Dry run + +Preview what the installer would do without making any changes: + +```bash +./install_trustgraph.sh --dry-run +``` + +This prints the detected hardware, recommended LLM mode, and planned +install paths, then exits. + +## Uninstalling + +```bash +./install_trustgraph.sh --remove-all +``` + +This stops containers, removes compose-managed volumes, and deletes +installer-managed files (venv, deploy output, logs, saved answers). It does +**not** remove Docker/Podman itself, container images, Ollama, or Ollama +models. + +## Troubleshooting + +### Logs + +All long-running operations write logs to `/logs/`. Key files: + +- `pytest.log` -- test suite output +- `compose-up.log` -- docker compose output +- `iam-bootstrap.log` -- IAM bootstrap output +- `ollama-pull-*.log` -- Ollama model downloads +- `pip-*.log` -- Python package installs +- `brew-install-*.log` -- Homebrew installs + +### API key rejected after reinstall + +If the API gateway returns 401/403 with your saved key, the compose volumes +likely contain IAM data from a previous install with a different key. Run: + +```bash +./install_trustgraph.sh --remove-all +./install_trustgraph.sh +``` + +This clears the old volumes and starts fresh. + +### Ollama not reachable from containers + +The Ollama base URL should use `host.docker.internal` instead of +`localhost` so that containers running in Docker Desktop can reach the +host-side Ollama service. The installer sets this automatically; if you +override `OLLAMA_HOST`, make sure the URL is reachable from inside the +container network. + +### Docker daemon not running + +The installer detects Docker Desktop and offers to start it. If that +doesn't work, start Docker Desktop manually and re-run the installer. diff --git a/install_trustgraph.sh b/install_trustgraph.sh new file mode 100644 index 00000000..b3919791 --- /dev/null +++ b/install_trustgraph.sh @@ -0,0 +1,2603 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +APP_NAME="TrustGraph" +DEFAULT_API_URL="http://localhost:8088/" +DEFAULT_UI_URL="http://localhost:8888" +DEFAULT_INSTALL_DIR="trustgraph-deploy" +DEFAULT_OLLAMA_MODEL="granite4:350m" +DEFAULT_OLLAMA_EMBEDDINGS_MODEL="mxbai-embed-large" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +INSTALL_DIR="${TG_INSTALL_DIR:-$SCRIPT_DIR/$DEFAULT_INSTALL_DIR}" +VENV_DIR="${TG_VENV_DIR:-$INSTALL_DIR/.venv}" +NLTK_DATA_DIR="${TG_NLTK_DATA_DIR:-$INSTALL_DIR/nltk_data}" +TIKTOKEN_CACHE_DIR_VALUE="${TIKTOKEN_CACHE_DIR:-$INSTALL_DIR/tiktoken_cache}" +PYTHON_BIN="python3" +API_URL="${TRUSTGRAPH_URL:-$DEFAULT_API_URL}" +UI_URL="${TRUSTGRAPH_UI_URL:-$DEFAULT_UI_URL}" + +RUN_TESTS=1 +AUTO_LAUNCH=1 +NON_INTERACTIVE=0 +DRY_RUN=0 +YES=0 +FRESH_INSTALL=0 +REMOVE_ALL=0 +USE_EXISTING_COMPOSE="" +HEALTH_TIMEOUT="${TG_HEALTH_TIMEOUT:-240}" +AUTH_CHECK_TIMEOUT="${TG_AUTH_CHECK_TIMEOUT:-45}" + +AUTH_TOKEN="${TRUSTGRAPH_TOKEN:-}" +LLM_MODE="" +OPENAI_TOKEN_VALUE="${OPENAI_TOKEN:-}" +OPENAI_BASE_URL_VALUE="${OPENAI_BASE_URL:-https://api.openai.com/v1}" +OLLAMA_BASE_URL_VALUE="${OLLAMA_HOST:-${OLLAMA_BASE_URL:-}}" +OLLAMA_MODEL="${OLLAMA_MODEL:-$DEFAULT_OLLAMA_MODEL}" +OLLAMA_EMBEDDINGS_MODEL="${OLLAMA_EMBEDDINGS_MODEL:-$DEFAULT_OLLAMA_EMBEDDINGS_MODEL}" + +HW_OS="" +HW_ARCH="" +HW_CPU_CORES="unknown" +HW_MEMORY_GB="unknown" +HW_GPU="none detected" +HW_CONTAINER_HINT="" +RECOMMENDED_LLM_MODE="openai" +RECOMMENDATION_REASON="" +COMPOSE_CMD=() +COLOR_RESET="" +COLOR_HEADING="" +COLOR_INFO="" +COLOR_WARN="" +COLOR_ERROR="" +COLOR_ACCENT="" + +usage() { + cat <<'USAGE' +Usage: ./install_trustgraph.sh [options] + +Guided local installer for TrustGraph. It detects the machine hardware, +recommends a local or hosted LLM path, asks for the few required values, +enumerates local Ollama models when relevant, runs the repo tests, generates +a deployment with the existing config tool, starts the stack, checks health, +and opens the Workbench UI. + +Options: + --install-dir PATH Directory for generated deployment files. + --api-url URL API gateway URL for health checks. + --ui-url URL Workbench UI URL to open. + --use-existing-compose F Skip config generation and start this compose file. + --skip-tests Do not run the full pytest suite. + --no-launch Do not open the Workbench UI at the end. + --non-interactive Use defaults where possible. Best with --dry-run or + --use-existing-compose. + --yes Accept confirmation prompts. + --fresh Remove installer-managed files in --install-dir + before generating a new deployment. + --remove-all Uninstall the installer-managed deployment: + stop containers, remove compose volumes, and + delete only installer-managed files. + --dry-run Show detected hardware and planned defaults only. + -h, --help Show this help. + +Environment defaults: + TRUSTGRAPH_TOKEN, TRUSTGRAPH_URL, OPENAI_TOKEN, OPENAI_BASE_URL, + OLLAMA_HOST, OLLAMA_BASE_URL, OLLAMA_MODEL, OLLAMA_EMBEDDINGS_MODEL, + TG_INSTALL_DIR, TG_VENV_DIR, TG_NLTK_DATA_DIR, TIKTOKEN_CACHE_DIR, + TG_HEALTH_TIMEOUT +USAGE +} + +say() { + printf '\n%b%s%b\n' "$COLOR_HEADING" "$*" "$COLOR_RESET" +} + +info() { + printf ' %b%s%b\n' "$COLOR_INFO" "$*" "$COLOR_RESET" +} + +warn() { + printf '%bWarning:%b %s\n' "$COLOR_WARN" "$COLOR_RESET" "$*" >&2 +} + +die() { + printf '%bError:%b %s\n' "$COLOR_ERROR" "$COLOR_RESET" "$*" >&2 + exit 1 +} + +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +spinner_enabled() { + [[ "${TG_NO_SPINNER:-0}" != "1" ]] && { [[ -t 2 ]] || [[ "${TG_FORCE_SPINNER:-0}" == "1" ]]; } +} + +clear_spinner_line() { + printf '\r\033[K' >&2 +} + +run_with_spinner() { + local message="$1" + shift + local frames=('|' '/' '-' '\') + local frame=0 + local pid + local status + + if ! spinner_enabled; then + "$@" + return + fi + + "$@" & + pid=$! + while kill -0 "$pid" 2>/dev/null; do + printf '\r %b%s%b %s' "$COLOR_ACCENT" "${frames[$frame]}" "$COLOR_RESET" "$message" >&2 + frame=$(((frame + 1) % ${#frames[@]})) + sleep 0.2 + done + + if wait "$pid"; then + status=0 + else + status=$? + fi + + clear_spinner_line + if [[ "$status" -eq 0 ]]; then + info "Done: $message" + else + warn "Failed: $message" + fi + return "$status" +} + +run_with_spinner_logged() { + local message="$1" + local log_file="$2" + shift 2 + local frames=('|' '/' '-' '\') + local frame=0 + local pid + local status + + if ! spinner_enabled; then + "$@" + return + fi + + mkdir -p "$(dirname "$log_file")" + "$@" >"$log_file" 2>&1 & + pid=$! + while kill -0 "$pid" 2>/dev/null; do + printf '\r %b%s%b %s' "$COLOR_ACCENT" "${frames[$frame]}" "$COLOR_RESET" "$message" >&2 + frame=$(((frame + 1) % ${#frames[@]})) + sleep 0.2 + done + + if wait "$pid"; then + status=0 + else + status=$? + fi + + clear_spinner_line + if [[ "$status" -eq 0 ]]; then + info "Done: $message" + else + warn "Failed: $message" + warn "Last log lines from $log_file:" + tail -n 40 "$log_file" >&2 || true + fi + return "$status" +} + +installer_log_file() { + local name="$1" + mkdir -p "$INSTALL_DIR/logs" + printf '%s/logs/%s.log\n' "$INSTALL_DIR" "$name" +} + +command_to_text() { + local arg + local out="" + + for arg in "$@"; do + if [[ -n "$out" ]]; then + out="$out " + fi + out="$out$(printf '%q' "$arg")" + done + + printf '%s\n' "$out" +} + +root_command_to_text() { + if [[ "${EUID:-$(id -u)}" -eq 0 ]]; then + command_to_text "$@" + elif command_exists sudo; then + command_to_text sudo "$@" + else + command_to_text "$@" + fi +} + +run_root_command() { + if [[ "${EUID:-$(id -u)}" -eq 0 ]]; then + "$@" + elif command_exists sudo; then + sudo "$@" + else + warn "Could not find sudo. Run this installer as an administrator or install the prerequisite manually." + return 1 + fi +} + +confirm_install_command() { + local question="$1" + local command_text="$2" + + info "Command: $command_text" + + if [[ "$YES" -eq 1 ]]; then + return 0 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + return 1 + fi + + confirm "$question" 1 +} + +init_colors() { + if [[ -n "${NO_COLOR:-}" || ! -t 1 ]]; then + return + fi + + if command_exists tput && tput colors >/dev/null 2>&1 && [[ "$(tput colors)" -ge 8 ]]; then + COLOR_RESET="$(tput sgr0)" + COLOR_HEADING="$(tput bold)$(tput setaf 6)" + COLOR_INFO="$(tput setaf 2)" + COLOR_WARN="$(tput setaf 3)" + COLOR_ERROR="$(tput bold)$(tput setaf 1)" + COLOR_ACCENT="$(tput bold)$(tput setaf 5)" + fi +} + +print_banner() { + printf '\n%b+---------------------------+%b\n' "$COLOR_ACCENT" "$COLOR_RESET" + printf '%b| Touchgraph Easy Installer |%b\n' "$COLOR_ACCENT" "$COLOR_RESET" + printf '%b+---------------------------+%b\n' "$COLOR_ACCENT" "$COLOR_RESET" +} + +parse_args() { + while [[ $# -gt 0 ]]; do + case "$1" in + --install-dir) + [[ $# -ge 2 ]] || die "--install-dir needs a path" + INSTALL_DIR="$2" + shift 2 + ;; + --api-url) + [[ $# -ge 2 ]] || die "--api-url needs a URL" + API_URL="$2" + shift 2 + ;; + --ui-url) + [[ $# -ge 2 ]] || die "--ui-url needs a URL" + UI_URL="$2" + shift 2 + ;; + --use-existing-compose) + [[ $# -ge 2 ]] || die "--use-existing-compose needs a file path" + USE_EXISTING_COMPOSE="$2" + shift 2 + ;; + --skip-tests) + RUN_TESTS=0 + shift + ;; + --no-launch) + AUTO_LAUNCH=0 + shift + ;; + --non-interactive) + NON_INTERACTIVE=1 + shift + ;; + --yes) + YES=1 + shift + ;; + --fresh) + FRESH_INSTALL=1 + shift + ;; + --remove-all) + REMOVE_ALL=1 + shift + ;; + --dry-run) + DRY_RUN=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + die "Unknown option: $1" + ;; + esac + done + + case "$API_URL" in + */) ;; + *) API_URL="$API_URL/" ;; + esac +} + +prompt_value() { + local label="$1" + local default="$2" + local helper="$3" + local answer="" + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + if [[ -n "$default" ]]; then + read -r -p "$label [$default]: " answer + printf '%s\n' "${answer:-$default}" + else + read -r -p "$label: " answer + printf '%s\n' "$answer" + fi +} + +looks_like_embedding_ollama_model() { + local model + model="$(printf '%s' "$1" | tr '[:upper:]' '[:lower:]')" + + case "$model" in + *embed*|*embedding*|*nomic*|*mxbai*|*bge*|*e5*|*gte*|*minilm*|*snowflake-arctic*) + return 0 + ;; + *) + return 1 + ;; + esac +} + +ollama_model_candidates() { + local kind="$1" + shift + local model + local selected=() + + for model in "$@"; do + case "$kind" in + embeddings) + if looks_like_embedding_ollama_model "$model"; then + selected+=("$model") + fi + ;; + chat) + if ! looks_like_embedding_ollama_model "$model"; then + selected+=("$model") + fi + ;; + *) + selected+=("$model") + ;; + esac + done + + if [[ "${#selected[@]}" -eq 0 ]]; then + selected=("$@") + fi + + for model in "${selected[@]}"; do + printf '%s\n' "$model" + done +} + +ollama_api_bases_for_host() { + local base="${OLLAMA_BASE_URL_VALUE%/}" + base="${base%/v1}" + + [[ -n "$base" ]] || base="http://localhost:11434" + printf '%s\n' "$base" + + case "$base" in + *host.docker.internal*) + printf '%s\n' "${base//host.docker.internal/localhost}" + ;; + *0.0.0.0*) + printf '%s\n' "${base//0.0.0.0/localhost}" + ;; + esac +} + +list_ollama_models_from_cli_for_host() { + local host="${1:-}" + + command_exists ollama || return 0 + + if [[ -n "$host" ]]; then + OLLAMA_HOST="$host" ollama list 2>/dev/null | awk 'NR > 1 && $1 != "" { print $1 }' || true + else + ollama list 2>/dev/null | awk 'NR > 1 && $1 != "" { print $1 }' || true + fi +} + +list_ollama_models_from_cli() { + local base + + list_ollama_models_from_cli_for_host + + if [[ -n "$OLLAMA_BASE_URL_VALUE" ]]; then + while IFS= read -r base; do + [[ -n "$base" ]] || continue + list_ollama_models_from_cli_for_host "$base" + done < <(ollama_api_bases_for_host) + fi +} + +list_ollama_models_from_api() { + command_exists curl || return 0 + command_exists python3 || return 0 + + local base + local response + + while IFS= read -r base; do + [[ -n "$base" ]] || continue + response="$(curl -fsS --max-time 2 "${base%/}/api/tags" 2>/dev/null || true)" + [[ -n "$response" ]] || continue + + printf '%s' "$response" | python3 -c 'import json, sys +try: + data = json.load(sys.stdin) +except Exception: + raise SystemExit(0) +for model in data.get("models", []): + name = model.get("name") or model.get("model") + if name: + print(name) +' 2>/dev/null || true + done < <(ollama_api_bases_for_host) +} + +list_ollama_models() { + { + list_ollama_models_from_cli + list_ollama_models_from_api + } | awk 'NF && !seen[$0]++' +} + +ollama_model_name_matches() { + local installed="$1" + local target="$2" + + [[ "$installed" == "$target" ]] && return 0 + [[ "$target" != *:* && "$installed" == "$target:latest" ]] && return 0 + [[ "$installed" != *:* && "$target" == "$installed:latest" ]] && return 0 + + return 1 +} + +find_reachable_ollama_cli_host() { + local base + + command_exists ollama || return 1 + + if ollama list >/dev/null 2>&1; then + printf '\n' + return 0 + fi + + while IFS= read -r base; do + [[ -n "$base" ]] || continue + if OLLAMA_HOST="$base" ollama list >/dev/null 2>&1; then + printf '%s\n' "$base" + return 0 + fi + done < <(ollama_api_bases_for_host) + + return 1 +} + +ollama_model_available_via_cli_host() { + local host="$1" + local target="$2" + local model + + while IFS= read -r model; do + ollama_model_name_matches "$model" "$target" && return 0 + done < <(list_ollama_models_from_cli_for_host "$host") + + return 1 +} + +pull_ollama_model() { + local host="$1" + local model="$2" + local log_file + log_file="$(installer_log_file "ollama-pull-${model//\//-}")" + + if [[ -n "$host" ]]; then + run_with_spinner_logged "Downloading Ollama model $model" "$log_file" env OLLAMA_HOST="$host" ollama pull "$model" + else + run_with_spinner_logged "Downloading Ollama model $model" "$log_file" ollama pull "$model" + fi +} + +wait_for_ollama_service() { + local timeout="${1:-30}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if find_reachable_ollama_cli_host >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +start_ollama_service_if_possible() { + local command_text + local log_file="$INSTALL_DIR/ollama.log" + + say "Ollama service is not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists open && [[ -d /Applications/Ollama.app ]]; then + command_text="$(command_to_text open -a Ollama)" + if confirm_install_command "Start the Ollama app now?" "$command_text"; then + open -a Ollama + wait_for_ollama_service 45 + return + fi + fi + + if command_exists brew; then + command_text="$(command_to_text brew services start ollama)" + if confirm_install_command "Start the Ollama service with Homebrew now?" "$command_text"; then + brew services start ollama + wait_for_ollama_service 45 + return + fi + fi + + command_text="$(command_to_text ollama serve) > $(printf '%q' "$log_file") 2>&1 &" + if confirm_install_command "Start Ollama in the background now?" "$command_text"; then + mkdir -p "$INSTALL_DIR" + nohup ollama serve > "$log_file" 2>&1 & + wait_for_ollama_service 45 + return + fi + + return 1 +} + +offer_single_ollama_model_download() { + local kind="$1" + local default_model="$2" + local selected_model="$3" + local processor_label="$4" + local cli_host="$5" + local question + + say "Preparing Ollama $kind model" + info "TrustGraph's Ollama $processor_label default is $default_model." + + if ollama_model_available_via_cli_host "$cli_host" "$selected_model"; then + info "Ollama $kind model already available: $selected_model" + return 0 + fi + + if [[ "$selected_model" == "$default_model" ]]; then + question="Download TrustGraph's preferred Ollama $kind model ($selected_model) now?" + else + question="Download the selected Ollama $kind model ($selected_model) now?" + fi + + if confirm "$question" 1; then + info "Downloading $selected_model with Ollama. This may take a while." + if ! pull_ollama_model "$cli_host" "$selected_model"; then + die "Ollama could not download $selected_model. Try running: ollama pull $selected_model" + fi + else + warn "Skipping Ollama $kind model download. TrustGraph's Ollama processor will try to pull $selected_model on first use." + fi +} + +offer_ollama_model_downloads() { + local cli_host + + [[ "$LLM_MODE" == "ollama" ]] || return 0 + + if ! command_exists ollama; then + warn "Ollama was selected, but the ollama CLI was not found. Install Ollama and run: ollama pull $OLLAMA_MODEL && ollama pull $OLLAMA_EMBEDDINGS_MODEL" + return 0 + fi + + if ! cli_host="$(find_reachable_ollama_cli_host)"; then + start_ollama_service_if_possible || true + if ! cli_host="$(find_reachable_ollama_cli_host)"; then + warn "Ollama CLI is installed, but the Ollama service is not reachable. Start Ollama and run: ollama pull $OLLAMA_MODEL && ollama pull $OLLAMA_EMBEDDINGS_MODEL" + return 0 + fi + fi + + if [[ -n "$cli_host" ]]; then + info "Ollama service: $cli_host" + else + info "Ollama service: local Ollama default" + fi + + offer_single_ollama_model_download \ + "chat" \ + "$DEFAULT_OLLAMA_MODEL" \ + "$OLLAMA_MODEL" \ + "text-completion" \ + "$cli_host" + + offer_single_ollama_model_download \ + "embeddings" \ + "$DEFAULT_OLLAMA_EMBEDDINGS_MODEL" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "$cli_host" +} + +prompt_ollama_model_choice() { + local label="$1" + local default="$2" + local kind="$3" + local helper="$4" + shift 4 + local all_models=("$@") + local candidates=() + local options=() + local model + local answer + local idx + local found_default=0 + local detected_default="" + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "${#all_models[@]}" -eq 0 ]]; then + prompt_value \ + "$label" \ + "$default" \ + "No local Ollama models were detected. Pull the recommended default with: ollama pull $default" + return + fi + + while IFS= read -r model; do + [[ -n "$model" ]] && candidates+=("$model") + done < <(ollama_model_candidates "$kind" "${all_models[@]}") + + if [[ "${#candidates[@]}" -eq 0 ]]; then + candidates=("${all_models[@]}") + fi + + for model in "${candidates[@]}"; do + if ollama_model_name_matches "$model" "$default"; then + found_default=1 + detected_default="$model" + break + fi + done + + options+=("$default") + for model in "${candidates[@]}"; do + ollama_model_name_matches "$model" "$default" && continue + options+=("$model") + done + + say "Local Ollama ${kind} model choices" >&2 + if [[ "$found_default" -eq 1 ]]; then + if [[ "$detected_default" != "$default" ]]; then + info "1) $default (recommended, detected as $detected_default)" >&2 + else + info "1) $default (recommended, detected)" >&2 + fi + else + info "1) $default (recommended default, not detected locally)" >&2 + fi + + idx=2 + for model in "${options[@]:1}"; do + info "$idx) $model" >&2 + idx=$((idx + 1)) + done + if [[ "$found_default" -eq 0 ]]; then + info "If you choose a missing model, the installer will offer to download it before startup." >&2 + fi + info "Or type another model name, for example one you plan to pull before startup." >&2 + + read -r -p "$label [1: $default]: " answer + answer="${answer:-1}" + + if [[ "$answer" =~ ^[0-9]+$ ]]; then + if (( answer >= 1 && answer <= ${#options[@]} )); then + printf '%s\n' "${options[$((answer - 1))]}" + return + fi + warn "Selection '$answer' is not in the list; using $default." + printf '%s\n' "$default" + return + fi + + printf '%s\n' "$answer" +} + +prompt_secret() { + local label="$1" + local default="$2" + local helper="$3" + local answer="" + local masked="${4:-}" + + if [[ -z "$masked" && -n "$default" ]]; then + masked="set in environment" + elif [[ -z "$masked" ]]; then + masked="blank" + fi + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + read -r -s -p "$label [$masked]: " answer + printf '\n' >&2 + printf '%s\n' "${answer:-$default}" +} + +confirm() { + local question="$1" + local default_yes="$2" + local answer="" + local prompt="[y/N]" + + if [[ "$YES" -eq 1 ]]; then + return 0 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + [[ "$default_yes" -eq 1 ]] + return + fi + + if [[ "$default_yes" -eq 1 ]]; then + prompt="[Y/n]" + fi + + read -r -p "$question $prompt " answer + answer="${answer:-}" + if [[ -z "$answer" ]]; then + [[ "$default_yes" -eq 1 ]] + return + fi + [[ "$answer" =~ ^[Yy] ]] +} + +path_within_install_dir() { + local path="$1" + case "$path" in + "$INSTALL_DIR"/*) return 0 ;; + *) return 1 ;; + esac +} + +safe_existing_path_within_install_dir() { + local path="$1" + local resolved_install + local resolved_parent + local resolved_path + local parent + local base + + [[ -d "$INSTALL_DIR" ]] || return 1 + [[ -e "$path" || -L "$path" ]] || return 1 + + resolved_install="$(cd "$INSTALL_DIR" && pwd -P)" + parent="$(dirname "$path")" + base="$(basename "$path")" + [[ -d "$parent" ]] || return 1 + resolved_parent="$(cd "$parent" && pwd -P)" + resolved_path="$resolved_parent/$base" + + case "$resolved_path" in + "$resolved_install"/*) return 0 ;; + *) return 1 ;; + esac +} + +installer_artifact_paths() { + local candidates=( + "$INSTALL_DIR/deploy.zip" + "$INSTALL_DIR/deploy" + "$INSTALL_DIR/INSTALLATION.md" + "$INSTALL_DIR/trustgraph-installer.env" + "$INSTALL_DIR/iam-bootstrap.log" + "$INSTALL_DIR/ollama.log" + "$INSTALL_DIR/logs" + "$INSTALL_DIR/pip_cache" + ) + local path + + for path in "$VENV_DIR" "$NLTK_DATA_DIR" "$TIKTOKEN_CACHE_DIR_VALUE"; do + if path_within_install_dir "$path"; then + candidates+=("$path") + fi + done + + for path in "${candidates[@]}"; do + if [[ -e "$path" || -L "$path" ]]; then + printf '%s\n' "$path" + fi + done +} + +installer_artifacts_present() { + local path + while IFS= read -r path; do + [[ -n "$path" ]] && return 0 + done < <(installer_artifact_paths) + return 1 +} + +assert_safe_cleanup_target() { + local resolved_install + local resolved_script + local resolved_home="" + + [[ -n "$INSTALL_DIR" ]] || die "Install directory is empty; refusing cleanup." + case "$INSTALL_DIR" in + /|.|..) die "Install directory '$INSTALL_DIR' is too broad; refusing cleanup." ;; + esac + + if [[ -d "$INSTALL_DIR" ]]; then + resolved_install="$(cd "$INSTALL_DIR" && pwd -P)" + else + return 0 + fi + resolved_script="$(cd "$SCRIPT_DIR" && pwd -P)" + if [[ -n "${HOME:-}" && -d "$HOME" ]]; then + resolved_home="$(cd "$HOME" && pwd -P)" + fi + + [[ "$resolved_install" != "$resolved_script" ]] || die "Install directory resolves to the source checkout; refusing cleanup." + [[ -z "$resolved_home" || "$resolved_install" != "$resolved_home" ]] || die "Install directory resolves to your home directory; refusing cleanup." +} + +find_existing_compose_file() { + [[ -d "$INSTALL_DIR" ]] || return 0 + + find "$INSTALL_DIR/deploy" "$INSTALL_DIR" \ + \( -name 'docker-compose.yaml' -o -name 'docker-compose.yml' -o -name 'compose.yaml' -o -name 'compose.yml' \) \ + -type f 2>/dev/null | head -n 1 +} + +stop_previous_stack_if_possible() { + local compose_file + compose_file="$(find_existing_compose_file || true)" + [[ -n "$compose_file" ]] || return 0 + + if ! confirm "Stop any containers from the previous deployment first? Docker volumes will be kept." 1; then + info "Leaving any previous containers untouched." + return + fi + + if ! detect_compose_command; then + warn "Could not find Docker Compose or podman-compose, so previous containers were not stopped." + return + fi + + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans; then + info "Stopped previous compose deployment using $compose_file" + else + warn "Could not stop previous compose deployment. Continuing with file cleanup only." + fi +} + +remove_installer_artifacts_only() { + local path + + say "Removing installer-managed files" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + if ! safe_existing_path_within_install_dir "$path"; then + warn "Skipping cleanup path outside install directory: $path" + continue + fi + info "Removing $path" + rm -rf -- "$path" + done < <(installer_artifact_paths) + + rmdir "$INSTALL_DIR" 2>/dev/null || true +} + +cleanup_installer_artifacts() { + assert_safe_cleanup_target + stop_previous_stack_if_possible + remove_installer_artifacts_only +} + +find_uninstall_compose_file() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + [[ -f "$USE_EXISTING_COMPOSE" ]] || die "Compose file does not exist: $USE_EXISTING_COMPOSE" + printf '%s\n' "$USE_EXISTING_COMPOSE" + return + fi + + find_existing_compose_file || true +} + +print_uninstall_plan() { + local compose_file="$1" + local path + local found=0 + + say "Uninstall plan" + info "Install directory: $INSTALL_DIR" + if [[ -n "$compose_file" ]]; then + info "Compose file: $compose_file" + info "Will stop containers and remove compose-managed volumes for this deployment." + else + info "No compose file was found, so no containers or volumes can be removed automatically." + fi + + info "Installer-managed files to remove:" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + found=1 + info "- $path" + done < <(installer_artifact_paths) + if [[ "$found" -eq 0 ]]; then + info "- none found" + fi + + info "Will not remove Docker/Podman, container images, external volumes, Ollama, Ollama models, or this source checkout." +} + +remove_compose_stack_for_uninstall() { + local compose_file="$1" + + [[ -n "$compose_file" ]] || return 0 + + say "Stopping TrustGraph containers and volumes" + if ! detect_compose_command; then + warn "Could not find Docker Compose or podman-compose, so containers and volumes were not removed." + return + fi + + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans --volumes; then + info "Removed compose containers, networks, and compose-managed volumes." + return + fi + + warn "Compose did not accept the volume removal command; trying to stop containers without removing volumes." + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans; then + warn "Containers were stopped, but compose volumes may remain." + else + warn "Could not stop the compose deployment. Installer-managed files will still be removed." + fi +} + +remove_all_installation() { + local compose_file + + assert_safe_cleanup_target + compose_file="$(find_uninstall_compose_file || true)" + + print_uninstall_plan "$compose_file" + + if [[ "$DRY_RUN" -eq 1 ]]; then + say "Dry run complete" + return 0 + fi + + if [[ -z "$compose_file" ]] && ! installer_artifacts_present; then + say "Nothing to remove" + info "No installer-managed files or compose deployment were found." + return 0 + fi + + if ! confirm "Remove the TrustGraph deployment listed above?" 0; then + die "Uninstall cancelled." + fi + + remove_compose_stack_for_uninstall "$compose_file" + remove_installer_artifacts_only + + say "TrustGraph installer-managed deployment removed" + info "Ollama models were left in place because they may be shared with other tools." +} + +handle_existing_install() { + local path + local found=0 + + [[ -z "$USE_EXISTING_COMPOSE" ]] || return 0 + installer_artifacts_present || return 0 + + say "Existing installer output detected" + info "Install directory: $INSTALL_DIR" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + found=1 + info "Found: $path" + done < <(installer_artifact_paths) + + [[ "$found" -eq 1 ]] || return 0 + + if [[ "$DRY_RUN" -eq 1 ]]; then + if [[ "$FRESH_INSTALL" -eq 1 ]]; then + info "Dry run: --fresh would remove the files listed above." + else + info "Dry run: existing files would be kept unless you choose --fresh." + fi + return 0 + fi + + if [[ "$FRESH_INSTALL" -eq 1 ]]; then + cleanup_installer_artifacts + return 0 + fi + + if confirm "Treat this as a fresh install and delete only the installer-managed files listed above?" 0; then + cleanup_installer_artifacts + else + info "Continuing with the existing installer output." + fi +} + +load_saved_answers() { + local env_file="$INSTALL_DIR/trustgraph-installer.env" + + [[ -f "$env_file" ]] || return 0 + + local current_api_url="$API_URL" + local current_ui_url="$UI_URL" + local current_auth_token="$AUTH_TOKEN" + local current_venv_dir="$VENV_DIR" + local current_nltk_data_dir="$NLTK_DATA_DIR" + local current_tiktoken_cache_dir="$TIKTOKEN_CACHE_DIR_VALUE" + local current_llm_mode="$LLM_MODE" + local current_openai_base_url="$OPENAI_BASE_URL_VALUE" + local current_openai_token="$OPENAI_TOKEN_VALUE" + local current_ollama_base_url="$OLLAMA_BASE_URL_VALUE" + local current_ollama_model="$OLLAMA_MODEL" + local current_ollama_embeddings_model="$OLLAMA_EMBEDDINGS_MODEL" + + # The file is generated by this installer with shell-escaped exports and 0600 permissions. + # shellcheck disable=SC1090 + source "$env_file" + + if [[ "$current_api_url" == "$DEFAULT_API_URL" && -n "${TRUSTGRAPH_URL:-}" ]]; then + API_URL="$TRUSTGRAPH_URL" + else + API_URL="$current_api_url" + fi + + if [[ "$current_ui_url" == "$DEFAULT_UI_URL" && -n "${TRUSTGRAPH_UI_URL:-}" ]]; then + UI_URL="$TRUSTGRAPH_UI_URL" + else + UI_URL="$current_ui_url" + fi + + if [[ -z "$current_auth_token" && -n "${TRUSTGRAPH_TOKEN:-}" ]]; then + AUTH_TOKEN="$TRUSTGRAPH_TOKEN" + elif [[ -z "$current_auth_token" && -n "${IAM_BOOTSTRAP_TOKEN:-}" ]]; then + AUTH_TOKEN="$IAM_BOOTSTRAP_TOKEN" + else + AUTH_TOKEN="$current_auth_token" + fi + + if [[ -z "${TG_VENV_DIR:-}" && -n "${current_venv_dir:-}" ]]; then + VENV_DIR="$current_venv_dir" + elif [[ -n "${TG_VENV_DIR:-}" ]]; then + VENV_DIR="$TG_VENV_DIR" + fi + + if [[ -z "${TG_NLTK_DATA_DIR:-}" && -n "$current_nltk_data_dir" ]]; then + NLTK_DATA_DIR="$current_nltk_data_dir" + elif [[ -n "${TG_NLTK_DATA_DIR:-}" ]]; then + NLTK_DATA_DIR="$TG_NLTK_DATA_DIR" + fi + + if [[ -z "${TIKTOKEN_CACHE_DIR:-}" && -n "$current_tiktoken_cache_dir" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$current_tiktoken_cache_dir" + elif [[ -n "${TIKTOKEN_CACHE_DIR:-}" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$TIKTOKEN_CACHE_DIR" + fi + + if [[ -z "$current_llm_mode" && -n "${TRUSTGRAPH_LLM_MODE:-}" ]]; then + LLM_MODE="$TRUSTGRAPH_LLM_MODE" + else + LLM_MODE="$current_llm_mode" + fi + + if [[ "$current_openai_base_url" == "https://api.openai.com/v1" && -n "${OPENAI_BASE_URL:-}" ]]; then + OPENAI_BASE_URL_VALUE="$OPENAI_BASE_URL" + else + OPENAI_BASE_URL_VALUE="$current_openai_base_url" + fi + + if [[ -z "$current_openai_token" && -n "${OPENAI_TOKEN:-}" ]]; then + OPENAI_TOKEN_VALUE="$OPENAI_TOKEN" + else + OPENAI_TOKEN_VALUE="$current_openai_token" + fi + + if [[ -z "$current_ollama_base_url" ]]; then + OLLAMA_BASE_URL_VALUE="${OLLAMA_HOST:-${OLLAMA_BASE_URL:-}}" + else + OLLAMA_BASE_URL_VALUE="$current_ollama_base_url" + fi + + if [[ "$current_ollama_model" == "$DEFAULT_OLLAMA_MODEL" && -n "${OLLAMA_MODEL:-}" ]]; then + OLLAMA_MODEL="$OLLAMA_MODEL" + else + OLLAMA_MODEL="$current_ollama_model" + fi + + if [[ "$current_ollama_embeddings_model" == "$DEFAULT_OLLAMA_EMBEDDINGS_MODEL" && -n "${OLLAMA_EMBEDDINGS_MODEL:-}" ]]; then + OLLAMA_EMBEDDINGS_MODEL="$OLLAMA_EMBEDDINGS_MODEL" + else + OLLAMA_EMBEDDINGS_MODEL="$current_ollama_embeddings_model" + fi + + case "$API_URL" in + */) ;; + *) API_URL="$API_URL/" ;; + esac + + info "Loaded saved answers from $env_file" +} + +bytes_to_gb() { + local bytes="$1" + awk "BEGIN { printf \"%.0f\", $bytes / 1024 / 1024 / 1024 }" +} + +detect_hardware() { + HW_OS="$(uname -s 2>/dev/null || printf 'unknown')" + HW_ARCH="$(uname -m 2>/dev/null || printf 'unknown')" + + if [[ "$HW_OS" == "Darwin" ]]; then + HW_CPU_CORES="$(sysctl -n hw.logicalcpu 2>/dev/null || getconf _NPROCESSORS_ONLN 2>/dev/null || python3 -c 'import os; print(os.cpu_count() or "unknown")' 2>/dev/null || printf 'unknown')" + local mem_bytes + mem_bytes="$(sysctl -n hw.memsize 2>/dev/null || true)" + if [[ -z "$mem_bytes" ]] && command_exists python3; then + mem_bytes="$(python3 -c 'import os; print(os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE"))' 2>/dev/null || true)" + fi + if [[ -n "$mem_bytes" ]]; then + HW_MEMORY_GB="$(bytes_to_gb "$mem_bytes")" + fi + if [[ "$HW_ARCH" == "arm64" ]]; then + HW_GPU="Apple Silicon unified GPU" + fi + HW_CONTAINER_HINT="Docker Desktop or Podman Desktop works well on macOS." + elif [[ "$HW_OS" == "Linux" ]]; then + HW_CPU_CORES="$(nproc 2>/dev/null || getconf _NPROCESSORS_ONLN 2>/dev/null || python3 -c 'import os; print(os.cpu_count() or "unknown")' 2>/dev/null || printf 'unknown')" + if [[ -r /proc/meminfo ]]; then + local mem_kb + mem_kb="$(awk '/MemTotal/ { print $2 }' /proc/meminfo)" + if [[ -n "$mem_kb" ]]; then + HW_MEMORY_GB="$(awk "BEGIN { printf \"%.0f\", $mem_kb / 1024 / 1024 }")" + fi + fi + if command_exists nvidia-smi; then + HW_GPU="$(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null | head -n 1 || true)" + [[ -n "$HW_GPU" ]] || HW_GPU="NVIDIA GPU detected" + elif command_exists lspci; then + HW_GPU="$(lspci 2>/dev/null | awk 'BEGIN{IGNORECASE=1} /VGA|3D|Display/ {print; exit}')" + [[ -n "$HW_GPU" ]] || HW_GPU="none detected" + fi + HW_CONTAINER_HINT="Docker Engine, Docker Desktop, or Podman can run the compose stack." + else + HW_CONTAINER_HINT="Use Docker or Podman with compose support." + fi +} + +is_number() { + [[ "$1" =~ ^[0-9]+$ ]] +} + +choose_recommendations() { + local mem=0 + local cores=0 + + if is_number "$HW_MEMORY_GB"; then + mem="$HW_MEMORY_GB" + fi + if is_number "$HW_CPU_CORES"; then + cores="$HW_CPU_CORES" + fi + + if [[ -z "$OLLAMA_BASE_URL_VALUE" ]]; then + if [[ "$HW_OS" == "Darwin" ]]; then + OLLAMA_BASE_URL_VALUE="http://host.docker.internal:11434" + else + OLLAMA_BASE_URL_VALUE="http://localhost:11434" + fi + fi + + if [[ -n "$LLM_MODE" ]]; then + RECOMMENDED_LLM_MODE="$LLM_MODE" + RECOMMENDATION_REASON="Using the LLM provider saved from the previous installer run." + return + fi + + if (( mem >= 16 )) && { [[ "$HW_GPU" != "none detected" ]] || (( cores >= 8 )); }; then + RECOMMENDED_LLM_MODE="ollama" + RECOMMENDATION_REASON="This machine looks comfortable for a small local Ollama model." + elif (( mem >= 8 )); then + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="Local Ollama may work with a small model, but a hosted OpenAI-compatible endpoint is smoother on this hardware." + else + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="Memory looks tight for local LLMs, so a hosted OpenAI-compatible endpoint is the friendlier default." + fi + + if [[ -n "${OPENAI_TOKEN:-}" && "${OPENAI_TOKEN:-}" != "ollama" ]]; then + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="OPENAI_TOKEN is already set, so the hosted/OpenAI-compatible path is ready to use." + fi +} + +print_hardware_summary() { + say "Detected hardware" + info "OS: $HW_OS" + info "Architecture: $HW_ARCH" + info "CPU cores: $HW_CPU_CORES" + info "Memory: $HW_MEMORY_GB GB" + info "GPU: $HW_GPU" + info "$HW_CONTAINER_HINT" + + say "Recommended install shape" + info "LLM path: $RECOMMENDED_LLM_MODE" + info "$RECOMMENDATION_REASON" + info "Default Workbench UI: $UI_URL" + info "Default API gateway: $API_URL" +} + +generate_token() { + if command_exists openssl; then + printf 'tg_%s\n' "$(openssl rand -base64 24 | tr '+/' '-_' | tr -d '=')" + elif command_exists python3; then + python3 -c 'import secrets; print("tg_" + secrets.token_urlsafe(24))' + else + die "Need openssl or python3 to generate a secure TrustGraph API key." + fi +} + +ensure_compliant_api_key() { + local token="$1" + + if [[ "$token" == tg_* ]]; then + printf '%s\n' "$token" + return + fi + + warn "TrustGraph API keys must start with 'tg_'; the provided value will not authenticate at the gateway." + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + warn "Non-interactive mode: replacing the non-compliant key with a generated TrustGraph API key." + generate_token + return + fi + + if confirm "Generate a compliant TrustGraph API key now?" 1; then + generate_token + return + fi + + die "TrustGraph API key must start with 'tg_'." +} + +collect_answers() { + local generated_token + local token_default + local token_mask + generated_token="$(generate_token)" + if [[ -n "$AUTH_TOKEN" ]]; then + token_default="$AUTH_TOKEN" + token_mask="set in environment" + else + token_default="$generated_token" + token_mask="generated tg_ key" + fi + AUTH_TOKEN="$(prompt_secret \ + "TrustGraph admin/bootstrap API key" \ + "$token_default" \ + "Recommendation: press Enter to use a generated TrustGraph API key beginning with tg_; it will be stored in the installer env file with restricted permissions." \ + "$token_mask")" + AUTH_TOKEN="$(ensure_compliant_api_key "$AUTH_TOKEN")" + + LLM_MODE="$(prompt_value \ + "LLM provider: ollama, openai, or none" \ + "${LLM_MODE:-$RECOMMENDED_LLM_MODE}" \ + "Recommendation: $RECOMMENDED_LLM_MODE. $RECOMMENDATION_REASON")" + LLM_MODE="$(printf '%s' "$LLM_MODE" | tr '[:upper:]' '[:lower:]')" + + case "$LLM_MODE" in + ollama) + OLLAMA_BASE_URL_VALUE="$(prompt_value \ + "Ollama base URL" \ + "$OLLAMA_BASE_URL_VALUE" \ + "If Ollama runs on your laptop and TrustGraph runs in Docker, host.docker.internal is usually the right host on macOS/Windows.")" + local ollama_models=() + local ollama_model + if [[ "$NON_INTERACTIVE" -ne 1 ]]; then + while IFS= read -r ollama_model; do + [[ -n "$ollama_model" ]] && ollama_models+=("$ollama_model") + done < <(list_ollama_models) + fi + if [[ "${#ollama_models[@]}" -gt 0 ]]; then + OLLAMA_MODEL="$(prompt_ollama_model_choice \ + "Ollama chat model" \ + "$OLLAMA_MODEL" \ + "chat" \ + "Recommendation from the local Ollama processor defaults: $DEFAULT_OLLAMA_MODEL for a quick first run." \ + "${ollama_models[@]}")" + OLLAMA_EMBEDDINGS_MODEL="$(prompt_ollama_model_choice \ + "Ollama embeddings model" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "Recommendation from the local Ollama embeddings defaults: $DEFAULT_OLLAMA_EMBEDDINGS_MODEL." \ + "${ollama_models[@]}")" + else + OLLAMA_MODEL="$(prompt_ollama_model_choice \ + "Ollama chat model" \ + "$OLLAMA_MODEL" \ + "chat" \ + "Recommendation from the local Ollama processor defaults: $DEFAULT_OLLAMA_MODEL for a quick first run.")" + OLLAMA_EMBEDDINGS_MODEL="$(prompt_ollama_model_choice \ + "Ollama embeddings model" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "Recommendation from the local Ollama embeddings defaults: $DEFAULT_OLLAMA_EMBEDDINGS_MODEL.")" + fi + OPENAI_BASE_URL_VALUE="${OLLAMA_BASE_URL_VALUE%/}/v1" + OPENAI_TOKEN_VALUE="${OPENAI_TOKEN_VALUE:-ollama}" + ;; + openai) + OPENAI_BASE_URL_VALUE="$(prompt_value \ + "OpenAI-compatible base URL" \ + "$OPENAI_BASE_URL_VALUE" \ + "Use https://api.openai.com/v1 for OpenAI, or your provider's OpenAI-compatible /v1 endpoint.")" + OPENAI_TOKEN_VALUE="$(prompt_secret \ + "OpenAI-compatible API key" \ + "$OPENAI_TOKEN_VALUE" \ + "Press Enter to reuse OPENAI_TOKEN if set; leave blank only if your endpoint does not require a key.")" + ;; + none|skip) + LLM_MODE="none" + warn "Continuing without an LLM key. The platform can start, but agent/RAG calls will need an LLM configured later." + ;; + *) + warn "Unknown LLM provider '$LLM_MODE'; using '$RECOMMENDED_LLM_MODE'." + LLM_MODE="$RECOMMENDED_LLM_MODE" + ;; + esac + + INSTALL_DIR="$(prompt_value \ + "Installer output directory" \ + "$INSTALL_DIR" \ + "This keeps deploy.zip, compose files, logs, and saved answers together.")" + + if [[ -z "${TG_VENV_DIR:-}" ]]; then + VENV_DIR="$INSTALL_DIR/.venv" + fi + if [[ -z "${TG_NLTK_DATA_DIR:-}" ]]; then + NLTK_DATA_DIR="$INSTALL_DIR/nltk_data" + fi + if [[ -z "${TIKTOKEN_CACHE_DIR:-}" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$INSTALL_DIR/tiktoken_cache" + fi +} + +print_plan_summary() { + say "Install plan" + info "Install directory: $INSTALL_DIR" + info "Python venv: $VENV_DIR" + info "NLTK data: $NLTK_DATA_DIR" + info "Tokenizer cache: $TIKTOKEN_CACHE_DIR_VALUE" + info "Run all tests: $([[ "$RUN_TESTS" -eq 1 ]] && printf yes || printf no)" + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + info "Compose file: $USE_EXISTING_COMPOSE" + else + info "Config generator: npx @trustgraph/config" + fi + info "LLM provider: $LLM_MODE" + if [[ "$LLM_MODE" == "ollama" ]]; then + info "Ollama URL: $OLLAMA_BASE_URL_VALUE" + info "Ollama model: $OLLAMA_MODEL" + info "Ollama embeddings model: $OLLAMA_EMBEDDINGS_MODEL" + elif [[ "$LLM_MODE" == "openai" ]]; then + info "OpenAI-compatible URL: $OPENAI_BASE_URL_VALUE" + fi + info "Health check timeout: ${HEALTH_TIMEOUT}s" + info "Autolaunch UI: $([[ "$AUTO_LAUNCH" -eq 1 ]] && printf yes || printf no)" +} + +detect_compose_command() { + if command_exists docker && docker compose version >/dev/null 2>&1; then + COMPOSE_CMD=(docker compose) + elif command_exists docker-compose; then + COMPOSE_CMD=(docker-compose) + elif command_exists podman-compose; then + COMPOSE_CMD=(podman-compose) + else + return 1 + fi +} + +wait_for_docker_ready() { + local timeout="${1:-60}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if docker info >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +wait_for_podman_ready() { + local timeout="${1:-60}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if podman info >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +start_docker_runtime_if_possible() { + local command_text + + say "Docker is installed but not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists open && [[ -d /Applications/Docker.app ]]; then + command_text="$(command_to_text open -a Docker)" + if confirm_install_command "Start Docker Desktop now?" "$command_text"; then + open -a Docker + wait_for_docker_ready 90 + return + fi + fi + + if command_exists systemctl; then + command_text="$(root_command_to_text systemctl start docker)" + if confirm_install_command "Start the Docker service now?" "$command_text"; then + run_root_command systemctl start docker + wait_for_docker_ready 60 + return + fi + fi + + return 1 +} + +start_podman_runtime_if_possible() { + local command_text + + say "Podman is installed but not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists podman; then + command_text="$(command_to_text podman machine init) && $(command_to_text podman machine start)" + if confirm_install_command "Start a local Podman machine now?" "$command_text"; then + podman machine init >/dev/null 2>&1 || true + podman machine start + wait_for_podman_ready 90 + return + fi + fi + + if command_exists systemctl; then + command_text="$(command_to_text systemctl --user start podman.socket)" + if confirm_install_command "Start the user Podman socket now?" "$command_text"; then + systemctl --user start podman.socket + wait_for_podman_ready 30 + return + fi + fi + + return 1 +} + +check_container_runtime_ready() { + case "${COMPOSE_CMD[0]}" in + docker|docker-compose) + if ! docker info >/dev/null 2>&1; then + start_docker_runtime_if_possible || true + docker info >/dev/null 2>&1 || die "Docker is installed, but the Docker daemon is not reachable. Start Docker Desktop or Docker Engine and run this installer again." + fi + ;; + podman-compose) + if ! podman info >/dev/null 2>&1; then + start_podman_runtime_if_possible || true + podman info >/dev/null 2>&1 || die "Podman is installed, but the Podman service is not reachable. Start Podman Desktop or the Podman machine and run this installer again." + fi + ;; + esac +} + +install_with_brew() { + local label="$1" + shift + local command_text + local log_file + command_text="$(command_to_text brew install "$@")" + log_file="$(installer_log_file "brew-install-${label// /-}")" + + if confirm_install_command "Install $label with Homebrew now?" "$command_text"; then + run_with_spinner_logged "Installing $label with Homebrew" "$log_file" brew install "$@" + else + return 1 + fi +} + +install_with_apt() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text apt-get update) && $(root_command_to_text apt-get install -y "$@")" + + if confirm_install_command "Install $label with apt now?" "$command_text"; then + run_root_command apt-get update + run_root_command apt-get install -y "$@" + else + return 1 + fi +} + +install_with_dnf() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text dnf install -y "$@")" + + if confirm_install_command "Install $label with dnf now?" "$command_text"; then + run_root_command dnf install -y "$@" + else + return 1 + fi +} + +install_with_yum() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text yum install -y "$@")" + + if confirm_install_command "Install $label with yum now?" "$command_text"; then + run_root_command yum install -y "$@" + else + return 1 + fi +} + +install_with_pacman() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text pacman -Sy --noconfirm "$@")" + + if confirm_install_command "Install $label with pacman now?" "$command_text"; then + run_root_command pacman -Sy --noconfirm "$@" + else + return 1 + fi +} + +install_with_zypper() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text zypper install -y "$@")" + + if confirm_install_command "Install $label with zypper now?" "$command_text"; then + run_root_command zypper install -y "$@" + else + return 1 + fi +} + +install_python3_prerequisite() { + if command_exists brew; then + install_with_brew "Python 3" python + elif command_exists apt-get; then + install_with_apt "Python 3" python3 python3-venv python3-pip + elif command_exists dnf; then + install_with_dnf "Python 3" python3 python3-pip + elif command_exists yum; then + install_with_yum "Python 3" python3 python3-pip + elif command_exists pacman; then + install_with_pacman "Python 3" python + elif command_exists zypper; then + install_with_zypper "Python 3" python3 python3-pip python3-venv + else + warn "No supported package manager was found. Install Python 3 manually, then run this installer again." + return 1 + fi +} + +install_python_venv_prerequisite() { + if command_exists apt-get; then + install_with_apt "Python venv support" python3-venv + elif command_exists zypper; then + install_with_zypper "Python venv support" python3-venv + elif command_exists brew || command_exists dnf || command_exists yum || command_exists pacman; then + info "Python venv support is usually bundled with the Python package on this platform." + return 1 + else + warn "Install Python's venv module manually, then run this installer again." + return 1 + fi +} + +install_basic_tool_prerequisite() { + local tool="$1" + + if command_exists brew; then + install_with_brew "$tool" "$tool" + elif command_exists apt-get; then + install_with_apt "$tool" "$tool" + elif command_exists dnf; then + install_with_dnf "$tool" "$tool" + elif command_exists yum; then + install_with_yum "$tool" "$tool" + elif command_exists pacman; then + install_with_pacman "$tool" "$tool" + elif command_exists zypper; then + install_with_zypper "$tool" "$tool" + else + warn "No supported package manager was found. Install $tool manually, then run this installer again." + return 1 + fi +} + +install_node_prerequisite() { + if command_exists brew; then + install_with_brew "Node.js and npx" node + elif command_exists apt-get; then + install_with_apt "Node.js and npx" nodejs npm + elif command_exists dnf; then + install_with_dnf "Node.js and npx" nodejs npm + elif command_exists yum; then + install_with_yum "Node.js and npx" nodejs npm + elif command_exists pacman; then + install_with_pacman "Node.js and npx" nodejs npm + elif command_exists zypper; then + install_with_zypper "Node.js and npx" nodejs npm + else + warn "No supported package manager was found. Install Node.js/npm manually, then run this installer again." + return 1 + fi +} + +start_podman_machine_if_needed() { + [[ "$HW_OS" == "Darwin" ]] || return 0 + command_exists podman || return 0 + + if podman info >/dev/null 2>&1; then + return 0 + fi + + if ! confirm_install_command \ + "Start a local Podman machine now?" \ + "$(command_to_text podman machine init) && $(command_to_text podman machine start)"; then + return 1 + fi + + podman machine init >/dev/null 2>&1 || true + podman machine start +} + +install_compose_prerequisite() { + if command_exists docker && ! docker compose version >/dev/null 2>&1; then + if command_exists brew; then + install_with_brew "Docker Compose" docker-compose + elif command_exists apt-get; then + install_with_apt "Docker Compose plugin" docker-compose-plugin + elif command_exists dnf; then + install_with_dnf "Docker Compose plugin" docker-compose-plugin + elif command_exists yum; then + install_with_yum "Docker Compose plugin" docker-compose-plugin + elif command_exists pacman; then + install_with_pacman "Docker Compose" docker-compose + elif command_exists zypper; then + install_with_zypper "Docker Compose" docker-compose + else + warn "Install Docker Compose manually, then run this installer again." + return 1 + fi + return + fi + + if command_exists podman && ! command_exists podman-compose; then + if command_exists brew; then + install_with_brew "podman-compose" podman-compose + elif command_exists apt-get; then + install_with_apt "podman-compose" podman-compose + elif command_exists dnf; then + install_with_dnf "podman-compose" podman-compose + elif command_exists yum; then + install_with_yum "podman-compose" podman-compose + elif command_exists pacman; then + install_with_pacman "podman-compose" podman-compose + elif command_exists zypper; then + install_with_zypper "podman-compose" podman-compose + else + warn "Install podman-compose manually, then run this installer again." + return 1 + fi + start_podman_machine_if_needed || true + return + fi + + if command_exists brew; then + info "Docker Desktop also works well. The CLI-friendly fallback is Podman plus podman-compose." + install_with_brew "Podman and podman-compose" podman podman-compose + start_podman_machine_if_needed || true + elif command_exists apt-get; then + install_with_apt "Podman and podman-compose" podman podman-compose + elif command_exists dnf; then + install_with_dnf "Podman and podman-compose" podman podman-compose + elif command_exists yum; then + install_with_yum "Podman and podman-compose" podman podman-compose + elif command_exists pacman; then + install_with_pacman "Podman and podman-compose" podman podman-compose + elif command_exists zypper; then + install_with_zypper "Podman and podman-compose" podman podman-compose + else + warn "Install Docker Desktop, Docker Engine with Compose, or Podman with podman-compose, then run this installer again." + return 1 + fi +} + +install_ollama_prerequisite() { + local command_text + + if command_exists brew; then + install_with_brew "Ollama" ollama + elif [[ "$HW_OS" == "Linux" ]] && command_exists curl; then + command_text="curl -fsSL https://ollama.com/install.sh | sh" + info "This uses Ollama's official Linux install script." + if confirm_install_command "Install Ollama now?" "$command_text"; then + sh -c "$command_text" + else + return 1 + fi + else + warn "Install Ollama from https://ollama.com/download, then run this installer again." + return 1 + fi +} + +ensure_python3_available() { + command_exists python3 && return 0 + + say "Python 3 is missing" + install_python3_prerequisite || die "Python 3 is required to run tests and helper CLIs." + command_exists python3 || die "Python 3 was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_python_venv_available() { + python3 -m venv --help >/dev/null 2>&1 && return 0 + + say "Python venv support is missing" + install_python_venv_prerequisite || die "Python venv support is required to create the installer environment." + python3 -m venv --help >/dev/null 2>&1 || die "Python venv support is still unavailable. Open a new terminal or install python3-venv manually." +} + +ensure_basic_tool_available() { + local tool="$1" + local reason="$2" + + command_exists "$tool" && return 0 + + say "$tool is missing" + info "$reason" + install_basic_tool_prerequisite "$tool" || die "$tool is required. Install it manually, then run this installer again." + command_exists "$tool" || die "$tool was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_npx_available() { + [[ -n "$USE_EXISTING_COMPOSE" ]] && return 0 + command_exists npx && return 0 + + say "npx is missing" + info "npx is required for the existing TrustGraph config generator: npx @trustgraph/config." + install_node_prerequisite || die "npx is required. Install Node.js/npm manually, then run this installer again." + command_exists npx || die "npx was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_compose_available() { + detect_compose_command && return 0 + + say "Container compose support is missing" + info "TrustGraph runs as a compose stack. Docker Compose or podman-compose is required." + install_compose_prerequisite || die "Docker Compose or podman-compose is required to start TrustGraph." + detect_compose_command || die "Compose support was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_ollama_available_if_needed() { + [[ "$LLM_MODE" == "ollama" ]] || return 0 + command_exists ollama && return 0 + + say "Ollama is missing" + info "Ollama was selected for local LLMs, so the Ollama CLI and service are needed before model setup." + install_ollama_prerequisite || die "Ollama is required for the selected local LLM path. Install it manually, then run this installer again." + command_exists ollama || die "Ollama was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +preflight() { + say "Checking prerequisites" + + ensure_python3_available + ensure_python_venv_available + ensure_basic_tool_available unzip "unzip is required to unpack deploy.zip from the config generator." + ensure_basic_tool_available curl "curl is required for startup health checks and local service probes." + ensure_npx_available + ensure_compose_available + ensure_ollama_available_if_needed + check_container_runtime_ready + + info "Compose command: ${COMPOSE_CMD[*]}" + info "Python: $(python3 --version 2>&1)" + if command_exists npx; then + info "npx: $(npx --version 2>/dev/null || printf unknown)" + fi +} + +write_env_file() { + mkdir -p "$INSTALL_DIR" + local env_file="$INSTALL_DIR/trustgraph-installer.env" + local grafana_admin_password="${GF_SECURITY_ADMIN_PASSWORD:-${GRAFANA_ADMIN_PASSWORD:-$AUTH_TOKEN}}" + + umask 077 + { + printf 'export TRUSTGRAPH_URL=%q\n' "$API_URL" + printf 'export TRUSTGRAPH_UI_URL=%q\n' "$UI_URL" + printf 'export TRUSTGRAPH_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export TRUSTGRAPH_BOOTSTRAP_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export IAM_BOOTSTRAP_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export GF_SECURITY_ADMIN_PASSWORD=%q\n' "$grafana_admin_password" + printf 'export TG_VENV_DIR=%q\n' "$VENV_DIR" + printf 'export TG_NLTK_DATA_DIR=%q\n' "$NLTK_DATA_DIR" + printf 'export NLTK_DATA=%q\n' "$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" + printf 'export TIKTOKEN_CACHE_DIR=%q\n' "$TIKTOKEN_CACHE_DIR_VALUE" + printf 'export TRUSTGRAPH_LLM_MODE=%q\n' "$LLM_MODE" + printf 'export OPENAI_BASE_URL=%q\n' "$OPENAI_BASE_URL_VALUE" + printf 'export OPENAI_TOKEN=%q\n' "$OPENAI_TOKEN_VALUE" + printf 'export OLLAMA_HOST=%q\n' "$OLLAMA_BASE_URL_VALUE" + printf 'export OLLAMA_BASE_URL=%q\n' "$OLLAMA_BASE_URL_VALUE" + printf 'export OLLAMA_MODEL=%q\n' "$OLLAMA_MODEL" + printf 'export OLLAMA_EMBEDDINGS_MODEL=%q\n' "$OLLAMA_EMBEDDINGS_MODEL" + } > "$env_file" + chmod 600 "$env_file" + + info "Saved answers to $env_file" +} + +prepare_python_env() { + say "Preparing Python environment" + mkdir -p "$INSTALL_DIR" + + if [[ ! -x "$VENV_DIR/bin/python" ]]; then + info "Creating venv at $VENV_DIR" + run_with_spinner "Creating Python venv" python3 -m venv "$VENV_DIR" + else + info "Using existing venv at $VENV_DIR" + fi + + PYTHON_BIN="$VENV_DIR/bin/python" + export PATH="$VENV_DIR/bin:$PATH" + info "Python venv: $($PYTHON_BIN --version 2>&1)" +} + +ensure_version_files() { + local version="${TRUSTGRAPH_LOCAL_VERSION:-2.5.0}" + local specs=( + "trustgraph-base/trustgraph/base_version.py:trustgraph.base_version" + "trustgraph-flow/trustgraph/flow_version.py:trustgraph.flow_version" + "trustgraph-vertexai/trustgraph/vertexai_version.py:trustgraph.vertexai_version" + "trustgraph-bedrock/trustgraph/bedrock_version.py:trustgraph.bedrock_version" + "trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py:trustgraph.embeddings_hf_version" + "trustgraph-cli/trustgraph/cli_version.py:trustgraph.cli_version" + "trustgraph-ocr/trustgraph/ocr_version.py:trustgraph.ocr_version" + "trustgraph-unstructured/trustgraph/unstructured_version.py:trustgraph.unstructured_version" + "trustgraph-mcp/trustgraph/mcp_version.py:trustgraph.mcp_version" + "trustgraph/trustgraph/trustgraph_version.py:trustgraph.trustgraph_version" + ) + + say "Ensuring local package version files" + for spec in "${specs[@]}"; do + local file="${spec%%:*}" + mkdir -p "$(dirname "$SCRIPT_DIR/$file")" + printf '__version__ = "%s"\n' "$version" > "$SCRIPT_DIR/$file" + info "Set $file to $version" + done +} + +local_package_pythonpath() { + local package_dirs=( + "$SCRIPT_DIR/trustgraph-flow" + "$SCRIPT_DIR/trustgraph-embeddings-hf" + "$SCRIPT_DIR/trustgraph-base" + "$SCRIPT_DIR/trustgraph-cli" + "$SCRIPT_DIR/trustgraph-bedrock" + "$SCRIPT_DIR/trustgraph-ocr" + "$SCRIPT_DIR/trustgraph-unstructured" + "$SCRIPT_DIR/trustgraph-mcp" + "$SCRIPT_DIR/trustgraph-vertexai" + "$SCRIPT_DIR/trustgraph" + ) + local joined="" + local dir + + for dir in "${package_dirs[@]}"; do + if [[ -d "$dir" ]]; then + if [[ -n "$joined" ]]; then + joined="$joined:$dir" + else + joined="$dir" + fi + fi + done + + printf '%s\n' "$joined" +} + +ensure_python_build_tools() { + say "Preparing Python build tools" + local pip_cache_dir="$INSTALL_DIR/pip_cache" + mkdir -p "$pip_cache_dir" + + if ! "$PYTHON_BIN" -m pip --version >/dev/null 2>&1; then + local ensurepip_log + ensurepip_log="$(installer_log_file "python-ensurepip")" + info "Installing pip into the Python venv" + run_with_spinner_logged \ + "Installing pip" \ + "$ensurepip_log" \ + "$PYTHON_BIN" -m ensurepip --upgrade \ + || die "Could not install pip into the Python venv." + fi + + if "$PYTHON_BIN" - <<'PY' >/dev/null 2>&1 +import setuptools.build_meta +PY + then + info "Python build backend available: setuptools.build_meta" + return + fi + + local log_file + log_file="$(installer_log_file "pip-build-tools")" + info "Installing setuptools and wheel into the Python venv" + run_with_spinner_logged \ + "Installing Python build tools" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install "setuptools>=61" wheel \ + || die "Could not install setuptools/wheel. Check $log_file, then re-run the installer." +} + +install_test_packages() { + say "Installing local Python packages for tests" + local pip_cache_dir="$INSTALL_DIR/pip_cache" + mkdir -p "$pip_cache_dir" + ensure_python_build_tools + + local package_dirs=( + trustgraph-base + trustgraph-cli + trustgraph-flow + trustgraph-vertexai + trustgraph-bedrock + trustgraph-embeddings-hf + trustgraph-ocr + trustgraph-unstructured + trustgraph-mcp + ) + + for package_dir in "${package_dirs[@]}"; do + if [[ -d "$SCRIPT_DIR/$package_dir" ]]; then + local log_file + log_file="$(installer_log_file "pip-${package_dir}")" + info "Installing $package_dir" + run_with_spinner_logged \ + "Installing $package_dir" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install --no-build-isolation "$SCRIPT_DIR/$package_dir" + fi + done + + if [[ -f "$SCRIPT_DIR/tests/requirements.txt" ]]; then + local log_file + log_file="$(installer_log_file "pip-test-requirements")" + info "Installing test requirements" + run_with_spinner_logged \ + "Installing test requirements" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install -r "$SCRIPT_DIR/tests/requirements.txt" + fi +} + +ensure_tokenizer_cache() { + say "Preparing tokenizer cache" + mkdir -p "$TIKTOKEN_CACHE_DIR_VALUE" + info "tiktoken cache: $TIKTOKEN_CACHE_DIR_VALUE" + + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" "$PYTHON_BIN" - <<'PY' +import tiktoken + +tiktoken.get_encoding("cl100k_base") +print(" Cached tiktoken encoding: cl100k_base") +PY +} + +ensure_nltk_data() { + say "Preparing NLTK tokenizer data" + mkdir -p "$NLTK_DATA_DIR" + info "NLTK data: $NLTK_DATA_DIR" + + TG_NLTK_DATA_DIR="$NLTK_DATA_DIR" \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + "$PYTHON_BIN" - <<'PY' +import os +import nltk + +target = os.environ["TG_NLTK_DATA_DIR"] +if target not in nltk.data.path: + nltk.data.path.insert(0, target) + +resources = ( + ("punkt", "tokenizers/punkt"), + ("punkt_tab", "tokenizers/punkt_tab"), + ("averaged_perceptron_tagger_eng", "taggers/averaged_perceptron_tagger_eng"), +) + +for package, resource in resources: + try: + nltk.data.find(resource) + except LookupError: + print(f" Downloading NLTK resource: {package}") + if not nltk.download(package, download_dir=target, quiet=True): + raise SystemExit(f"Could not download NLTK resource: {package}") + else: + print(f" NLTK resource already available: {package}") +PY +} + +run_all_tests() { + if [[ "$RUN_TESTS" -ne 1 ]]; then + warn "Skipping tests because --skip-tests was supplied." + return + fi + + prepare_python_env + ensure_version_files + install_test_packages + ensure_tokenizer_cache + ensure_nltk_data + + say "Running all tests" + info "Command: $PYTHON_BIN -m pytest tests" + local test_log + test_log="$(installer_log_file "pytest")" + if spinner_enabled; then + info "Test output log: $test_log" + fi + ( + cd "$SCRIPT_DIR" + run_with_spinner_logged \ + "Running pytest tests" \ + "$test_log" \ + env \ + INSTALL_TRUSTGRAPH_SOURCE_ONLY= \ + TG_NO_SPINNER= \ + TG_FORCE_SPINNER= \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" \ + TRUSTGRAPH_CASSANDRA_SKIP_ON_UNREADY=1 \ + "$PYTHON_BIN" -m pytest tests + ) +} + +show_config_guidance() { + say "Before the config wizard starts" + info "Choose a Docker/Podman compose deployment for local installation." + info "Keep the Workbench UI enabled; the existing UI default is port 8888." + info "Use the bundled infrastructure defaults: Cassandra, Qdrant, Garage, and RabbitMQ/Pulsar as offered." + if [[ -n "$AUTH_TOKEN" ]]; then + info "For IAM/auth, use token/bootstrap-token mode when offered." + info "Admin/bootstrap API key to enter if asked: $AUTH_TOKEN" + else + info "For IAM/auth, use token/bootstrap-token mode when offered and paste the API key saved by this installer." + fi + if [[ "$LLM_MODE" == "ollama" ]]; then + info "For LLMs, choose Ollama or an OpenAI-compatible endpoint and use $OLLAMA_BASE_URL_VALUE." + elif [[ "$LLM_MODE" == "openai" ]]; then + info "For LLMs, choose OpenAI/OpenAI-compatible and use $OPENAI_BASE_URL_VALUE." + else + info "You can skip LLM configuration now and add it later in the Workbench." + fi +} + +run_config_generator() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + return + fi + + mkdir -p "$INSTALL_DIR" + + if [[ -f "$INSTALL_DIR/deploy.zip" ]]; then + if confirm "Existing deploy.zip found in $INSTALL_DIR. Reuse it and skip the config wizard?" 1; then + info "Using existing deployment archive: $INSTALL_DIR/deploy.zip" + return + fi + fi + + show_config_guidance + + if ! confirm "Start the TrustGraph config wizard now?" 1; then + die "Config generation cancelled." + fi + + say "Running TrustGraph config generator" + ( + cd "$INSTALL_DIR" + TRUSTGRAPH_TOKEN="$AUTH_TOKEN" \ + TRUSTGRAPH_BOOTSTRAP_TOKEN="$AUTH_TOKEN" \ + OPENAI_TOKEN="$OPENAI_TOKEN_VALUE" \ + OPENAI_BASE_URL="$OPENAI_BASE_URL_VALUE" \ + OLLAMA_HOST="$OLLAMA_BASE_URL_VALUE" \ + OLLAMA_BASE_URL="$OLLAMA_BASE_URL_VALUE" \ + OLLAMA_MODEL="$OLLAMA_MODEL" \ + OLLAMA_EMBEDDINGS_MODEL="$OLLAMA_EMBEDDINGS_MODEL" \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" \ + npx @trustgraph/config + ) +} + +find_compose_file() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + [[ -f "$USE_EXISTING_COMPOSE" ]] || die "Compose file does not exist: $USE_EXISTING_COMPOSE" + printf '%s\n' "$USE_EXISTING_COMPOSE" + return + fi + + local deploy_zip="$INSTALL_DIR/deploy.zip" + local unpack_dir="$INSTALL_DIR/deploy" + + [[ -f "$deploy_zip" ]] || die "The config generator did not create $deploy_zip" + + rm -rf "$unpack_dir" + mkdir -p "$unpack_dir" + unzip -oq "$deploy_zip" -d "$unpack_dir" + + local compose_file + compose_file="$(find "$unpack_dir" "$INSTALL_DIR" \ + \( -name 'docker-compose.yaml' -o -name 'docker-compose.yml' -o -name 'compose.yaml' -o -name 'compose.yml' \) \ + -type f | head -n 1)" + + [[ -n "$compose_file" ]] || die "Could not find a compose file in $deploy_zip" + printf '%s\n' "$compose_file" +} + +compose_dir_for() { + local compose_file="$1" + (cd "$(dirname "$compose_file")" && pwd -P) +} + +compose_env_file_for() { + local compose_file="$1" + local compose_dir + + compose_dir="$(compose_dir_for "$compose_file")" + printf '%s/.env\n' "$compose_dir" +} + +write_compose_env_file() { + local compose_file="$1" + local compose_env_file + local grafana_admin_password="${GF_SECURITY_ADMIN_PASSWORD:-${GRAFANA_ADMIN_PASSWORD:-$AUTH_TOKEN}}" + + [[ -n "$AUTH_TOKEN" ]] || die "TrustGraph API key is empty; cannot create compose environment." + + compose_env_file="$(compose_env_file_for "$compose_file")" + umask 077 + { + printf 'TRUSTGRAPH_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'TRUSTGRAPH_BOOTSTRAP_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'IAM_BOOTSTRAP_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'GF_SECURITY_ADMIN_PASSWORD=%s\n' "$grafana_admin_password" + printf 'OLLAMA_HOST=%s\n' "$OLLAMA_BASE_URL_VALUE" + printf 'OLLAMA_BASE_URL=%s\n' "$OLLAMA_BASE_URL_VALUE" + printf 'OLLAMA_MODEL=%s\n' "$OLLAMA_MODEL" + printf 'OLLAMA_EMBEDDINGS_MODEL=%s\n' "$OLLAMA_EMBEDDINGS_MODEL" + } > "$compose_env_file" + chmod 600 "$compose_env_file" + + info "Compose environment: $compose_env_file" +} + +start_stack() { + local compose_file="$1" + local compose_dir + local compose_name + local log_file + + say "Starting TrustGraph" + info "Compose file: $compose_file" + write_compose_env_file "$compose_file" + + compose_dir="$(compose_dir_for "$compose_file")" + compose_name="$(basename "$compose_file")" + log_file="$(installer_log_file "compose-up")" + case "$log_file" in + /*) ;; + *) log_file="$SCRIPT_DIR/$log_file" ;; + esac + + ( + cd "$compose_dir" + run_with_spinner_logged \ + "Starting TrustGraph containers" \ + "$log_file" \ + "${COMPOSE_CMD[@]}" -f "$compose_name" up -d + ) +} + +http_status() { + local url="$1" + curl -sS -o /dev/null -w '%{http_code}' --max-time 5 "$url" 2>/dev/null || true +} + +http_status_with_bearer() { + local url="$1" + local token="$2" + + curl -sS -o /dev/null -w '%{http_code}' --max-time 5 \ + -H "Authorization: Bearer $token" \ + "$url" 2>/dev/null || true +} + +sha256_text() { + local value="$1" + + printf '%s' "$value" | python3 -c 'import hashlib, sys; print(hashlib.sha256(sys.stdin.buffer.read()).hexdigest())' +} + +repair_local_iam_api_key() { + local compose_file="$1" + local compose_dir + local compose_name + local key_hash + local key_suffix + local user_id + local username + local key_id + local prefix + local cql + local log_file + + [[ -n "$compose_file" && -f "$compose_file" ]] || return 1 + [[ -n "$AUTH_TOKEN" ]] || return 1 + command_exists python3 || return 1 + [[ "${#COMPOSE_CMD[@]}" -gt 0 ]] || return 1 + + key_hash="$(sha256_text "$AUTH_TOKEN")" + key_suffix="${key_hash:0:12}" + user_id="installer-admin-$key_suffix" + username="installer-admin-$key_suffix" + key_id="installer-key-$key_suffix" + prefix="$(printf '%s' "${AUTH_TOKEN:0:7}" | tr -cd 'a-zA-Z0-9_-')" + compose_dir="$(compose_dir_for "$compose_file")" + compose_name="$(basename "$compose_file")" + log_file="$(installer_log_file "iam-key-repair")" + + cql=" +USE iam; +INSERT INTO iam_users (id, workspace, username, name, email, password_hash, roles, enabled, must_change_password, created) +VALUES ('$user_id', 'default', '$username', 'Installer Admin', '', 'installer-repair', {'admin'}, true, false, toTimestamp(now())); +INSERT INTO iam_users_by_username (workspace, username, user_id) +VALUES ('default', '$username', '$user_id'); +INSERT INTO iam_api_keys (key_hash, id, user_id, name, prefix, expires, created, last_used) +VALUES ('$key_hash', '$key_id', '$user_id', 'installer-repair', '$prefix', null, toTimestamp(now()), null); +" + + say "Repairing local IAM API key" + info "Adding the saved installer key to the local installer-managed IAM database." + mkdir -p "$(dirname "$log_file")" + if ( + cd "$compose_dir" + printf '%s\n' "$cql" | "${COMPOSE_CMD[@]}" -f "$compose_name" exec -T cassandra cqlsh + ) >"$log_file" 2>&1; then + info "Local IAM API key repair completed." + return 0 + fi + + warn "Local IAM API key repair failed. Last log lines from $log_file:" + tail -n 40 "$log_file" >&2 || true + return 1 +} + +wait_for_gateway() { + local deadline=$((SECONDS + HEALTH_TIMEOUT)) + local next_notice=$((SECONDS + 15)) + local status="" + + say "Waiting for API gateway" + info "Checking $API_URL for up to ${HEALTH_TIMEOUT}s." + while (( SECONDS < deadline )); do + status="$(http_status "$API_URL")" + if [[ "$status" == "200" || "$status" == "401" || "$status" == "404" ]]; then + info "API gateway is responding with HTTP $status" + return 0 + fi + if (( SECONDS >= next_notice )); then + info "Still waiting; last HTTP status was ${status:-connection failed}." + next_notice=$((SECONDS + 15)) + fi + sleep 3 + done + + die "API gateway did not respond at $API_URL within ${HEALTH_TIMEOUT}s" +} + +verify_api_key_authentication() { + local compose_file="${1:-}" + local deadline=$((SECONDS + AUTH_CHECK_TIMEOUT)) + local metrics_url="${API_URL%/}/api/metrics/query?query=processor_info" + local status="" + + [[ -n "$AUTH_TOKEN" ]] || return 0 + + say "Checking API key authentication" + info "The API gateway root can return HTTP 404; that is normal. This checks an authenticated endpoint." + + while :; do + status="$(http_status_with_bearer "$metrics_url" "$AUTH_TOKEN")" + case "$status" in + 200) + info "Installer API key authenticated at the API gateway." + return 0 + ;; + 401|403) + ;; + "") + ;; + *) + info "Authentication probe returned HTTP $status; continuing to the full health checks." + return 0 + ;; + esac + + (( SECONDS >= deadline )) && break + sleep 3 + done + + if [[ "$status" == "401" || "$status" == "403" ]]; then + if [[ -n "$compose_file" ]] && repair_local_iam_api_key "$compose_file"; then + status="$(http_status_with_bearer "$metrics_url" "$AUTH_TOKEN")" + if [[ "$status" == "200" ]]; then + info "Installer API key authenticated after local IAM repair." + return 0 + fi + fi + warn "The API gateway is running, but it rejected the installer API key." + info "Configured installer API key: $AUTH_TOKEN" + info "Saved environment: $INSTALL_DIR/trustgraph-installer.env" + warn "This usually means compose volumes contain IAM data from an earlier install. Run ./install_trustgraph.sh --remove-all to remove the installer-managed deployment and compose volumes, then reinstall; or rerun with the original TRUSTGRAPH_TOKEN if you know it." + return 1 + fi + + warn "Could not confirm API key authentication yet; continuing to the full health checks." +} + +bootstrap_iam_if_available() { + local bootstrap_output="" + local log_file="$INSTALL_DIR/iam-bootstrap.log" + + if ! command_exists tg-bootstrap-iam; then + warn "tg-bootstrap-iam is not on PATH; using the installer API key for health checks." + return + fi + + say "Checking IAM bootstrap" + if bootstrap_output="$(tg-bootstrap-iam --api-url "$API_URL" 2>"$log_file")"; then + if [[ -n "$bootstrap_output" ]]; then + AUTH_TOKEN="$bootstrap_output" + info "Captured the first-run admin API key from IAM bootstrap." + write_env_file + fi + else + info "IAM bootstrap did not issue a new key; this is normal for token mode or an already-bootstrapped system." + info "Details: $log_file" + fi +} + +verify_system() { + local verify_cmd=() + + if command_exists tg-verify-system-status; then + verify_cmd=(tg-verify-system-status) + elif "$PYTHON_BIN" -c 'import trustgraph.cli.verify_system_status' >/dev/null 2>&1; then + verify_cmd=("$PYTHON_BIN" -m trustgraph.cli.verify_system_status) + else + say "Verifying TrustGraph health" + info "API gateway: $API_URL" + info "Workbench UI: $UI_URL" + [[ "$(http_status "$API_URL")" =~ ^(200|401|404)$ ]] || die "API gateway health check failed." + [[ "$(http_status "${UI_URL%/}/index.html")" == "200" ]] || warn "Workbench UI did not return HTTP 200 yet." + return + fi + + say "Verifying TrustGraph health" + verify_cmd+=( + --api-url "$API_URL" + --ui-url "$UI_URL" + --global-timeout "$HEALTH_TIMEOUT" + ) + if [[ -n "$AUTH_TOKEN" ]]; then + verify_cmd+=(--token "$AUTH_TOKEN") + fi + + "${verify_cmd[@]}" +} + +launch_ui() { + if [[ "$AUTO_LAUNCH" -ne 1 ]]; then + info "Workbench UI autolaunch disabled." + return + fi + + say "Opening Workbench UI" + if command_exists open; then + open "$UI_URL" + elif command_exists xdg-open; then + xdg-open "$UI_URL" + elif command_exists wslview; then + wslview "$UI_URL" + else + warn "Could not find a browser launcher. Open this URL manually: $UI_URL" + return + fi + info "Workbench UI: $UI_URL" +} + +print_ready_summary() { + local auth_status="${1:-0}" + + if [[ "$auth_status" -eq 0 ]]; then + say "TrustGraph is ready" + else + say "TrustGraph started with an authentication warning" + fi + info "Workbench UI: $UI_URL" + info "API gateway: $API_URL" + if [[ -n "$AUTH_TOKEN" ]]; then + info "Admin/bootstrap API key: $AUTH_TOKEN" + fi + info "Saved environment: $INSTALL_DIR/trustgraph-installer.env" +} + +main() { + parse_args "$@" + cd "$SCRIPT_DIR" + init_colors + + print_banner + if [[ "$REMOVE_ALL" -eq 1 ]]; then + say "$APP_NAME guided uninstaller" + load_saved_answers + remove_all_installation + return 0 + fi + + say "$APP_NAME guided installer" + handle_existing_install + load_saved_answers + detect_hardware + choose_recommendations + print_hardware_summary + + collect_answers + print_plan_summary + + if [[ "$DRY_RUN" -eq 1 ]]; then + say "Dry run complete" + return 0 + fi + + if ! confirm "Proceed with this install plan?" 1; then + die "Install cancelled." + fi + + preflight + offer_ollama_model_downloads + write_env_file + run_all_tests + run_config_generator + + local compose_file + compose_file="$(find_compose_file)" + start_stack "$compose_file" + wait_for_gateway + bootstrap_iam_if_available + local auth_status=0 + local verify_status=0 + + verify_api_key_authentication "$compose_file" || auth_status=$? + if [[ "$auth_status" -eq 0 ]]; then + verify_system || verify_status=$? + else + warn "Skipping authenticated health checks because the configured API key was rejected." + fi + launch_ui + + print_ready_summary "$auth_status" + + if [[ "$auth_status" -ne 0 ]]; then + return "$auth_status" + fi + if [[ "$verify_status" -ne 0 ]]; then + return "$verify_status" + fi +} + +if [[ "${INSTALL_TRUSTGRAPH_SOURCE_ONLY:-0}" != "1" ]]; then + main "$@" +fi From 6887076ce0944576fc38ef8edd84fac71bed8e17 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 1 Jul 2026 16:50:47 +0100 Subject: [PATCH 6/9] feat: add dashscope variant for Alibaba Cloud DashScope API (#1010) DashScope uses enable_thinking as a top-level parameter rather than inside extra_body as the Qwen docs suggest. --- .../model/text_completion/openai/variants.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py index 0c314a04..7e650e37 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -118,6 +118,26 @@ class QwenVariant(Variant): return {} +class DashScopeVariant(Variant): + """Alibaba Cloud DashScope API (Qwen models via DashScope).""" + + name = "dashscope" + token_param = "max_completion_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = thinking != "off" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "enable_thinking": enabled, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + class MistralVariant(Variant): """Mistral API (Mistral Large, etc.).""" @@ -181,6 +201,7 @@ VARIANTS = { "deepseek": DeepSeekVariant, "qwen": QwenVariant, "mistral": MistralVariant, + "dashscope": DashScopeVariant, "glm": GlmVariant, "llama": LlamaVariant, } From f18d48dc3963edc3c2dc14d78550db99becc4599 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Jul 2026 09:12:55 +0100 Subject: [PATCH 7/9] fix: simplify dashscope variant and route API calls through variants (#1012) Replace the client.post()/httpx bypass with standard SDK extra_body, confirmed working against DashScope. Make DashScope the base variant with Qwen as a subclass alias. Route all API calls through variant create_completion/create_completion_stream methods. --- .../model/text_completion/openai/llm.py | 60 +++++++++---------- .../model/text_completion/openai/variants.py | 51 ++++++++-------- 2 files changed, 53 insertions(+), 58 deletions(-) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 01035bc9..57958bc0 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -96,20 +96,20 @@ class Processor(LlmService): api_kwargs = self._build_kwargs(model_name, effective_temperature) - resp = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - **api_kwargs, + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] + + resp = self.variant.create_completion( + self.openai, model_name, messages, **api_kwargs, ) inputtokens = resp.usage.prompt_tokens @@ -176,28 +176,24 @@ class Processor(LlmService): try: api_kwargs = self._build_kwargs(model_name, effective_temperature) - response = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - stream=True, - stream_options={"include_usage": True}, - **api_kwargs, - ) + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] total_input_tokens = 0 total_output_tokens = 0 - for chunk in response: + async for chunk in self.variant.create_completion_stream( + self.openai, model_name, messages, **api_kwargs, + ): if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( text=chunk.choices[0].delta.content, diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py index 7e650e37..87de725d 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -62,6 +62,20 @@ class Variant: """Extract thinking content from a streaming delta.""" return getattr(delta, "reasoning_content", None) + def create_completion(self, client, model, messages, **kwargs): + """Call the completions API. Override for non-standard SDKs.""" + return client.chat.completions.create( + model=model, messages=messages, **kwargs, + ) + + async def create_completion_stream(self, client, model, messages, **kwargs): + """Call the streaming completions API. Override for non-standard SDKs.""" + for chunk in client.chat.completions.create( + model=model, messages=messages, stream=True, + stream_options={"include_usage": True}, **kwargs, + ): + yield chunk + class OpenAIVariant(Variant): """Standard OpenAI API (GPT-4o, o1, o3, etc.).""" @@ -96,30 +110,8 @@ class DeepSeekVariant(Variant): return {} -class QwenVariant(Variant): - """Qwen / Alibaba Cloud API.""" - - name = "qwen" - token_param = "max_completion_tokens" - temperature_with_thinking = True - - def completion_kwargs(self, max_output, temperature, thinking): - enabled = thinking != "off" - kwargs = { - self.token_param: max_output, - "temperature": temperature, - "extra_body": { - "enable_thinking": enabled, - }, - } - return kwargs - - def thinking_kwargs(self, effort): - return {} - - class DashScopeVariant(Variant): - """Alibaba Cloud DashScope API (Qwen models via DashScope).""" + """Alibaba Cloud DashScope API (Qwen models).""" name = "dashscope" token_param = "max_completion_tokens" @@ -127,17 +119,24 @@ class DashScopeVariant(Variant): def completion_kwargs(self, max_output, temperature, thinking): enabled = thinking != "off" - kwargs = { + return { self.token_param: max_output, "temperature": temperature, - "enable_thinking": enabled, + "extra_body": { + "enable_thinking": enabled, + }, } - return kwargs def thinking_kwargs(self, effort): return {} +class QwenVariant(DashScopeVariant): + """Qwen — alias for DashScope.""" + + name = "qwen" + + class MistralVariant(Variant): """Mistral API (Mistral Large, etc.).""" From 6c9a545a0673eef6843948efefa754fbb7802c71 Mon Sep 17 00:00:00 2001 From: Sunny Date: Thu, 2 Jul 2026 02:50:13 -0600 Subject: [PATCH 8/9] feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011) Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after vector retrieval, over-fetch a wider candidate pool, rerank with the cross-encoder, and keep the top doc_limit chunks for synthesis. Per maintainer review, the fetch and select sizes are two caller-controlled limits rather than one internal heuristic: - doc_limit: chunks selected into the synthesis prompt (unchanged meaning). - fetch_limit: candidate pool pulled from the vector store before reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are raised to it. Lets the caller control how hard the reranker has to work. Details: - schema: DocumentRagQuery.fetch_limit (additive, backward compatible). - document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors doc_limit); the core applies the heuristic default and derives synthesis provenance from the chunk-selection focus when reranking ran. - provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection). - request translator + client SDKs + CLI: fetch-limit / --fetch-limit, threaded exactly like doc_limit and the GraphRAG limits. - tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring. Reranking is skipped byte-identically when no reranker role is wired. Requires the companion trustgraph-templates change wiring the reranker topics into the document-rag flow (mirrors #279 for GraphRAG). --- .../unit/test_retrieval/test_document_rag.py | 18 +- .../test_document_rag_rerank.py | 478 ++++++++++++++++++ .../test_document_rag_reranker_wiring.py | 89 ++++ .../test_document_rag_service.py | 1 + trustgraph-base/trustgraph/api/async_flow.py | 8 +- .../trustgraph/api/async_socket_client.py | 4 +- trustgraph-base/trustgraph/api/flow.py | 7 +- .../trustgraph/api/socket_client.py | 4 + .../messaging/translators/retrieval.py | 2 + .../trustgraph/provenance/__init__.py | 10 + .../trustgraph/provenance/namespaces.py | 3 + .../trustgraph/provenance/triples.py | 76 ++- trustgraph-base/trustgraph/provenance/uris.py | 29 ++ .../trustgraph/provenance/vocabulary.py | 2 + .../trustgraph/schema/services/retrieval.py | 5 +- .../trustgraph/cli/invoke_document_rag.py | 21 +- .../retrieval/document_rag/document_rag.py | 88 +++- .../trustgraph/retrieval/document_rag/rag.py | 34 ++ 18 files changed, 853 insertions(+), 26 deletions(-) create mode 100644 tests/unit/test_retrieval/test_document_rag_rerank.py create mode 100644 tests/unit/test_retrieval/test_document_rag_reranker_wiring.py diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 7762b543..a08bc718 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -101,27 +101,27 @@ class TestQuery: assert query.rag == mock_rag assert query.collection == "test_collection" assert query.verbose is False - assert query.doc_limit == 20 # Default value + assert query.fetch_limit == 20 # Default value - def test_query_initialization_with_custom_doc_limit(self): - """Test Query initialization with custom doc_limit""" + def test_query_initialization_with_custom_fetch_limit(self): + """Test Query initialization with custom fetch_limit""" # Create mock DocumentRag mock_rag = MagicMock() - # Initialize Query with custom doc_limit + # Initialize Query with custom fetch_limit query = Query( rag=mock_rag, workspace="test_workspace", collection="custom_collection", verbose=True, - doc_limit=50 + fetch_limit=50 ) # Verify initialization assert query.rag == mock_rag assert query.collection == "custom_collection" assert query.verbose is True - assert query.doc_limit == 50 + assert query.fetch_limit == 50 @pytest.mark.asyncio async def test_extract_concepts(self): @@ -224,7 +224,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=15 + fetch_limit=15 ) # Call get_docs with concepts list @@ -377,7 +377,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=True, - doc_limit=5 + fetch_limit=5 ) # Call get_docs with concepts @@ -615,7 +615,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=10 + fetch_limit=10 ) docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) diff --git a/tests/unit/test_retrieval/test_document_rag_rerank.py b/tests/unit/test_retrieval/test_document_rag_rerank.py new file mode 100644 index 00000000..d711d57c --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_rerank.py @@ -0,0 +1,478 @@ +""" +Tests for the optional cross-encoder reranking pass in DocumentRag.query(). + +Two behaviours are covered: + + 1. No-op: when no reranker_client is wired (the default), query() must feed + the LLM the exact same chunks, in the same order, that retrieval produced + - byte-identical to the pre-reranker behaviour - and must NOT emit a + chunk-selection provenance event. + + 2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered + and truncated according to the reranker's results, the LLM receives the + reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted + carrying the per-surviving-chunk scores and chunk references. + +These are pure orchestration tests - the reranker is a stub, so there is no +torch / network dependency. +""" + +import pytest +from unittest.mock import AsyncMock +from dataclasses import dataclass + +from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.base import PromptResult +from trustgraph.schema import RerankerResult + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_WAS_DERIVED_FROM, + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class ChunkMatch: + """Mimics the result from doc_embeddings_client.query().""" + chunk_id: str + + +# --------------------------------------------------------------------------- +# Fixtures: three retrievable chunks +# --------------------------------------------------------------------------- + +CHUNK_A = "urn:chunk:policy-doc-1:chunk-0" +CHUNK_B = "urn:chunk:policy-doc-1:chunk-1" +CHUNK_C = "urn:chunk:policy-doc-1:chunk-2" + +CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase." +CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays." +CHUNK_C_CONTENT = "Refunds are processed to the original payment method." + +# Retrieval (post-dedupe) order is A, B, C. +ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT] +ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C] + + +def build_mock_clients(): + """ + Build mock subsidiary clients for a document-rag query returning three + distinct chunks (A, B, C) in that order. + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock() + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="return policy\nrefund") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # Each concept query returns the same three chunks; dedupe keeps A, B, C. + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id=CHUNK_A), + ChunkMatch(chunk_id=CHUNK_B), + ChunkMatch(chunk_id=CHUNK_C), + ] + + async def mock_fetch(chunk_id): + return { + CHUNK_A: CHUNK_A_CONTENT, + CHUNK_B: CHUNK_B_CONTENT, + CHUNK_C: CHUNK_C_CONTENT, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="Items can be returned within 30 days for a full refund.", + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + +class StubReranker: + """ + Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed, + pre-sorted, truncated list of RerankerResult - exactly the contract the + flashrank service guarantees (sorted desc by score, truncated to limit). + """ + + def __init__(self, results): + self._results = results + self.calls = [] + + async def rerank(self, queries, documents, limit=10, timeout=300): + self.calls.append( + {"queries": queries, "documents": documents, "limit": limit} + ) + return self._results + + +# --------------------------------------------------------------------------- +# 1. No-op: reranker_client=None must not change anything +# --------------------------------------------------------------------------- + +class TestRerankNoOp: + + @pytest.mark.asyncio + async def test_documents_passed_to_llm_are_unchanged(self): + """ + With no reranker wired, document_prompt must receive the retrieved + chunks in the original order and length. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) # reranker_client defaults to None + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == ORDERED_CONTENT + + @pytest.mark.asyncio + async def test_no_chunk_selection_event_emitted(self): + """ + Without a reranker, the provenance chain is the original 4 stages: + question, grounding, exploration, synthesis - no focus stage. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + assert len(events) == 4 + types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS, + ] + for i, expected in enumerate(types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + # No chunk-selection entity anywhere. + for e in events: + assert not any( + t.o.iri == TG_CHUNK_SELECTION + for t in e["triples"] + if t.p.iri == RDF_TYPE + ) + + @pytest.mark.asyncio + async def test_synthesis_derives_from_exploration_when_no_rerank(self): + """ + No-op lineage is unchanged: synthesis derives from exploration + (there is no focus stage). Guards the conditional synthesis parent. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, synthesis + exp_uri = events[2]["explain_id"] + syn_event = events[3] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri + + +# --------------------------------------------------------------------------- +# 2. Rerank: reorder + truncate + provenance +# --------------------------------------------------------------------------- + +class TestRerankActive: + + def _reranker_keeping_C_then_A(self): + # Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped. + # Pre-sorted desc by score and truncated to limit, per the contract. + return StubReranker([ + RerankerResult(document_id="2", query_id="0", score=0.95), + RerankerResult(document_id="0", query_id="0", score=0.42), + ]) + + @pytest.mark.asyncio + async def test_documents_reordered_and_truncated(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT] + + @pytest.mark.asyncio + async def test_reranker_called_with_single_query_and_all_docs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=2) + + assert len(reranker.calls) == 1 + c = reranker.calls[0] + assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}] + assert c["documents"] == [ + {"id": "0", "text": CHUNK_A_CONTENT}, + {"id": "1", "text": CHUNK_B_CONTENT}, + {"id": "2", "text": CHUNK_C_CONTENT}, + ] + # The rerank narrows down to the final doc_limit, NOT fetch_limit + # (fetch_limit is the over-fetched candidate pool size). + assert c["limit"] == 2 + + @pytest.mark.asyncio + async def test_explicit_fetch_limit_over_fetches_then_narrows(self): + """ + Semantic guard for the value of reranking AND the maintainer's two-limit + contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider + candidate pool so the cross-encoder can surface chunks the bi-encoder + ranked outside the final doc_limit, then the rerank narrows the pool back + down to doc_limit. The fetch_limit is honoured directly (caller controls + how hard the reranker works), not overridden by any heuristic. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # Candidate pool (fetch_limit=60) >> final doc_limit (6). + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=6, fetch_limit=60, + ) + + # Over-fetch: the embeddings store is queried with the fetch_limit + # budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit + # budget (6 // 2 = 3). This is the bug guard. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + + # Narrow: the rerank keeps the final doc_limit (6), not fetch_limit. + assert reranker.calls[0]["limit"] == 6 + + @pytest.mark.asyncio + async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self): + """ + With no fetch_limit passed to query(), the candidate pool falls back to + the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with + doc_limit and reranking keeps its recall benefit out of the box. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # No fetch_limit -> heuristic default. + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=20) + + # fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + # Rerank narrows to the final doc_limit (20). + assert reranker.calls[0]["limit"] == 20 + + @pytest.mark.asyncio + async def test_fetch_limit_floored_at_doc_limit(self): + """ + A fetch_limit below doc_limit is floored up to doc_limit: retrieval must + never fetch fewer candidates than the rerank is asked to keep, else the + prompt could not be filled. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=10, fetch_limit=4, + ) + + # fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 5 + assert reranker.calls[0]["limit"] == 10 + + @pytest.mark.asyncio + async def test_chunk_selection_event_emitted(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # Now 5 stages: question, grounding, exploration, focus, synthesis. + assert len(events) == 5 + ordered_types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + ] + for i, expected in enumerate(ordered_types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + @pytest.mark.asyncio + async def test_chunk_selection_carries_scores_and_chunk_refs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + focus_event = events[3] + foc_uri = focus_event["explain_id"] + triples = focus_event["triples"] + + # focus is derived from exploration + exp_uri = events[2]["explain_id"] + assert derived_from(triples, foc_uri) == exp_uri + + # Two ChunkSelection sub-entities, linked from focus. + sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri) + assert len(sel_links) == 2 + + # Each selection has a ChunkSelection type, a chunk document ref and a score. + chunk_refs = set() + scores = set() + for link in sel_links: + sel_uri = link.o.iri + assert has_type(triples, sel_uri, TG_CHUNK_SELECTION) + doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri) + assert doc_ref is not None + chunk_refs.add(doc_ref.o.iri) + score_t = find_triple(triples, TG_SCORE, sel_uri) + assert score_t is not None + scores.add(score_t.o.value) + + # Surviving chunks are C and A (B dropped), with the reranker scores. + assert chunk_refs == {CHUNK_C, CHUNK_A} + assert scores == {"0.95", "0.42"} + + @pytest.mark.asyncio + async def test_all_focus_triples_in_retrieval_graph(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + for t in events[3]["triples"]: + assert t.g == "urn:graph:retrieval" + + @pytest.mark.asyncio + async def test_synthesis_derives_from_focus_when_reranking(self): + """ + When reranking runs, synthesis must derive from the focus node (the + reranked chunks actually fed to the LLM), mirroring GraphRAG - not from + exploration, which would leave focus as a dangling branch and + misrepresent what fed the answer. + """ + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + doc_limit=2, + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, focus, synthesis + foc_uri = events[3]["explain_id"] + syn_event = events[4] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri + + @pytest.mark.asyncio + async def test_empty_docs_skips_reranker(self): + """If retrieval returns no chunks, the reranker is never called.""" + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + doc_embeddings_client.query.return_value = [] # no matches + + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + assert reranker.calls == [] diff --git a/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py new file mode 100644 index 00000000..bf4337b4 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py @@ -0,0 +1,89 @@ +""" +Cross-layer wiring contract for the Document-RAG reranker (issue #878). + +The Document-RAG processor registers a ``RerankerClientSpec`` for the +``reranker-request`` / ``reranker-response`` roles (see +``retrieval/document_rag/rag.py``). At flow construction every spec runs +``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add`` +resolves its topics via ``definition["topics"][name]`` - which raises +``KeyError`` if the flow blueprint does not provide those topics. + +This means the monorepo code change is only safe to deploy together with the +companion ``trustgraph-templates`` change that wires ``reranker-request`` / +``reranker-response`` into the Document-RAG flow (mirroring what templates +PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that +contract from the monorepo side: + + * with the reranker topics present (as the updated templates compile them), + the spec binds cleanly and registers the client; + * without them (the pre-companion blueprint), construction fails fast with a + KeyError naming the missing role - documenting exactly why the templates + change is required. + +No broker/network: the pub/sub backend is mocked (topics are bound at add() +time, connections happen later at start()). +""" + +import pytest +from unittest.mock import MagicMock + +from trustgraph.base import RerankerClientSpec + + +def _flow(): + f = MagicMock() + f.workspace = "ws" + f.name = "document-rag" + f.id = "proc1" + f.consumer = {} + return f + + +def _processor(): + p = MagicMock() + p.pubsub = MagicMock() + p.id = "proc1" + p.taskgroup = MagicMock() + return p + + +def _spec(): + return RerankerClientSpec( + request_name="reranker-request", + response_name="reranker-response", + ) + + +# Topics dict as the UPDATED document-store.jsonnet compiles them +# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}). +DEFINITION_WITH_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + "reranker-request": "request:tg:reranker:ws:id", + "reranker-response": "response:tg:reranker:ws:id", + } +} + +# Pre-companion blueprint: no reranker topics (document-rag before the templates change). +DEFINITION_WITHOUT_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + } +} + + +def test_reranker_client_binds_when_flow_provides_topics(): + flow = _flow() + _spec().add(flow, _processor(), DEFINITION_WITH_RERANKER) + # The client consumer is registered against the reranker role. + assert "reranker-request" in flow.consumer + + +def test_reranker_client_keyerrors_without_companion_template_topics(): + with pytest.raises(KeyError) as exc: + _spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER) + # Fails fast naming the missing role -> the trustgraph-templates companion + # change (wire reranker-request/response into the document-rag flow) is required. + assert "reranker-request" in str(exc.value) diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index dde3acc1..2bdf3959 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -66,6 +66,7 @@ class TestDocumentRagService: workspace=ANY, # Workspace comes from flow.workspace (mock) collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, + fetch_limit=0, # Unset -> core derives the candidate pool explain_callback=ANY, # Explainability callback is always passed save_answer_callback=ANY, # Librarian save callback is always passed ) diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index de592b59..afd48f1b 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -527,7 +527,8 @@ class AsyncFlowInstance: return result.get("response", "") async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, **kwargs: Any) -> str: + doc_limit: int = 10, fetch_limit: int = 0, + **kwargs: Any) -> str: """ Execute document-based RAG query (non-streaming). @@ -541,7 +542,9 @@ class AsyncFlowInstance: Args: query: User query text collection: Collection identifier containing documents - doc_limit: Maximum number of document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) **kwargs: Additional service-specific parameters Returns: @@ -564,6 +567,7 @@ class AsyncFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": False } request_data.update(kwargs) diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 78b608a7..9eff3d60 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -379,12 +379,14 @@ class AsyncSocketFlowInstance: yield chunk.content async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, streaming: bool = False, **kwargs): + doc_limit: int = 10, fetch_limit: int = 0, + streaming: bool = False, **kwargs): """Document RAG with optional streaming""" request = { "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 886306b3..b9e9487b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -415,7 +415,7 @@ class FlowInstance: def document_rag( self, query,collection="default", - doc_limit=10, + doc_limit=10, fetch_limit=0, ): """ Execute document-based Retrieval-Augmented Generation (RAG) query. @@ -426,7 +426,9 @@ class FlowInstance: Args: query: Natural language query collection: Collection identifier (default: "default") - doc_limit: Maximum document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) Returns: str: Generated response incorporating document context @@ -447,6 +449,7 @@ class FlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, } result = self.request( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 3a06e0d8..efa887a1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -752,6 +752,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, streaming: bool = False, **kwargs: Any ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: @@ -764,6 +765,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) @@ -785,6 +787,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: """Execute document-based RAG query with explainability support.""" @@ -792,6 +795,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": True, "explainable": True, } diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index fe766522..f2a0b29a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator): query=data["query"], collection=data.get("collection", "default"), doc_limit=int(data.get("doc-limit", 20)), + fetch_limit=int(data.get("fetch-limit", 0)), streaming=data.get("streaming", False) ) @@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator): "query": obj.query, "collection": obj.collection, "doc-limit": obj.doc_limit, + "fetch-limit": obj.fetch_limit, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index ce91a3cb..d96bad1e 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -64,6 +64,8 @@ from . uris import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, + chunk_selection_uri, docrag_synthesis_uri, ) @@ -94,6 +96,8 @@ from . namespaces import ( TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, @@ -132,6 +136,7 @@ from . triples import ( # Query-time provenance triple builders (DocumentRAG) docrag_question_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, # Utility set_graph, @@ -196,6 +201,8 @@ __all__ = [ "docrag_question_uri", "docrag_grounding_uri", "docrag_exploration_uri", + "docrag_focus_uri", + "chunk_selection_uri", "docrag_synthesis_uri", # Namespaces "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", @@ -219,6 +226,8 @@ __all__ = [ "TG_EDGE_SELECTION", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", + # Chunk selection entity type + "TG_CHUNK_SELECTION", # Explainability entity types "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_ANALYSIS", "TG_CONCLUSION", @@ -254,6 +263,7 @@ __all__ = [ # Query-time provenance triple builders (DocumentRAG) "docrag_question_triples", "docrag_exploration_triples", + "docrag_chunk_selection_triples", "docrag_synthesis_triples", # Agent provenance triple builders "agent_session_triples", diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 6f81f122..da6e30b2 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -76,6 +76,9 @@ TG_EDGE_SELECTION = TG + "EdgeSelection" TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" +# Chunk selection entity type (cross-encoder reranked chunk in Focus) +TG_CHUNK_SELECTION = TG + "ChunkSelection" + # Extraction provenance entity types TG_DOCUMENT_TYPE = TG + "Document" TG_PAGE_TYPE = TG + "Page" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 8e4871c3..d2374d54 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -30,6 +30,8 @@ from . namespaces import ( TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, # Unifying types @@ -40,7 +42,10 @@ from . namespaces import ( TG_IN_TOKEN, TG_OUT_TOKEN, ) -from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri +from . uris import ( + activity_uri, agent_uri, subgraph_uri, edge_selection_uri, + chunk_selection_uri, +) def set_graph(triples: List[Triple], graph: str) -> List[Triple]: @@ -718,6 +723,75 @@ def docrag_exploration_triples( return triples +def docrag_chunk_selection_triples( + focus_uri: str, + exploration_uri: str, + selected_chunks_with_scores: List[dict], + session_id: str, +) -> List[Triple]: + """ + Build triples for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity + derived from exploration, with one ChunkSelection sub-entity per surviving + chunk carrying the chunk reference and the reranker score. + + Structure: + a tg:Focus ; prov:wasDerivedFrom . + tg:selectedChunk . + a tg:ChunkSelection . + tg:document . + tg:score "0.97" . + + Args: + focus_uri: URI of the focus entity (from docrag_focus_uri) + exploration_uri: URI of the parent exploration entity + selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score' + session_id: Session UUID for generating chunk selection URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)), + _triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")), + _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), + ] + + for idx, chunk_info in enumerate(selected_chunks_with_scores): + chunk_id = chunk_info.get("chunk_id") + if not chunk_id: + continue + + chunk_sel_uri = chunk_selection_uri(session_id, idx) + + # Link focus to chunk selection entity + triples.append( + _triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri)) + ) + + # Type the chunk selection entity + triples.append( + _triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION)) + ) + + # Reference the actual chunk (in librarian) + triples.append( + _triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id)) + ) + + # Cross-encoder score + score = chunk_info.get("score") + if score is not None: + triples.append( + _triple(chunk_sel_uri, TG_SCORE, _literal(str(score))) + ) + + return triples + + def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index a26ac867..00beacbe 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str: return f"urn:trustgraph:docrag:{session_id}/exploration" +def docrag_focus_uri(session_id: str) -> str: + """ + Generate URI for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/focus + """ + return f"urn:trustgraph:docrag:{session_id}/focus" + + +def chunk_selection_uri(session_id: str, chunk_index: int) -> str: + """ + Generate URI for a chunk selection item (links a reranked chunk to its + score). Mirrors edge_selection_uri for GraphRAG. + + Args: + session_id: The session UUID. + chunk_index: Index of this chunk in the selection (0-based). + + Returns: + URN in format: urn:trustgraph:prov:chunk:{uuid}:{index} + """ + return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}" + + def docrag_synthesis_uri(session_id: str) -> str: """ Generate URI for a document RAG synthesis entity (final answer). diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index 1434d45d..f5139992 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -30,6 +30,7 @@ from . namespaces import ( TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, TG_EDGE_SELECTION, TG_SCORE, + TG_CHUNK_SELECTION, ) @@ -95,6 +96,7 @@ TG_CLASS_LABELS = [ _label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_STEP_RESULT, "Step Result"), _label_triple(TG_EDGE_SELECTION, "Edge Selection"), + _label_triple(TG_CHUNK_SELECTION, "Chunk Selection"), ] # TrustGraph predicate labels diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index e937e720..2d4e01e1 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -40,7 +40,10 @@ class GraphRagResponse: class DocumentRagQuery: query: str = "" collection: str = "" - doc_limit: int = 0 + doc_limit: int = 0 # docs selected into the synthesis prompt + fetch_limit: int = 0 # candidate pool fetched from the vector store + # before reranking (0 = derive from doc_limit; + # values below doc_limit are raised to it) streaming: bool = False @dataclass diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 01512ac8..04f4deda 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 +default_fetch_limit = 0 def question_explainable( - url, flow_id, question_text, collection, doc_limit, token=None, debug=False, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, + token=None, debug=False, workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" @@ -39,6 +41,7 @@ def question_explainable( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ): if isinstance(item, RAGChunk): # Print response content @@ -97,7 +100,7 @@ def question_explainable( def question( - url, flow_id, question_text, collection, doc_limit, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, streaming=True, token=None, explainable=False, debug=False, show_usage=False, workspace="default", ): @@ -109,6 +112,7 @@ def question( question_text=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, token=token, debug=debug, workspace=workspace, @@ -128,6 +132,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True ) @@ -155,6 +160,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ) print(result.text) @@ -214,7 +220,15 @@ def main(): '-d', '--doc-limit', type=int, default=default_doc_limit, - help=f'Document limit (default: {default_doc_limit})' + help=f'Documents selected into the prompt (default: {default_doc_limit})' + ) + + parser.add_argument( + '--fetch-limit', + type=int, + default=default_fetch_limit, + help='Candidate documents fetched from the vector store before ' + 'reranking (default: derive from doc-limit)' ) parser.add_argument( @@ -251,6 +265,7 @@ def main(): question_text=args.question, collection=args.collection, doc_limit=args.doc_limit, + fetch_limit=args.fetch_limit, streaming=not args.no_streaming, token=args.token, explainable=args.explainable, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index ecfa7936..a3730eb9 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -9,10 +9,12 @@ from trustgraph.provenance import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, docrag_synthesis_uri, docrag_question_triples, grounding_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, set_graph, GRAPH_RETRIEVAL, @@ -21,19 +23,25 @@ from trustgraph.provenance import ( # Module logger logger = logging.getLogger(__name__) +# When the caller does not specify a fetch_limit, reranking over-fetches this +# many times the final doc_limit as the candidate pool, so the cross-encoder can +# recover relevant chunks the bi-encoder ranked just outside the top doc_limit. +# This is only the fallback default: an explicit fetch_limit overrides it. +OVERFETCH_FACTOR = 3 + LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( self, rag, workspace, collection, verbose, - doc_limit=20, track_usage=None, + fetch_limit=20, track_usage=None, ): self.rag = rag self.workspace = workspace self.collection = collection self.verbose = verbose - self.doc_limit = doc_limit + self.fetch_limit = fetch_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -91,7 +99,7 @@ class Query: # Query chunk matches for each concept concurrently per_concept_limit = max( - 1, self.doc_limit // len(vectors) + 1, self.fetch_limit // len(vectors) ) async def query_concept(vec): @@ -140,6 +148,7 @@ class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk, + reranker_client=None, verbose=False, ): @@ -150,12 +159,16 @@ class DocumentRag: self.doc_embeddings_client = doc_embeddings_client self.fetch_chunk = fetch_chunk + # Optional cross-encoder reranker. When None, the retrieval path is + # byte-identical to the pre-reranker behaviour. + self.reranker_client = reranker_client + if self.verbose: logger.debug("DocumentRag initialized") async def query( self, query, workspace="default", collection="default", - doc_limit=20, streaming=False, chunk_callback=None, + doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): """ @@ -165,7 +178,10 @@ class DocumentRag: query: The query string workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier - doc_limit: Max chunks to retrieve + doc_limit: Chunks selected into the synthesis prompt (after rerank) + fetch_limit: Candidate pool fetched from the vector store before + reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a + reranker is wired, else doc_limit). streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for explainability @@ -197,6 +213,7 @@ class DocumentRag: q_uri = docrag_question_uri(session_id) gnd_uri = docrag_grounding_uri(session_id) exp_uri = docrag_exploration_uri(session_id) + foc_uri = docrag_focus_uri(session_id) syn_uri = docrag_synthesis_uri(session_id) timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") @@ -209,10 +226,21 @@ class DocumentRag: ) await explain_callback(q_triples, q_uri) + # Resolve the candidate-pool size fetched from the vector store. When a + # reranker is wired, honour an explicit fetch_limit; if unset, fall back + # to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit, + # else the rerank could not fill the prompt. Without a reranker, fetch + # doc_limit as before (byte-identical behaviour). + if self.reranker_client is not None: + fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit) + fetch_count = max(fl, doc_limit) + else: + fetch_count = doc_limit + q = Query( rag=self, workspace=workspace, collection=collection, verbose=self.verbose, - doc_limit=doc_limit, track_usage=track_usage, + fetch_limit=fetch_count, track_usage=track_usage, ) # Extract concepts from query (grounding step) @@ -235,6 +263,7 @@ class DocumentRag: docs, chunk_ids = await q.get_docs(concepts) # Emit exploration explainability after chunks retrieved + # (full candidate set, before any reranking) if explain_callback: exp_triples = set_graph( docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), @@ -242,6 +271,45 @@ class DocumentRag: ) await explain_callback(exp_triples, exp_uri) + # Optional cross-encoder reranking pass between retrieval and + # synthesis. Mirrors GraphRAG's reranker usage but with a single + # query (the question). When no reranker is wired, this block is + # skipped entirely and behaviour is byte-identical to before. + reranked = False + if self.reranker_client is not None and docs: + results = await self.reranker_client.rerank( + queries=[{"id": "0", "text": query}], + documents=[ + {"id": str(i), "text": d} for i, d in enumerate(docs) + ], + # Narrow the over-fetched candidate pool down to the final + # doc_limit requested for synthesis. + limit=doc_limit, + ) + + # results are sorted desc by score and truncated to limit by the + # reranker service, so order gives the surviving top-N directly. + order = [int(r.document_id) for r in results] + docs = [docs[i] for i in order] + chunk_ids = [chunk_ids[i] for i in order] + reranked = True + + # Emit chunk-selection (focus) explainability: surviving chunks + # with their cross-encoder scores, derived from exploration. + if explain_callback: + selected_chunks_with_scores = [ + {"chunk_id": chunk_ids[i], "score": r.score} + for i, r in enumerate(results) + ] + foc_triples = set_graph( + docrag_chunk_selection_triples( + foc_uri, exp_uri, + selected_chunks_with_scores, session_id, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(foc_triples, foc_uri) + if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Documents: {docs}") @@ -291,9 +359,15 @@ class DocumentRag: logger.warning(f"Failed to save answer to librarian: {e}") synthesis_doc_id = None + # When reranking ran, synthesis derives from the focus (the + # reranked chunks actually fed to the LLM), as GraphRAG always does. + # When no reranker is wired, there is no focus stage, so synthesis + # derives from exploration (the unchanged no-op lineage) - a + # deliberate divergence from GraphRAG's always-on focus. + syn_parent = foc_uri if reranked else exp_uri syn_triples = set_graph( docrag_synthesis_triples( - syn_uri, exp_uri, + syn_uri, syn_parent, document_id=synthesis_doc_id, in_token=synthesis_result.in_token if synthesis_result else None, out_token=synthesis_result.out_token if synthesis_result else None, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index c80f4172..158cbefc 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -13,6 +13,7 @@ from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -28,14 +29,21 @@ class Processor(FlowProcessor): doc_limit = params.get("doc_limit", 5) + # Instance-default candidate-pool size fetched before cross-encoder + # reranking; the rerank step narrows it back down to doc_limit for the + # LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit). + fetch_limit = params.get("fetch_limit", 0) + super(Processor, self).__init__( **params | { "id": id, "doc_limit": doc_limit, + "fetch_limit": fetch_limit, } ) self.doc_limit = doc_limit + self.fetch_limit = fetch_limit self.register_specification( ConsumerSpec( @@ -66,6 +74,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -105,6 +120,7 @@ class Processor(FlowProcessor): doc_embeddings_client = flow("document-embeddings-request"), prompt_client = flow("prompt-request"), fetch_chunk = fetch_chunk, + reranker_client = flow("reranker-request"), verbose=True, ) @@ -113,6 +129,13 @@ class Processor(FlowProcessor): else: doc_limit = self.doc_limit + # Candidate-pool size: per-request override, else the instance + # default; 0 lets the core derive it from doc_limit. + if v.fetch_limit: + fetch_limit = v.fetch_limit + else: + fetch_limit = self.fetch_limit + async def send_explainability(triples, explain_id): await flow("explainability").send(Triples( metadata=Metadata( @@ -163,6 +186,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True, chunk_callback=send_chunk, explain_callback=send_explainability, @@ -188,6 +212,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, explain_callback=send_explainability, save_answer_callback=save_answer, ) @@ -243,6 +268,15 @@ class Processor(FlowProcessor): help=f'Default document fetch limit (default: 10)' ) + parser.add_argument( + '--fetch-limit', + type=int, + default=0, + help='Candidate chunks to fetch from the vector store and rerank ' + 'before keeping the top doc-limit for the LLM ' + '(default: derive from doc-limit)' + ) + def run(): Processor.launch(default_ident, __doc__) From 9cf7dcb5782995ad405cd205750a4993fc8dba7b Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 2 Jul 2026 11:14:54 +0100 Subject: [PATCH 9/9] fix: wire variant into remaining streaming integration test mocks (#1013) Three more streaming tests were missing _wire_variant after the async for change in create_completion_stream. --- .../integration/test_text_completion_streaming_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/test_text_completion_streaming_integration.py b/tests/integration/test_text_completion_streaming_integration.py index caa3ec9c..7d514522 100644 --- a/tests/integration/test_text_completion_streaming_integration.py +++ b/tests/integration/test_text_completion_streaming_integration.py @@ -270,6 +270,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act & Assert with pytest.raises(Exception) as exc_info: @@ -307,6 +308,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] @@ -330,6 +332,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) system_prompt = "You are an expert." user_prompt = "Explain quantum physics."