diff --git a/README.md b/README.md
index c366a3d9..b66edc70 100644
--- a/README.md
+++ b/README.md
@@ -11,11 +11,11 @@
-# The agent runtime platform
+# The semantic deployment platform
-TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
+TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads.
The platform:
- [x] Multi-model and multimodal database system
@@ -99,23 +99,21 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
-## Workbench
+## Context Graph UI
-The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
+
-- **Vector Search**: Search the installed knowledge bases
-- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
-- **Relationships**: Analyze deep relationships in the installed knowledge bases
-- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
-- **Library**: Staging area for installing knowledge bases
-- **Flow Classes**: Workflow preset configurations
-- **Flows**: Create custom workflows and adjust LLM parameters during runtime
-- **Knowledge Cores**: Manage resuable knowledge bases
-- **Prompts**: Manage and adjust prompts during runtime
-- **Schemas**: Define custom schemas for structured data knowledge bases
-- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
-- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
-- **MCP Tools**: Connect to MCP servers
+The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
+
+- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
+- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
+- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
+- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
+- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
+- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
+- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout
+- **Workspace UX** — Workspace selection and management surfaced directly in the interface
+- **Prompt Editor** — A dedicated prompt editing workflow
## TypeScript Library for UIs
diff --git a/containers/Containerfile.unstructured b/containers/Containerfile.unstructured
index 6de8a800..2b9a18f7 100644
--- a/containers/Containerfile.unstructured
+++ b/containers/Containerfile.unstructured
@@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
-RUN dnf install -y python3.13 libxcb mesa-libGL && \
+RUN dnf install -y python3.13 libxcb mesa-libGL poppler-utils && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
diff --git a/docs/tech-specs/knowledge-core-completeness.md b/docs/tech-specs/knowledge-core-completeness.md
new file mode 100644
index 00000000..3ccb41f0
--- /dev/null
+++ b/docs/tech-specs/knowledge-core-completeness.md
@@ -0,0 +1,535 @@
+---
+layout: default
+title: "Knowledge Core Completeness"
+parent: "Tech Specs"
+---
+
+# Knowledge Core Completeness
+
+## Overview
+
+Knowledge cores are portable snapshots of extracted knowledge: triples, graph
+embeddings, and document embeddings stored in Cassandra's `knowledge` keyspace.
+They can be downloaded as files, transferred between TrustGraph instances, and
+loaded back into vector and graph stores.
+
+Recent additions to TrustGraph — explainability/provenance and named graphs —
+were not carried through to the knowledge core system. This means that
+exporting and re-importing a core loses provenance links, graph assignments,
+and source material, breaking the explainability chain.
+
+This specification addresses three gaps:
+
+1. **Named graphs not stored** — The `g` (graph name) field on triples is
+ silently dropped when writing to the core store and comes back as `None`
+ on read.
+2. **Provenance triples not captured** — Provenance triples (PROV-O) are
+ generated during extraction and flow to graph stores, but never enter
+ the knowledge core store. It is unclear whether they arrive at the store
+ in the correct form.
+3. **Source material not included** — Documents, text pages, and chunks in
+ the librarian's bucket store are not part of the core. After loading a
+ core on a different instance, provenance links to source material point
+ at nothing.
+
+## Goals
+
+- **Self-contained cores**: A downloaded knowledge core file contains
+ everything needed to reconstruct the full knowledge graph including
+ provenance and source attribution on a fresh instance.
+- **Named graph preservation**: Round-tripping a core preserves graph
+ assignments on all triples.
+- **Backward compatibility**: Existing core files (without graph names or
+ source material) can still be uploaded and loaded. New fields are optional
+ on import.
+- **No change to core identity**: A core is still identified by its document
+ ID. The additional data is associated with the same core ID.
+- **Minimal file format changes**: Extend the existing msgpack record format
+ with new record types rather than restructuring existing ones.
+
+## Background
+
+### Current Lifecycle
+
+```
+Extraction pipeline
+ │
+ ├─ triples ──────────────────► knowledge core store (Cassandra)
+ ├─ graph embeddings ─────────► knowledge core store (Cassandra)
+ ├─ document embeddings ──────► knowledge core store (Cassandra)
+ ├─ provenance triples ───────► graph store (only)
+ └─ source documents ─────────► librarian bucket store (only)
+
+Download: Cassandra ──► knowledge manager ──► API gateway ──► client file
+Upload: client file ──► API gateway ──► knowledge manager ──► Cassandra
+Load: Cassandra ──► knowledge manager ──► Pulsar topics ──► graph/vector stores
+```
+
+### Current Core File Format (msgpack)
+
+A core file is a sequence of concatenated msgpack records. Each record is a
+2-element tuple: `(type_tag, payload)`.
+
+| Type tag | Payload | Description |
+|----------|---------|-------------|
+| `"t"` | `{"m": {id, root, collection}, "t": [triple_dicts]}` | Triple batch |
+| `"ge"` | `{"m": {id, root, collection}, "e": [{entity, vector}]}` | Graph embedding batch |
+
+### What's Missing
+
+#### Named Graphs
+
+The `Triple` dataclass has a `g: str | None` field (graph name IRI), used to
+separate provenance graphs (`urn:graph:source`, `urn:graph:retrieval`) from
+the default graph. However:
+
+- **Cassandra schema** (`knowledge.triples` table): stores a 6-tuple per
+ triple `(s_val, s_is_uri, p_val, p_is_uri, o_val, o_is_uri)` — no graph
+ field.
+- **`add_triples()`** (`tables/knowledge.py:231`): destructures only `s`,
+ `p`, `o` — `g` is discarded.
+- **`get_triples()`** (`tables/knowledge.py:396`): reconstructs `Triple`
+ with `g` defaulting to `None`.
+- **Core file format**: triple dicts do not include a graph field.
+
+#### Provenance Triples
+
+Provenance triples are generated in the extraction pipeline
+(`trustgraph-base/trustgraph/provenance/triples.py`) and published to graph
+store topics. They use named graphs (`urn:graph:source`,
+`urn:graph:retrieval`) and PROV-O vocabulary.
+
+The knowledge core store processor (`storage/knowledge/store.py`) listens on
+`triples-input` and `graph-embeddings-input`. Whether provenance triples
+arrive on the same `triples-input` topic or a separate one needs
+verification. Even if they do arrive, the graph name would be lost (per
+above).
+
+#### Source Material
+
+The librarian stores the full document hierarchy in a separate system:
+
+- **Blob store** (S3/MinIO): original documents, text pages, chunks —
+ keyed by object UUID under `doc/{object_id}`.
+- **Cassandra `library` keyspace**: document metadata including `id`,
+ `kind` (MIME type), `title`, `parent_id`, `document_type`
+ (`source`/`extracted`), `object_id` (blob reference).
+
+Provenance triples link extracted facts back to chunk/page/document IDs.
+Those IDs resolve through the librarian. When a core is loaded on a
+different instance, the librarian has no matching documents, so the entire
+provenance chain is broken.
+
+### Key Source Files
+
+| Component | File | Purpose |
+|-----------|------|---------|
+| Core Cassandra schema | `trustgraph-flow/trustgraph/tables/knowledge.py` | Table definitions, read/write |
+| Core manager | `trustgraph-flow/trustgraph/cores/knowledge.py` | API operations, load-to-store |
+| Core store processor | `trustgraph-flow/trustgraph/storage/knowledge/store.py` | Extraction → Cassandra |
+| CLI download | `trustgraph-cli/trustgraph/cli/get_kg_core.py` | Core → msgpack file |
+| CLI upload | `trustgraph-cli/trustgraph/cli/put_kg_core.py` | Msgpack file → core |
+| CLI load | `trustgraph-cli/trustgraph/cli/load_kg_core.py` | Core → graph/vector stores |
+| API client | `trustgraph-base/trustgraph/api/knowledge.py` | Client-side knowledge API |
+| Triple schema | `trustgraph-base/trustgraph/schema/core/primitives.py` | Triple dataclass with `g` field |
+| Provenance generation | `trustgraph-base/trustgraph/provenance/triples.py` | PROV-O triple creation |
+| Librarian | `trustgraph-flow/trustgraph/librarian/librarian.py` | Document storage service |
+| Library tables | `trustgraph-flow/trustgraph/tables/library.py` | Document metadata in Cassandra |
+| Blob store | `trustgraph-flow/trustgraph/librarian/blob_store.py` | S3/MinIO object storage |
+
+## Technical Design
+
+### Change 1: Named Graph Field in Core Storage
+
+#### Cassandra Schema
+
+Extend the `triples` tuple from 6 to 7 elements, adding the graph name:
+
+```
+triples list>
+```
+
+**Migration**: The schema change uses `ALTER TABLE` or is handled by
+creating a new table version. Existing rows with 6-element tuples must be
+handled gracefully on read — if the tuple has 6 elements, treat graph as
+default.
+
+#### Write Path (`add_triples`)
+
+Change `tables/knowledge.py:add_triples()` to include `triple.g`:
+
+```python
+triples = [
+ (
+ *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
+ v.g or ""
+ )
+ for v in m.triples
+]
+```
+
+#### Read Path (`get_triples`)
+
+Change `tables/knowledge.py:get_triples()` to restore the graph name:
+
+```python
+Triple(
+ s = tuple_to_term(elt[0], elt[1]),
+ p = tuple_to_term(elt[2], elt[3]),
+ o = tuple_to_term(elt[4], elt[5]),
+ g = elt[6] if len(elt) > 6 and elt[6] else None,
+)
+```
+
+The `len(elt) > 6` guard provides backward compatibility with existing
+6-element rows.
+
+#### Core File Format
+
+Extend triple dicts in the `"t"` record to include the graph name:
+
+```python
+# In get_kg_core.py write_triple — each triple dict gains "g" key
+{"s": ..., "p": ..., "o": ..., "g": "urn:graph:source"}
+```
+
+On read (`put_kg_core.py`), treat missing `"g"` key as default graph for
+backward compatibility with old core files.
+
+### Change 2: Provenance Triples in Cores
+
+#### Investigation Required
+
+Before implementation, verify:
+
+1. Whether provenance triples arrive on the `triples-input` topic that the
+ knowledge core store processor already listens on.
+2. If not, which topic they use, and whether the store processor should
+ subscribe to it.
+
+#### If provenance triples already arrive at the store
+
+The only change needed is Change 1 (named graphs) — the provenance triples
+are already being stored, just without their graph name. Once graph names
+are preserved, provenance triples will round-trip correctly.
+
+#### If provenance triples do NOT arrive at the store
+
+Two options:
+
+**Option A — Route provenance to the existing store topic**: Configure the
+flow so provenance triples are published to the same `triples-input` topic.
+This is the simpler approach and keeps the store processor unchanged.
+
+**Option B — Add a subscription**: Add a new `ConsumerSpec` in the store
+processor for the provenance topic. This keeps provenance routing
+independent but adds complexity.
+
+Recommendation: Option A, unless there is a reason provenance triples are
+intentionally kept off the core store topic.
+
+### Change 3: Source Material in Cores
+
+This is the largest change. The goal is that when a core is loaded on a
+fresh instance, provenance links to source material resolve.
+
+#### Architecture
+
+Source material is **not stored in the knowledge core tables**. It lives in
+the librarian (Cassandra `library` keyspace + S3/MinIO blob store) and is
+fetched on demand via the librarian's existing service API.
+
+The knowledge manager acts as a **client of the librarian service** — it
+calls the librarian's request/response API over pub/sub to retrieve document
+metadata and content. It does not access the library's Cassandra tables or
+blob store directly.
+
+#### Transport
+
+The librarian's pub/sub API already handles chunking of large documents.
+This chunking is designed to be websocket-friendly, so library content
+flowing through the API gateway to external clients does not require
+re-chunking. The API gateway remains a transport layer.
+
+```
+Download:
+ Knowledge manager ──pub/sub──► Librarian (fetch metadata + content)
+ Knowledge manager ──pub/sub──► API gateway ──websocket──► Client
+
+Upload:
+ Client ──websocket──► API gateway ──pub/sub──► Knowledge manager
+ Knowledge manager ──pub/sub──► Librarian (store metadata + content)
+```
+
+#### What to Include
+
+The provenance chain links facts → chunks → pages → documents. For the
+chain to resolve, the core must include:
+
+1. **Document metadata** — the library record for each document in the
+ hierarchy (id, kind, title, parent_id, document_type, etc.)
+2. **Document content** — the blob data for each document (original file,
+ extracted text pages, text chunks)
+
+Including the full hierarchy is necessary because:
+- A user viewing provenance needs to traverse fact → chunk → page → document
+- The chunk text is needed to show what text a fact was extracted from
+- The page text provides broader context
+- The original document is needed for full source attribution
+
+#### Size Implications
+
+Source material will significantly increase core file sizes. A rough model:
+
+| Component | Typical size per document |
+|-----------|-------------------------|
+| Triples + embeddings (current) | 1-10 MB |
+| Chunk text (all chunks) | ~same as original document |
+| Page text (all pages) | ~same as original document |
+| Original document (PDF, etc.) | Varies widely (KB to hundreds of MB) |
+
+For a 10 MB PDF, the core could grow from ~5 MB to ~25 MB (original +
+derived text + existing data). For large document sets, cores could become
+very large.
+
+**Decision needed**: Whether to include original documents or just derived
+text (pages + chunks). Including only derived text still allows provenance
+display but loses the ability to serve the original file.
+
+#### New Core File Record Types
+
+Add new msgpack record types for library content:
+
+| Type tag | Payload | Description |
+|----------|---------|-------------|
+| `"lm"` | `{"id", "kind", "title", "parent_id", "document_type", "comments", "tags", "metadata"}` | Library document metadata |
+| `"lb"` | `{"id", "data"}` | Library document blob content (chunked by pub/sub layer) |
+
+These are emitted after the existing `"t"` and `"ge"` records during
+download and processed during upload.
+
+#### Download Path
+
+Extend `KnowledgeManager.get_kg_core()` to:
+
+1. Stream triples and graph embeddings from the core store (existing
+ behavior).
+2. Use the librarian service API to retrieve documents associated with
+ this core ID:
+ a. Fetch the root document metadata and content.
+ b. Use `list-children` to discover child documents (pages, chunks).
+ c. Recursively fetch metadata and content for each child.
+3. Stream each document as `"lm"` (metadata) and `"lb"` (content) records.
+
+The knowledge manager gains the librarian service as a pub/sub dependency.
+Large document content is chunked by the librarian's existing pub/sub
+transport — the knowledge manager receives and forwards these chunks without
+buffering the full blob in memory.
+
+#### Upload Path
+
+Extend `KnowledgeManager.put_kg_core()` to handle the new record types:
+
+1. For `"lm"` records: call the librarian service API to create/update
+ the document metadata.
+2. For `"lb"` records: call the librarian service API to store the
+ document content.
+
+Parent-child relationships are preserved because `parent_id` is stored in
+the metadata. Documents should be processed in hierarchy order (parent
+before child) to satisfy any ordering constraints.
+
+#### Load Path
+
+The load path (`_load_kg_core`) publishes triples and embeddings to Pulsar
+topics for ingestion into graph/vector stores. Source material does not need
+to flow through the load path — it is already in the librarian after the
+upload step and can be accessed directly by services that need it.
+
+No changes to the load path for source material.
+
+#### CLI Changes
+
+**`tg-get-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
+the file writer.
+
+**`tg-put-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
+the file reader. Send library records to the knowledge manager alongside
+triple/embedding records.
+
+#### Associating Documents with Cores
+
+The core ID is `metadata.root`, which is the root document ID from the
+librarian. This provides a natural join: the core's root document and all
+its children (pages, chunks) are the source material for that core.
+
+The librarian's `list-children` API provides the child documents. A
+recursive traversal from the root document collects the full hierarchy.
+
+### API Changes
+
+#### KnowledgeResponse Schema
+
+Add optional fields to `KnowledgeResponse` for library data:
+
+```python
+@dataclass
+class KnowledgeResponse:
+ error: Error | None = None
+ ids: list | None = None
+ eos: bool = False
+ triples: Triples | None = None
+ graph_embeddings: GraphEmbeddings | None = None
+ document_embeddings: DocumentEmbeddings | None = None
+ library_metadata: LibraryMetadata | None = None # new
+ library_blob: LibraryBlob | None = None # new
+```
+
+#### New Schema Types
+
+```python
+@dataclass
+class LibraryMetadata:
+ id: str
+ kind: str | None = None
+ title: str | None = None
+ parent_id: str | None = None
+ document_type: str | None = None
+ comments: str | None = None
+ tags: list[str] | None = None
+ metadata: list[Triple] | None = None
+
+@dataclass
+class LibraryBlob:
+ id: str
+ data: bytes
+```
+
+#### Socket API
+
+The existing streaming protocol for `get-kg-core` / `put-kg-core` carries
+these new fields naturally — responses already stream multiple record types.
+
+### Dependencies Between Changes
+
+```
+Change 1 (named graphs) ◄── Change 2 depends on this
+ │
+ └── Change 2 (provenance triples)
+ │
+ └── Change 3 (source material) is independent
+```
+
+Change 1 is a prerequisite for Change 2 (provenance triples use named
+graphs). Change 3 is independent and can be implemented in parallel.
+
+## Security Considerations
+
+- **Workspace isolation**: Core download/upload must respect workspace
+ boundaries. Source material from the librarian must only be included if
+ it belongs to the same workspace as the core. This is already enforced
+ by the existing workspace-scoped queries.
+- **Large blob transfer**: Streaming large documents through the API
+ is handled by the librarian's existing pub/sub chunking, which is
+ designed to be websocket-friendly. No additional chunking layer is
+ needed.
+- **Cross-instance trust**: When uploading a core from an external source,
+ the library content should be treated as untrusted input. Document
+ metadata and blob content should be validated before insertion.
+
+## Performance Considerations
+
+- **Core file size**: Including source material will significantly increase
+ core file sizes. Consider adding a flag to download/upload commands to
+ optionally exclude source material for use cases where only the knowledge
+ graph is needed.
+- **Streaming**: All paths already use streaming (paged Cassandra queries,
+ msgpack record-at-a-time). Library content should follow the same pattern.
+- **Cassandra schema migration**: Changing the tuple width in the `triples`
+ table requires careful handling. Cassandra frozen tuples cannot be altered
+ in place — a migration strategy is needed (see Migration Plan).
+
+## Testing Strategy
+
+- **Unit tests**: Triple round-trip with graph name (write → read →
+ verify `g` field preserved). Backward compatibility with 6-element tuples.
+- **Integration tests**: Full lifecycle — extract with provenance → download
+ core → upload to fresh instance → load → verify provenance chain resolves.
+- **File format tests**: Read old-format core files (no graph name, no
+ library records) and verify they load without error.
+- **Library inclusion tests**: Download core with source material → upload →
+ verify documents accessible through librarian.
+
+## Migration Plan
+
+### Cassandra Schema
+
+The `triples` table stores tuples in a `list>` column. Cassandra
+does not support altering the type of an existing column. Options:
+
+**Option A — New table**: Create a `triples_v2` table with the 7-element
+tuple. Migrate data from `triples` to `triples_v2`. The read path checks
+both tables during a transition period, then the old table is dropped.
+
+**Option B — Dual read**: Keep the existing table. The read path handles
+both 6-element and 7-element tuples by checking length. New writes use
+7-element tuples. This works if Cassandra accepts variable-length tuples in
+a list — **needs verification**.
+
+**Option C — Separate graph column**: Instead of extending the tuple, add a
+parallel `graphs list` column where `graphs[i]` corresponds to
+`triples[i]`. This avoids tuple migration entirely but requires keeping the
+two lists in sync.
+
+Recommendation: Verify Option B first (simplest). Fall back to Option A if
+Cassandra rejects mixed tuple lengths.
+
+### Core File Format
+
+Backward compatible by design:
+- Old files lack `"g"` in triple dicts and have no `"lm"`/`"lb"` records →
+ handled by defaults.
+- New files read by old code → old code ignores unknown record types (the
+ existing `read_message` raises on unknown types, so this needs a small
+ fix to skip unknown types gracefully).
+
+## Open Questions
+
+1. **Provenance topic routing**: Do provenance triples currently arrive at
+ the `triples-input` topic consumed by the knowledge core store? If not,
+ what topic are they on?
+
+2. **Include original documents?**: Should cores include the original
+ uploaded document (e.g. PDF), or only derived text (pages + chunks)?
+ Including originals makes cores fully self-contained but potentially
+ very large. Excluding them preserves provenance text display but loses
+ the ability to serve the original file.
+
+3. **Optional source material**: Should there be a flag on download/upload
+ to include or exclude source material? This would let users choose
+ between compact cores (knowledge only) and complete cores (knowledge +
+ sources).
+
+4. **Cassandra tuple migration**: Can Cassandra handle mixed-length tuples
+ in a `list>` column, or is a table migration required?
+
+5. **Document embedding cores**: DE cores are managed alongside KG cores.
+ Do they need the same treatment (source material inclusion)? The
+ document embeddings reference chunk IDs — the same provenance chain
+ applies.
+
+6. **Core versioning**: Should the core file include a version marker so
+ readers can distinguish old-format from new-format files without
+ trial-and-error parsing?
+
+## References
+
+- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
+- Query-time explainability: `docs/tech-specs/query-time-explainability.md`
+- Agent explainability: `docs/tech-specs/agent-explainability.md`
+- Data ownership model: `docs/tech-specs/data-ownership-model.md`
diff --git a/tests/unit/test_base/test_cassandra_config.py b/tests/unit/test_base/test_cassandra_config.py
index a291434d..fe8a8379 100644
--- a/tests/unit/test_base/test_cassandra_config.py
+++ b/tests/unit/test_base/test_cassandra_config.py
@@ -409,4 +409,57 @@ class TestEdgeCases:
assert hosts == ['mixed-host']
assert username is None # Stays None
- assert password == 'mixed-pass'
\ No newline at end of file
+ assert password == 'mixed-pass'
+
+
+class TestReplicationFactorParamPath:
+
+ def test_explicit_kwarg(self):
+ with patch.dict(os.environ, {}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=3,
+ )
+ assert rf == 3
+
+ def test_kwarg_overrides_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=3,
+ )
+ assert rf == 3
+
+ def test_env_fallback_when_kwarg_none(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=None,
+ )
+ assert rf == 5
+
+ def test_default_when_no_kwarg_no_env(self):
+ with patch.dict(os.environ, {}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config()
+ assert rf == 1
+
+ def test_params_dict_path(self):
+ with patch.dict(os.environ, {}, clear=True):
+ params = {'cassandra_replication_factor': 3}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 3
+
+ def test_params_dict_overrides_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ params = {'cassandra_replication_factor': 3}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 3
+
+ def test_params_dict_missing_falls_to_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ params = {}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 5
\ No newline at end of file
diff --git a/tests/unit/test_base/test_qdrant_config.py b/tests/unit/test_base/test_qdrant_config.py
new file mode 100644
index 00000000..dbbe4214
--- /dev/null
+++ b/tests/unit/test_base/test_qdrant_config.py
@@ -0,0 +1,136 @@
+
+import os
+import pytest
+from unittest.mock import patch
+
+from trustgraph.base.qdrant_config import (
+ get_qdrant_defaults,
+ resolve_qdrant_config,
+)
+
+
+class TestGetQdrantDefaults:
+
+ def test_defaults_with_no_env_vars(self):
+ with patch.dict(os.environ, {}, clear=True):
+ defaults = get_qdrant_defaults()
+ assert defaults['url'] == 'http://localhost:6333'
+ assert defaults['api_key'] is None
+ assert defaults['replication_factor'] == 1
+ assert defaults['shard_number'] == 1
+
+ def test_defaults_from_env(self):
+ env = {
+ 'QDRANT_URL': 'http://qdrant:6333',
+ 'QDRANT_API_KEY': 'secret',
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ defaults = get_qdrant_defaults()
+ assert defaults['url'] == 'http://qdrant:6333'
+ assert defaults['api_key'] == 'secret'
+ assert defaults['replication_factor'] == 3
+ assert defaults['shard_number'] == 5
+
+
+class TestResolveQdrantConfig:
+
+ def test_defaults(self):
+ with patch.dict(os.environ, {}, clear=True):
+ url, api_key, rf, sn = resolve_qdrant_config()
+ assert url == 'http://localhost:6333'
+ assert api_key is None
+ assert rf == 1
+ assert sn == 1
+
+ def test_explicit_kwargs(self):
+ with patch.dict(os.environ, {}, clear=True):
+ url, api_key, rf, sn = resolve_qdrant_config(
+ url='http://custom:6333',
+ api_key='key',
+ replication_factor=3,
+ shard_number=5,
+ )
+ assert url == 'http://custom:6333'
+ assert api_key == 'key'
+ assert rf == 3
+ assert sn == 5
+
+ def test_kwargs_override_env(self):
+ env = {
+ 'QDRANT_URL': 'http://env:6333',
+ 'QDRANT_REPLICATION_FACTOR': '10',
+ 'QDRANT_SHARD_NUMBER': '10',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ url, _, rf, sn = resolve_qdrant_config(
+ url='http://explicit:6333',
+ replication_factor=3,
+ shard_number=5,
+ )
+ assert url == 'http://explicit:6333'
+ assert rf == 3
+ assert sn == 5
+
+ def test_env_fallback_when_kwargs_none(self):
+ env = {
+ 'QDRANT_URL': 'http://env:6333',
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ url, _, rf, sn = resolve_qdrant_config()
+ assert url == 'http://env:6333'
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_path(self):
+ with patch.dict(os.environ, {}, clear=True):
+ params = {
+ 'store_uri': 'http://params:6333',
+ 'api_key': 'pkey',
+ 'qdrant_replication_factor': 3,
+ 'qdrant_shard_number': 5,
+ }
+ url, api_key, rf, sn = resolve_qdrant_config(
+ url=params.get('store_uri'),
+ api_key=params.get('api_key'),
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert url == 'http://params:6333'
+ assert api_key == 'pkey'
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_overrides_env(self):
+ env = {
+ 'QDRANT_REPLICATION_FACTOR': '10',
+ 'QDRANT_SHARD_NUMBER': '10',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ params = {
+ 'qdrant_replication_factor': 3,
+ 'qdrant_shard_number': 5,
+ }
+ _, _, rf, sn = resolve_qdrant_config(
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_missing_falls_to_env(self):
+ env = {
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ params = {}
+ _, _, rf, sn = resolve_qdrant_config(
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert rf == 3
+ assert sn == 5
diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py
index 8f73dcc6..7797c9be 100644
--- a/tests/unit/test_cores/test_knowledge_manager.py
+++ b/tests/unit/test_cores/test_knowledge_manager.py
@@ -11,7 +11,12 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import call
from trustgraph.cores.knowledge import KnowledgeManager
-from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
+from trustgraph.schema import (
+ KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term,
+ EntityEmbeddings, IRI, LITERAL,
+ LibraryMetadata, LibraryBlob,
+ LibrarianResponse, DocumentMetadata,
+)
@pytest.fixture
@@ -373,11 +378,252 @@ class TestKnowledgeManagerOtherMethods:
mock_respond = AsyncMock()
await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user")
-
+
# Verify table store was called correctly
knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id")
-
+
# Verify response
mock_respond.assert_called_once()
response = mock_respond.call_args[0][0]
- assert response.error is None
\ No newline at end of file
+ assert response.error is None
+
+
+class TestKnowledgeManagerLibraryDownload:
+ """Test get_kg_core streaming of library documents."""
+
+ @pytest.fixture
+ def manager_with_librarian(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ mock_librarian = AsyncMock()
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="test_user",
+ cassandra_password="test_pass",
+ keyspace="test_keyspace",
+ flow_config=mock_flow_config,
+ librarian=mock_librarian,
+ )
+ manager.table_store = AsyncMock()
+ return manager
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_streams_library_docs(self, manager_with_librarian):
+ mock_request = Mock()
+ mock_request.id = "root-doc"
+ mock_respond = AsyncMock()
+
+ manager_with_librarian.table_store.get_triples = AsyncMock()
+ manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
+
+ root_meta = DocumentMetadata(
+ id="root-doc", kind="application/pdf", title="Test PDF",
+ document_type="source",
+ )
+ child_meta = DocumentMetadata(
+ id="chunk-1", kind="text/plain", title="Chunk 1",
+ parent_id="root-doc", document_type="chunk",
+ )
+
+ manager_with_librarian.librarian.fetch_document_metadata.return_value = root_meta
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse(
+ document_metadatas=[child_meta],
+ )
+ manager_with_librarian.librarian.fetch_document_content.side_effect = [
+ b"cm9vdCBjb250ZW50",
+ b"Y2h1bmsgY29udGVudA==",
+ ]
+
+ await manager_with_librarian.get_kg_core(
+ mock_request, mock_respond, "test-user"
+ )
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+
+ lm_responses = [r for r in responses if r.library_metadata is not None]
+ lb_responses = [r for r in responses if r.library_blob is not None]
+ eos_responses = [r for r in responses if r.eos is True]
+
+ assert len(lm_responses) == 2
+ assert lm_responses[0].library_metadata.id == "root-doc"
+ assert lm_responses[0].library_metadata.document_type == "source"
+ assert lm_responses[1].library_metadata.id == "chunk-1"
+ assert lm_responses[1].library_metadata.parent_id == "root-doc"
+
+ assert len(lb_responses) == 2
+ assert lb_responses[0].library_blob.id == "root-doc"
+ assert lb_responses[0].library_blob.data == b"cm9vdCBjb250ZW50"
+ assert lb_responses[1].library_blob.id == "chunk-1"
+
+ assert len(eos_responses) == 1
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_no_librarian_skips_library(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="u", cassandra_password="p",
+ keyspace="ks", flow_config=mock_flow_config,
+ )
+ manager.table_store = AsyncMock()
+ manager.table_store.get_triples = AsyncMock()
+ manager.table_store.get_graph_embeddings = AsyncMock()
+
+ mock_request = Mock()
+ mock_request.id = "doc-1"
+ mock_respond = AsyncMock()
+
+ await manager.get_kg_core(mock_request, mock_respond, "w")
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+ assert all(r.library_metadata is None for r in responses)
+ assert all(r.library_blob is None for r in responses)
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_librarian_metadata_failure_is_graceful(
+ self, manager_with_librarian,
+ ):
+ mock_request = Mock()
+ mock_request.id = "missing-doc"
+ mock_respond = AsyncMock()
+
+ manager_with_librarian.table_store.get_triples = AsyncMock()
+ manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
+ manager_with_librarian.librarian.fetch_document_metadata.side_effect = (
+ RuntimeError("not found")
+ )
+
+ await manager_with_librarian.get_kg_core(
+ mock_request, mock_respond, "test-user"
+ )
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+ assert all(r.library_metadata is None for r in responses)
+ assert any(r.eos for r in responses)
+
+
+class TestKnowledgeManagerLibraryUpload:
+ """Test put_kg_core handling of library metadata and blob records."""
+
+ @pytest.fixture
+ def manager_with_librarian(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ mock_librarian = AsyncMock()
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="u", cassandra_password="p",
+ keyspace="ks", flow_config=mock_flow_config,
+ librarian=mock_librarian,
+ )
+ manager.table_store = AsyncMock()
+ return manager
+
+ @pytest.mark.asyncio
+ async def test_put_metadata_then_blob_calls_librarian(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse()
+
+ # First call: metadata
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="doc-1", kind="application/pdf", title="Test",
+ document_type="source",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ # Metadata is buffered, librarian not called yet
+ manager_with_librarian.librarian.request.assert_not_called()
+
+ # Second call: blob
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(
+ id="doc-1", data=b"dGVzdA==",
+ )
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Now librarian should have been called with add-document
+ manager_with_librarian.librarian.request.assert_called_once()
+ call_args = manager_with_librarian.librarian.request.call_args[0][0]
+ assert call_args.operation == "add-document"
+ assert call_args.document_metadata.id == "doc-1"
+ assert call_args.document_metadata.kind == "application/pdf"
+ assert call_args.content == b"dGVzdA=="
+
+ @pytest.mark.asyncio
+ async def test_put_child_document_uses_add_child_operation(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse()
+
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="chunk-1", kind="text/plain", title="Chunk",
+ parent_id="doc-1", document_type="chunk",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="chunk-1", data=b"Y2h1bms=")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ call_args = manager_with_librarian.librarian.request.call_args[0][0]
+ assert call_args.operation == "add-child-document"
+ assert call_args.document_metadata.parent_id == "doc-1"
+
+ @pytest.mark.asyncio
+ async def test_put_blob_without_metadata_logs_warning(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="orphan", data=b"data")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Librarian should not be called for orphan blob
+ manager_with_librarian.librarian.request.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_put_existing_document_is_graceful(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.side_effect = RuntimeError(
+ "Document already exists"
+ )
+
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="doc-1", kind="application/pdf", title="Test",
+ document_type="source",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="doc-1", data=b"data")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Should not raise — "already exists" is handled gracefully
\ No newline at end of file
diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py
index 04807b20..641a9d78 100644
--- a/tests/unit/test_decoding/test_pdf_decoder.py
+++ b/tests/unit/test_decoding/test_pdf_decoder.py
@@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
@@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
+ @patch('trustgraph.base.librarian_client.Consumer')
+ @patch('trustgraph.base.librarian_client.Producer')
+ @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
+ @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
+ async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
+ """Test rejecting non-PDF content before invoking the PDF loader"""
+ html_content = b"Not found"
+ html_base64 = base64.b64encode(html_content)
+
+ mock_metadata = Metadata(id="test-doc")
+ mock_document = Document(metadata=mock_metadata, document_id="doc-123")
+ mock_msg = MagicMock()
+ mock_msg.value.return_value = mock_document
+
+ mock_output_flow = AsyncMock()
+ mock_triples_flow = AsyncMock()
+ mock_flow = MagicMock(side_effect=lambda name: {
+ "output": mock_output_flow,
+ "triples": mock_triples_flow,
+ }.get(name))
+ mock_flow.librarian.fetch_document_metadata = AsyncMock(
+ return_value=MagicMock(kind="application/pdf")
+ )
+ mock_flow.librarian.fetch_document_content = AsyncMock(
+ return_value=html_base64
+ )
+ mock_flow.librarian.save_child_document = AsyncMock()
+
+ config = {
+ 'id': 'test-pdf-decoder',
+ 'taskgroup': AsyncMock()
+ }
+
+ processor = Processor(**config)
+
+ await processor.on_message(mock_msg, None, mock_flow)
+
+ mock_pdf_loader_class.assert_not_called()
+ mock_output_flow.send.assert_not_called()
+ mock_triples_flow.send.assert_not_called()
+ mock_flow.librarian.save_child_document.assert_not_called()
+
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
@@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
diff --git a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
index aef6fc92..65837323 100644
--- a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
+++ b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
@@ -18,7 +18,7 @@ from trustgraph.embeddings.hf.hf import Processor
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test HuggingFace dynamic model loading and caching"""
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -39,7 +39,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -63,7 +63,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
mock_hf_class.assert_not_called()
assert processor.cached_model_name == "test-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -84,7 +84,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
mock_hf_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -107,7 +107,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -130,7 +130,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "custom-model"
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -164,7 +164,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -187,7 +187,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert mock_hf_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py
index b61500a4..fb385f43 100644
--- a/tests/unit/test_query/test_rows_cassandra_query.py
+++ b/tests/unit/test_query/test_rows_cassandra_query.py
@@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
- @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
- async def test_query_with_index_match(self, mock_async_execute):
+ @patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
+ async def test_query_with_index_match(self, mock_async_execute_paged):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
- # Mock async_execute to return test data
+ # Mock async_execute_paged to return test data (list of pages)
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
- mock_async_execute.return_value = [mock_row]
+ mock_async_execute_paged.return_value = [[mock_row]]
schema = RowSchema(
name="products",
@@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
- mock_async_execute.assert_called_once()
+ mock_async_execute_paged.assert_called_once()
# Verify query structure - should query unified rows table
- call_args = mock_async_execute.call_args
+ call_args = mock_async_execute_paged.call_args
query = call_args[0][1]
params = call_args[0][2]
@@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
- @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
- async def test_query_without_index_match(self, mock_async_execute):
+ @patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
+ async def test_query_without_index_match(self, mock_async_scan):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@@ -406,12 +406,10 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
- # Mock async_execute to return test data
+ # Mock async_scan to return filtered test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
- mock_row2 = MagicMock()
- mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
- mock_async_execute.return_value = [mock_row1, mock_row2]
+ mock_async_scan.return_value = [mock_row1]
schema = RowSchema(
name="products",
@@ -432,13 +430,16 @@ class TestUnifiedTableQueries:
limit=10
)
- # Query should use ALLOW FILTERING for scan
- call_args = mock_async_execute.call_args
+ # Verify async_scan was called
+ mock_async_scan.assert_called_once()
+
+ # Verify query structure
+ call_args = mock_async_scan.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query
- # Should post-filter results
+ # Should return filtered results
assert len(results) == 1
assert results[0]["name"] == "Product A"
diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py
index dbe06b40..41d0f88b 100644
--- a/tests/unit/test_reliability/test_null_embedding_protection.py
+++ b/tests/unit/test_reliability/test_null_embedding_protection.py
@@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
+ proc.replication_factor = 1
+ proc.shard_number = 1
msg = MagicMock()
msg.metadata.collection = "graphs"
diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py
index 59d15b45..2d058733 100644
--- a/tests/unit/test_tables/test_knowledge_table_store.py
+++ b/tests/unit/test_tables/test_knowledge_table_store.py
@@ -35,9 +35,9 @@ def _make_store():
class TestGetGraphEmbeddings:
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_entity_embeddings_with_singular_vector(
- self, mock_async_execute
+ self, mock_async_execute_paged
):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
@@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
@@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
await store.get_graph_embeddings("alice", "doc-1", receiver)
- mock_async_execute.assert_called_once_with(
+ mock_async_execute_paged.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
@@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
assert ge.entities[2].entity.value == "a literal entity"
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)
@@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
@@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
assert received[0].entities == []
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
@@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = fake_rows
+ mock_async_execute_paged.return_value = [fake_rows]
received = []
@@ -153,9 +153,9 @@ class TestGetTriples:
the same Metadata construction. Cover it for parity."""
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_row_converts_to_triples(self, mock_async_execute):
- # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_row_converts_to_triples(self, mock_async_execute_paged):
+ # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri, graph)
fake_row = (
None, None, None,
[
@@ -163,6 +163,7 @@ class TestGetTriples:
"http://example.org/alice", True,
"http://example.org/knows", True,
"http://example.org/bob", True,
+ "urn:graph:source",
),
],
)
@@ -170,7 +171,7 @@ class TestGetTriples:
store = _make_store()
store.cassandra = Mock()
store.get_triples_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
@@ -191,3 +192,33 @@ class TestGetTriples:
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"
+ assert t.g == "urn:graph:source"
+
+ @pytest.mark.asyncio
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_empty_graph_name_becomes_none(self, mock_async_execute_paged):
+ fake_row = (
+ None, None, None,
+ [
+ (
+ "http://example.org/alice", True,
+ "http://example.org/knows", True,
+ "http://example.org/bob", True,
+ "",
+ ),
+ ],
+ )
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.get_triples_stmt = Mock()
+ mock_async_execute_paged.return_value = [[fake_row]]
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ await store.get_triples("w", "d", receiver)
+
+ assert received[0].triples[0].g is None
diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
index 437b83c8..af128f23 100644
--- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
+++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
@@ -1,5 +1,6 @@
"""
-Round-trip unit tests for KnowledgeRequestTranslator.
+Round-trip unit tests for KnowledgeRequestTranslator and
+KnowledgeResponseTranslator.
Regression coverage: a previous version of the decode side constructed
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
@@ -15,9 +16,13 @@ Triples breaks the test.
import pytest
-from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
+from trustgraph.messaging.translators.knowledge import (
+ KnowledgeRequestTranslator,
+ KnowledgeResponseTranslator,
+)
from trustgraph.schema import (
KnowledgeRequest,
+ KnowledgeResponse,
GraphEmbeddings,
EntityEmbeddings,
Triples,
@@ -25,6 +30,8 @@ from trustgraph.schema import (
Metadata,
Term,
IRI,
+ LibraryMetadata,
+ LibraryBlob,
)
@@ -145,3 +152,161 @@ class TestKnowledgeRequestTranslatorTriples:
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"
+
+
+class TestKnowledgeRequestTranslatorLibrary:
+
+ def test_roundtrip_preserves_library_metadata(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_metadata=LibraryMetadata(
+ id="doc-1",
+ kind="application/pdf",
+ title="Test Document",
+ parent_id="",
+ document_type="source",
+ comments="test comments",
+ tags=["tag1", "tag2"],
+ ),
+ )
+
+ encoded = translator.encode(request)
+ assert "library-metadata" in encoded
+ lm = encoded["library-metadata"]
+ assert lm["id"] == "doc-1"
+ assert lm["kind"] == "application/pdf"
+ assert lm["title"] == "Test Document"
+ assert lm["parent-id"] == ""
+ assert lm["document-type"] == "source"
+ assert lm["comments"] == "test comments"
+ assert lm["tags"] == ["tag1", "tag2"]
+
+ decoded = translator.decode(encoded)
+ assert decoded.library_metadata is not None
+ assert decoded.library_metadata.id == "doc-1"
+ assert decoded.library_metadata.kind == "application/pdf"
+ assert decoded.library_metadata.title == "Test Document"
+ assert decoded.library_metadata.parent_id == ""
+ assert decoded.library_metadata.document_type == "source"
+ assert decoded.library_metadata.comments == "test comments"
+ assert decoded.library_metadata.tags == ["tag1", "tag2"]
+
+ def test_roundtrip_preserves_child_document_metadata(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_metadata=LibraryMetadata(
+ id="chunk-1",
+ kind="text/plain",
+ title="Chunk 1",
+ parent_id="doc-1",
+ document_type="chunk",
+ ),
+ )
+
+ encoded = translator.encode(request)
+ decoded = translator.decode(encoded)
+
+ assert decoded.library_metadata.parent_id == "doc-1"
+ assert decoded.library_metadata.document_type == "chunk"
+
+ def test_roundtrip_preserves_library_blob(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data=b"SGVsbG8gV29ybGQ=",
+ ),
+ )
+
+ encoded = translator.encode(request)
+ assert "library-blob" in encoded
+ assert encoded["library-blob"]["id"] == "doc-1"
+ assert encoded["library-blob"]["data"] == "SGVsbG8gV29ybGQ="
+
+ decoded = translator.decode(encoded)
+ assert decoded.library_blob is not None
+ assert decoded.library_blob.id == "doc-1"
+ assert decoded.library_blob.data == "SGVsbG8gV29ybGQ="
+
+ def test_absent_library_fields_decode_as_none(self, translator):
+ decoded = translator.decode({
+ "operation": "get-kg-core",
+ "id": "doc-1",
+ })
+ assert decoded.library_metadata is None
+ assert decoded.library_blob is None
+
+
+class TestKnowledgeResponseTranslatorLibrary:
+
+ @pytest.fixture
+ def response_translator(self):
+ return KnowledgeResponseTranslator()
+
+ def test_encode_library_metadata(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_metadata=LibraryMetadata(
+ id="doc-1",
+ kind="application/pdf",
+ title="Test",
+ parent_id="",
+ document_type="source",
+ comments="",
+ tags=[],
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert "library-metadata" in encoded
+ assert encoded["library-metadata"]["id"] == "doc-1"
+ assert encoded["library-metadata"]["kind"] == "application/pdf"
+ assert encoded["library-metadata"]["document-type"] == "source"
+
+ def test_encode_library_blob_bytes_to_string(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data=b"dGVzdCBkYXRh",
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert "library-blob" in encoded
+ assert encoded["library-blob"]["id"] == "doc-1"
+ assert encoded["library-blob"]["data"] == "dGVzdCBkYXRh"
+ assert isinstance(encoded["library-blob"]["data"], str)
+
+ def test_encode_library_blob_string_passthrough(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data="already-a-string",
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert encoded["library-blob"]["data"] == "already-a-string"
+
+ def test_library_metadata_is_not_final(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_metadata=LibraryMetadata(id="doc-1"),
+ )
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is False
+
+ def test_library_blob_is_not_final(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(id="doc-1", data=b"data"),
+ )
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is False
+
+ def test_eos_is_final(self, response_translator):
+ response = KnowledgeResponse(eos=True)
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is True
diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py
index 9074bac1..0190d3f5 100644
--- a/trustgraph-base/trustgraph/api/api.py
+++ b/trustgraph-base/trustgraph/api/api.py
@@ -337,7 +337,7 @@ class Api:
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
- self._bulk_client = BulkClient(base_url, self.timeout, self.token)
+ self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._bulk_client
def metrics(self):
@@ -462,7 +462,7 @@ class Api:
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
- self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
+ self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._async_bulk_client
def async_metrics(self):
diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py
index 9a6a49c3..f93ab667 100644
--- a/trustgraph-base/trustgraph/api/async_bulk_client.py
+++ b/trustgraph-base/trustgraph/api/async_bulk_client.py
@@ -9,10 +9,11 @@ from . types import Triple
class AsyncBulkClient:
"""Asynchronous bulk operations client"""
- def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
+ def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
+ self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@@ -25,11 +26,21 @@ class AsyncBulkClient:
else:
return f"ws://{url}"
+ def _build_ws_url(self, path: str) -> str:
+ """Build a WebSocket URL with token and workspace query params."""
+ ws_url = f"{self.url}{path}"
+ params = []
+ if self.token:
+ params.append(f"token={self.token}")
+ if self.workspace:
+ params.append(f"workspace={self.workspace}")
+ if params:
+ ws_url = f"{ws_url}?{'&'.join(params)}"
+ return ws_url
+
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
@@ -42,9 +53,7 @@ class AsyncBulkClient:
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -57,9 +66,7 @@ class AsyncBulkClient:
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@@ -67,9 +74,7 @@ class AsyncBulkClient:
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -77,9 +82,7 @@ class AsyncBulkClient:
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@@ -87,9 +90,7 @@ class AsyncBulkClient:
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -97,9 +98,7 @@ class AsyncBulkClient:
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
@@ -107,9 +106,7 @@ class AsyncBulkClient:
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -117,9 +114,7 @@ class AsyncBulkClient:
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for row in rows:
diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py
index 0e49fc4e..ae185240 100644
--- a/trustgraph-base/trustgraph/api/bulk_client.py
+++ b/trustgraph-base/trustgraph/api/bulk_client.py
@@ -34,7 +34,7 @@ class BulkClient:
Note: For true async support, use AsyncBulkClient instead.
"""
- def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
+ def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
"""
Initialize synchronous bulk client.
@@ -42,10 +42,12 @@ class BulkClient:
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
timeout: WebSocket timeout in seconds
token: Optional bearer token for authentication
+ workspace: Workspace for data isolation
"""
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
+ self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@@ -58,6 +60,18 @@ class BulkClient:
else:
return f"ws://{url}"
+ def _build_ws_url(self, path: str) -> str:
+ """Build a WebSocket URL with token and workspace query params."""
+ ws_url = f"{self.url}{path}"
+ params = []
+ if self.token:
+ params.append(f"token={self.token}")
+ if self.workspace:
+ params.append(f"workspace={self.workspace}")
+ if params:
+ ws_url = f"{ws_url}?{'&'.join(params)}"
+ return ws_url
+
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
@@ -116,9 +130,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@@ -194,9 +206,7 @@ class BulkClient:
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -238,9 +248,7 @@ class BulkClient:
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@@ -296,9 +304,7 @@ class BulkClient:
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -336,9 +342,7 @@ class BulkClient:
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@@ -394,9 +398,7 @@ class BulkClient:
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -446,9 +448,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@@ -522,9 +522,7 @@ class BulkClient:
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -562,9 +560,7 @@ class BulkClient:
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for row in rows:
diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py
index 6eeb95ff..91bc67a1 100644
--- a/trustgraph-base/trustgraph/api/socket_client.py
+++ b/trustgraph-base/trustgraph/api/socket_client.py
@@ -167,7 +167,8 @@ class SocketClient:
)
if resp.get("type") == "auth-ok":
- self.workspace = resp.get("workspace", self.workspace)
+ if self.workspace == "default":
+ self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
@@ -501,6 +502,7 @@ class SocketClient:
def put_kg_core(
self, id: str, triples=None, graph_embeddings=None,
+ library_metadata=None, library_blob=None,
) -> Dict[str, Any]:
request = {
"operation": "put-kg-core",
@@ -511,6 +513,10 @@ class SocketClient:
request["triples"] = triples
if graph_embeddings is not None:
request["graph-embeddings"] = graph_embeddings
+ if library_metadata is not None:
+ request["library-metadata"] = library_metadata
+ if library_blob is not None:
+ request["library-blob"] = library_blob
return self._send_request_sync("knowledge", None, request)
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
diff --git a/trustgraph-base/trustgraph/base/cassandra_config.py b/trustgraph-base/trustgraph/base/cassandra_config.py
index 78505c68..b2e36fbd 100644
--- a/trustgraph-base/trustgraph/base/cassandra_config.py
+++ b/trustgraph-base/trustgraph/base/cassandra_config.py
@@ -103,35 +103,19 @@ def resolve_cassandra_config(
host: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
- default_keyspace: Optional[str] = None
+ default_keyspace: Optional[str] = None,
+ replication_factor: Optional[int] = None,
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
- """
- Resolve Cassandra configuration from various sources.
-
- Can accept either argparse args object or explicit parameters.
- Converts host string to list format for Cassandra driver.
-
- Args:
- args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
- host: Optional explicit host parameter (overrides args)
- username: Optional explicit username parameter (overrides args)
- password: Optional explicit password parameter (overrides args)
- default_keyspace: Optional default keyspace if not specified elsewhere
-
- Returns:
- tuple: (hosts_list, username, password, keyspace, replication_factor)
- """
- # If args provided, extract values
keyspace = None
- replication_factor = 1
if args is not None:
host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
- replication_factor = getattr(args, 'cassandra_replication_factor', 1)
+ replication_factor = replication_factor or getattr(
+ args, 'cassandra_replication_factor', None
+ )
- # Apply defaults if still None
defaults = get_cassandra_defaults()
host = host or defaults['host']
username = username or defaults['username']
diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py
index 9bf599b1..ff10c140 100644
--- a/trustgraph-base/trustgraph/base/logging.py
+++ b/trustgraph-base/trustgraph/base/logging.py
@@ -11,6 +11,7 @@ Supports dual output to console and Loki for centralized log aggregation.
import contextvars
import logging
import logging.handlers
+import uuid
from argparse import ArgumentParser
from queue import Queue
from typing import Any
@@ -132,14 +133,12 @@ def setup_logging(args: dict[str, Any]) -> None:
try:
from logging_loki import LokiHandler
- # Create Loki handler with optional authentication. The
- # processor label is NOT baked in here — it's stamped onto
- # each record by _ProcessorIdFilter reading the task-local
- # contextvar, and logging_loki's emitter reads record.tags
- # to build per-record Loki labels.
+ instance_id = str(uuid.uuid4())[:8]
+
loki_handler_kwargs = {
'url': loki_url,
'version': "1",
+ 'tags': {'instance': instance_id},
}
if loki_username and loki_password:
diff --git a/trustgraph-base/trustgraph/base/qdrant_config.py b/trustgraph-base/trustgraph/base/qdrant_config.py
new file mode 100644
index 00000000..f3e015ca
--- /dev/null
+++ b/trustgraph-base/trustgraph/base/qdrant_config.py
@@ -0,0 +1,87 @@
+
+import os
+import argparse
+from typing import Optional, Any, Tuple
+
+
+def get_qdrant_defaults() -> dict:
+ return {
+ 'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
+ 'api_key': os.getenv('QDRANT_API_KEY'),
+ 'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
+ 'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
+ }
+
+
+def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
+ defaults = get_qdrant_defaults()
+
+ url_help = f"Qdrant URL (default: {defaults['url']})"
+ if 'QDRANT_URL' in os.environ:
+ url_help += " [from QDRANT_URL]"
+
+ api_key_help = "Qdrant API key"
+ if defaults['api_key']:
+ api_key_help += " (default: )"
+ if 'QDRANT_API_KEY' in os.environ:
+ api_key_help += " [from QDRANT_API_KEY]"
+
+ replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
+ if 'QDRANT_REPLICATION_FACTOR' in os.environ:
+ replication_help += " [from QDRANT_REPLICATION_FACTOR]"
+
+ shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
+ if 'QDRANT_SHARD_NUMBER' in os.environ:
+ shard_help += " [from QDRANT_SHARD_NUMBER]"
+
+ parser.add_argument(
+ '--store-uri',
+ default=defaults['url'],
+ help=url_help,
+ )
+
+ parser.add_argument(
+ '--api-key',
+ default=defaults['api_key'],
+ help=api_key_help,
+ )
+
+ parser.add_argument(
+ '--qdrant-replication-factor',
+ type=int,
+ default=defaults['replication_factor'],
+ help=replication_help,
+ )
+
+ parser.add_argument(
+ '--qdrant-shard-number',
+ type=int,
+ default=defaults['shard_number'],
+ help=shard_help,
+ )
+
+
+def resolve_qdrant_config(
+ args: Optional[Any] = None,
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ replication_factor: Optional[int] = None,
+ shard_number: Optional[int] = None,
+) -> Tuple[str, Optional[str], int, int]:
+ if args is not None:
+ url = url or getattr(args, 'store_uri', None)
+ api_key = api_key or getattr(args, 'api_key', None)
+ replication_factor = replication_factor or getattr(
+ args, 'qdrant_replication_factor', None
+ )
+ shard_number = shard_number or getattr(
+ args, 'qdrant_shard_number', None
+ )
+
+ defaults = get_qdrant_defaults()
+ url = url or defaults['url']
+ api_key = api_key or defaults['api_key']
+ replication_factor = replication_factor or defaults['replication_factor']
+ shard_number = shard_number or defaults['shard_number']
+
+ return url, api_key, replication_factor, shard_number
diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
index 3830bf59..3f09b41b 100644
--- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py
+++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
@@ -2,7 +2,8 @@ from typing import Dict, Any, Tuple, Optional
from ...schema import (
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
DocumentEmbeddings, ChunkEmbeddings,
- Metadata, EntityEmbeddings
+ Metadata, EntityEmbeddings,
+ LibraryMetadata, LibraryBlob,
)
from .base import MessageTranslator
from .primitives import ValueTranslator, SubgraphTranslator
@@ -61,6 +62,27 @@ class KnowledgeRequestTranslator(MessageTranslator):
]
)
+ library_metadata = None
+ if "library-metadata" in data:
+ lm = data["library-metadata"]
+ library_metadata = LibraryMetadata(
+ id=lm.get("id", ""),
+ kind=lm.get("kind", ""),
+ title=lm.get("title", ""),
+ parent_id=lm.get("parent-id", ""),
+ document_type=lm.get("document-type", ""),
+ comments=lm.get("comments", ""),
+ tags=lm.get("tags", []),
+ )
+
+ library_blob = None
+ if "library-blob" in data:
+ lb = data["library-blob"]
+ library_blob = LibraryBlob(
+ id=lb.get("id", ""),
+ data=lb.get("data", b""),
+ )
+
return KnowledgeRequest(
operation=data.get("operation"),
id=data.get("id"),
@@ -69,6 +91,8 @@ class KnowledgeRequestTranslator(MessageTranslator):
triples=triples,
graph_embeddings=graph_embeddings,
document_embeddings=document_embeddings,
+ library_metadata=library_metadata,
+ library_blob=library_blob,
)
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
@@ -125,6 +149,26 @@ class KnowledgeRequestTranslator(MessageTranslator):
],
}
+ if obj.library_metadata:
+ result["library-metadata"] = {
+ "id": obj.library_metadata.id,
+ "kind": obj.library_metadata.kind,
+ "title": obj.library_metadata.title,
+ "parent-id": obj.library_metadata.parent_id,
+ "document-type": obj.library_metadata.document_type,
+ "comments": obj.library_metadata.comments,
+ "tags": obj.library_metadata.tags,
+ }
+
+ if obj.library_blob:
+ data = obj.library_blob.data
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ result["library-blob"] = {
+ "id": obj.library_blob.id,
+ "data": data,
+ }
+
return result
@@ -194,6 +238,32 @@ class KnowledgeResponseTranslator(MessageTranslator):
}
}
+ # Streaming library metadata response
+ if obj.library_metadata:
+ return {
+ "library-metadata": {
+ "id": obj.library_metadata.id,
+ "kind": obj.library_metadata.kind,
+ "title": obj.library_metadata.title,
+ "parent-id": obj.library_metadata.parent_id,
+ "document-type": obj.library_metadata.document_type,
+ "comments": obj.library_metadata.comments,
+ "tags": obj.library_metadata.tags,
+ }
+ }
+
+ # Streaming library blob response
+ if obj.library_blob:
+ data = obj.library_blob.data
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ return {
+ "library-blob": {
+ "id": obj.library_blob.id,
+ "data": data,
+ }
+ }
+
# End of stream marker
if obj.eos is True:
return {"eos": True}
@@ -209,7 +279,9 @@ class KnowledgeResponseTranslator(MessageTranslator):
is_final = (
obj.ids is not None or # List response
obj.eos is True or # End of stream
- (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
+ (not obj.triples and not obj.graph_embeddings
+ and not obj.document_embeddings
+ and not obj.library_metadata and not obj.library_blob) # Empty response
)
return response, is_final
\ No newline at end of file
diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
index a3879103..4353065b 100644
--- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
+++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
@@ -21,6 +21,21 @@ from .embeddings import GraphEmbeddings, DocumentEmbeddings
# <- ()
# <- (error)
+@dataclass
+class LibraryMetadata:
+ id: str = ""
+ kind: str = ""
+ title: str = ""
+ parent_id: str = ""
+ document_type: str = ""
+ comments: str = ""
+ tags: list[str] = field(default_factory=list)
+
+@dataclass
+class LibraryBlob:
+ id: str = ""
+ data: bytes = b""
+
@dataclass
class KnowledgeRequest:
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
@@ -44,6 +59,10 @@ class KnowledgeRequest:
# put-de-core
document_embeddings: DocumentEmbeddings | None = None
+ # put-kg-core (source material)
+ library_metadata: LibraryMetadata | None = None
+ library_blob: LibraryBlob | None = None
+
@dataclass
class KnowledgeResponse:
error: Error | None = None
@@ -52,6 +71,8 @@ class KnowledgeResponse:
triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None
document_embeddings: DocumentEmbeddings | None = None
+ library_metadata: LibraryMetadata | None = None
+ library_blob: LibraryBlob | None = None
knowledge_request_queue = queue('knowledge', cls='request')
knowledge_response_queue = queue('knowledge', cls='response')
diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py
index 62fa7ca2..f4d44cca 100644
--- a/trustgraph-cli/trustgraph/cli/get_document_content.py
+++ b/trustgraph-cli/trustgraph/cli/get_document_content.py
@@ -5,7 +5,7 @@ Gets document content from the library by document ID.
import argparse
import os
import sys
-from trustgraph.api import Api
+import requests
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@@ -13,15 +13,29 @@ default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def get_content(url, document_id, output_file, token=None, workspace="default"):
- api = Api(url, token=token, workspace=workspace).library()
+ stream_url = url.rstrip("/") + "/api/v1/document-stream"
- content = api.get_document_content(id=document_id)
+ params = {
+ "document-id": document_id,
+ "workspace": workspace,
+ }
+
+ headers = {}
+ if token:
+ headers["Authorization"] = f"Bearer {token}"
+
+ resp = requests.get(stream_url, params=params, headers=headers, stream=True)
+ resp.raise_for_status()
if output_file:
+ total = 0
with open(output_file, 'wb') as f:
- f.write(content)
- print(f"Written {len(content)} bytes to {output_file}")
+ for chunk in resp.iter_content(chunk_size=65536):
+ f.write(chunk)
+ total += len(chunk)
+ print(f"Written {total} bytes to {output_file}")
else:
+ content = resp.content
try:
text = content.decode('utf-8')
print(text)
diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py
index b4f37b81..2ff1a3cc 100644
--- a/trustgraph-cli/trustgraph/cli/get_kg_core.py
+++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py
@@ -47,6 +47,31 @@ def write_ge(f, data):
)
f.write(msgpack.packb(msg, use_bin_type=True))
+def write_library_metadata(f, data):
+ msg = (
+ "lm",
+ {
+ "i": data["id"],
+ "k": data.get("kind", ""),
+ "t": data.get("title", ""),
+ "p": data.get("parent-id", ""),
+ "d": data.get("document-type", ""),
+ "c": data.get("comments", ""),
+ "g": data.get("tags", []),
+ }
+ )
+ f.write(msgpack.packb(msg, use_bin_type=True))
+
+def write_library_blob(f, data):
+ msg = (
+ "lb",
+ {
+ "i": data["id"],
+ "d": data.get("data", b""),
+ }
+ )
+ f.write(msgpack.packb(msg, use_bin_type=True))
+
def fetch(url, workspace, id, output, token=None):
api = Api(url=url, token=token, workspace=workspace)
@@ -55,6 +80,8 @@ def fetch(url, workspace, id, output, token=None):
try:
ge = 0
t = 0
+ lm = 0
+ lb = 0
with open(output, "wb") as f:
@@ -68,7 +95,15 @@ def fetch(url, workspace, id, output, token=None):
ge += 1
write_ge(f, response["graph-embeddings"])
- print(f"Got: {t} triple, {ge} GE messages.")
+ if "library-metadata" in response:
+ lm += 1
+ write_library_metadata(f, response["library-metadata"])
+
+ if "library-blob" in response:
+ lb += 1
+ write_library_blob(f, response["library-blob"])
+
+ print(f"Got: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
finally:
socket.close()
diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py
index dccf548e..5649a5ae 100644
--- a/trustgraph-cli/trustgraph/cli/load_structured_data.py
+++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py
@@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic)
- discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
+ discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
@@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
- auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
+ auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
@@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display)
- response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
+ response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
if response:
# Debug: print response type and content
@@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
- schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
+ schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
@@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function
- descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
+ descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
if descriptor:
# Output the generated descriptor
@@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode
-def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
+def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
"""Auto-discover the best matching schema for the input data
Args:
@@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
- api = Api(api_url, workspace=workspace)
+ api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get available schemas
@@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None
-def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
+def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
@@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
- api = Api(api_url, workspace=workspace)
+ api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get schema definition
diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py
index fe0981a5..f4e0b3dd 100644
--- a/trustgraph-cli/trustgraph/cli/put_kg_core.py
+++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py
@@ -40,6 +40,23 @@ def read_message(unpacked, id):
},
"triples": msg["t"],
}
+ elif unpacked[0] == "lm":
+ msg = unpacked[1]
+ return "lm", {
+ "id": msg["i"],
+ "kind": msg.get("k", ""),
+ "title": msg.get("t", ""),
+ "parent-id": msg.get("p", ""),
+ "document-type": msg.get("d", ""),
+ "comments": msg.get("c", ""),
+ "tags": msg.get("g", []),
+ }
+ elif unpacked[0] == "lb":
+ msg = unpacked[1]
+ return "lb", {
+ "id": msg["i"],
+ "data": msg.get("d", b""),
+ }
else:
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
@@ -51,6 +68,8 @@ def put(url, workspace, id, input, token=None):
try:
ge = 0
t = 0
+ lm = 0
+ lb = 0
with open(input, "rb") as f:
@@ -73,10 +92,18 @@ def put(url, workspace, id, input, token=None):
t += 1
socket.put_kg_core(id, triples=msg)
+ elif kind == "lm":
+ lm += 1
+ socket.put_kg_core(id, library_metadata=msg)
+
+ elif kind == "lb":
+ lb += 1
+ socket.put_kg_core(id, library_blob=msg)
+
else:
raise RuntimeError("Unexpected message kind", kind)
- print(f"Put: {t} triple, {ge} GE messages.")
+ print(f"Put: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
finally:
socket.close()
diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py
index c5fac198..725f1106 100644
--- a/trustgraph-flow/trustgraph/config/service/service.py
+++ b/trustgraph-flow/trustgraph/config/service/service.py
@@ -83,7 +83,8 @@ class Processor(AsyncProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="config"
+ default_keyspace="config",
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration
diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py
index f1fa53f5..6f017c43 100644
--- a/trustgraph-flow/trustgraph/cores/knowledge.py
+++ b/trustgraph-flow/trustgraph/cores/knowledge.py
@@ -1,6 +1,7 @@
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
-from .. schema import DocumentEmbeddings
+from .. schema import DocumentEmbeddings, LibraryMetadata, LibraryBlob
+from .. schema import LibrarianRequest, DocumentMetadata
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.knowledge import KnowledgeTableStore
@@ -18,7 +19,7 @@ class KnowledgeManager:
def __init__(
self, cassandra_host, cassandra_username, cassandra_password,
- keyspace, flow_config, replication_factor=1,
+ keyspace, flow_config, librarian=None, replication_factor=1,
):
self.table_store = KnowledgeTableStore(
@@ -26,6 +27,9 @@ class KnowledgeManager:
replication_factor
)
+ self.librarian = librarian
+ self._pending_library_metadata = {}
+
self.loader_queue = asyncio.Queue(maxsize=20)
self.background_task = None
self.flow_config = flow_config
@@ -86,6 +90,9 @@ class KnowledgeManager:
publish_ge,
)
+ if self.librarian:
+ await self._stream_library_docs(request.id, respond)
+
logger.debug("Knowledge core retrieval complete")
await respond(
@@ -122,6 +129,12 @@ class KnowledgeManager:
workspace, request.graph_embeddings
)
+ if request.library_metadata and self.librarian:
+ await self._put_library_metadata(request.library_metadata, workspace)
+
+ if request.library_blob and self.librarian:
+ await self._put_library_blob(request.library_blob, workspace)
+
await respond(
KnowledgeResponse(
error = None,
@@ -250,6 +263,112 @@ class KnowledgeManager:
await self.loader_queue.put((request, respond, workspace))
+ async def _stream_library_docs(self, document_id, respond):
+
+ try:
+ root_meta = await self.librarian.fetch_document_metadata(
+ document_id
+ )
+ except Exception as e:
+ logger.warning(f"Could not fetch library metadata for {document_id}: {e}")
+ return
+
+ if root_meta is None:
+ return
+
+ await self._stream_one_doc(root_meta, respond)
+
+ try:
+ resp = await self.librarian.request(
+ LibrarianRequest(
+ operation="list-children",
+ document_id=document_id,
+ )
+ )
+ except Exception as e:
+ logger.warning(f"Could not list children for {document_id}: {e}")
+ return
+
+ for child_meta in resp.document_metadatas:
+ await self._stream_one_doc(child_meta, respond)
+
+ async def _stream_one_doc(self, doc_meta, respond):
+
+ lm = LibraryMetadata(
+ id=doc_meta.id,
+ kind=doc_meta.kind,
+ title=doc_meta.title,
+ parent_id=doc_meta.parent_id,
+ document_type=doc_meta.document_type,
+ comments=doc_meta.comments,
+ tags=doc_meta.tags or [],
+ )
+
+ await respond(
+ KnowledgeResponse(library_metadata=lm)
+ )
+
+ try:
+ content = await self.librarian.fetch_document_content(
+ doc_meta.id
+ )
+ except Exception as e:
+ logger.warning(f"Could not fetch content for {doc_meta.id}: {e}")
+ return
+
+ await respond(
+ KnowledgeResponse(
+ library_blob=LibraryBlob(
+ id=doc_meta.id,
+ data=content,
+ )
+ )
+ )
+
+ async def _put_library_metadata(self, lm, workspace):
+ self._pending_library_metadata[lm.id] = lm
+
+ async def _put_library_blob(self, lb, workspace):
+
+ lm = self._pending_library_metadata.pop(lb.id, None)
+ if lm is None:
+ logger.warning(
+ f"Received library blob for {lb.id} with no preceding metadata"
+ )
+ return
+
+ doc_meta = DocumentMetadata(
+ id=lm.id,
+ kind=lm.kind,
+ title=lm.title,
+ parent_id=lm.parent_id,
+ document_type=lm.document_type,
+ comments=lm.comments,
+ tags=lm.tags or [],
+ )
+
+ if lm.parent_id:
+ operation = "add-child-document"
+ else:
+ operation = "add-document"
+
+ try:
+ await self.librarian.request(
+ LibrarianRequest(
+ operation=operation,
+ document_id=lm.id,
+ document_metadata=doc_meta,
+ content=lb.data,
+ )
+ )
+ except RuntimeError as e:
+ if "already exists" in str(e):
+ logger.debug(f"Library document {lm.id} already exists, skipping")
+ else:
+ logger.warning(f"Could not save library document {lm.id}: {e}")
+ except Exception as e:
+ logger.warning(f"Could not save library document {lm.id}: {e}")
+
async def core_loader(self):
logger.info("Knowledge background processor running...")
diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py
index a04e42ca..5c50c207 100755
--- a/trustgraph-flow/trustgraph/cores/service.py
+++ b/trustgraph-flow/trustgraph/cores/service.py
@@ -12,6 +12,7 @@ import logging
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
+from .. base import LibrarianClient
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
from .. schema import knowledge_request_queue, knowledge_response_queue
@@ -60,7 +61,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="knowledge"
+ default_keyspace="knowledge",
+ replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts
@@ -77,12 +79,17 @@ class Processor(WorkspaceProcessor):
}
)
+ self.librarian_client = LibrarianClient(
+ id=id, backend=self.pubsub, taskgroup=self.taskgroup,
+ )
+
self.knowledge = KnowledgeManager(
cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password,
keyspace = keyspace,
flow_config = self,
+ librarian = self.librarian_client,
replication_factor = replication_factor,
)
@@ -156,6 +163,7 @@ class Processor(WorkspaceProcessor):
async def start(self):
await super(Processor, self).start()
+ await self.librarian_client.start()
async def on_knowledge_config(self, workspace, config, version):
diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
index f214111d..40ecac8a 100755
--- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
+++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
@@ -219,7 +219,14 @@ class Processor(FlowProcessor):
source_doc_id = v.document_id or v.metadata.id
# Run OCR, get per-page markdown
- pages = self.ocr(blob)
+ try:
+ pages = self.ocr(blob)
+ except Exception as e:
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
for markdown, page_num in pages:
diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
index 209153f6..ae393028 100755
--- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
+++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder"
+def _looks_like_pdf(content):
+ return content.lstrip().startswith(b"%PDF-")
+
+
class Processor(FlowProcessor):
def __init__(self, **params):
@@ -94,33 +98,37 @@ class Processor(FlowProcessor):
)
return
- with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
+ # Check if we should fetch from librarian or use inline data
+ if v.document_id:
+ # Fetch from librarian via Pulsar
+ logger.info(f"Fetching document {v.document_id} from librarian...")
+
+ content = await flow.librarian.fetch_document_content(
+ document_id=v.document_id,
+
+ )
+
+ # Content is base64 encoded
+ if isinstance(content, str):
+ content = content.encode('utf-8')
+ decoded_content = base64.b64decode(content)
+
+ logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
+ else:
+ # Use inline data (backward compatibility)
+ decoded_content = base64.b64decode(v.data)
+
+ if not _looks_like_pdf(decoded_content):
+ logger.error(
+ f"Document {v.metadata.id} is not valid PDF content. "
+ f"Ignoring document."
+ )
+ return
+
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
temp_path = fp.name
-
- # Check if we should fetch from librarian or use inline data
- if v.document_id:
- # Fetch from librarian via Pulsar
- logger.info(f"Fetching document {v.document_id} from librarian...")
- fp.close()
-
- content = await flow.librarian.fetch_document_content(
- document_id=v.document_id,
-
- )
-
- # Content is base64 encoded
- if isinstance(content, str):
- content = content.encode('utf-8')
- decoded_content = base64.b64decode(content)
-
- with open(temp_path, 'wb') as f:
- f.write(decoded_content)
-
- logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
- else:
- # Use inline data (backward compatibility)
- fp.write(base64.b64decode(v.data))
- fp.close()
+ fp.write(decoded_content)
+ fp.close()
global PyPDFLoader
if PyPDFLoader is None:
@@ -129,7 +137,15 @@ class Processor(FlowProcessor):
)
PyPDFLoader = _cls
loader = PyPDFLoader(temp_path)
- pages = loader.load()
+ try:
+ pages = loader.load()
+ except Exception as e:
+ source_doc_id = v.document_id or v.metadata.id
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py
index d7abd1a9..f1e4a577 100644
--- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py
+++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py
@@ -6,7 +6,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
from ..tables.cassandra_async import async_execute
@@ -41,13 +41,15 @@ class KnowledgeGraph:
def __init__(
self, hosts=None,
- keyspace="trustgraph", username=None, password=None
+ keyspace="trustgraph", username=None, password=None,
+ replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
+ self.replication_factor = replication_factor
self.username = username
# 7-table schema for quads with full query pattern support
@@ -68,7 +70,7 @@ class KnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@@ -92,7 +94,7 @@ class KnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
- 'replication_factor' : 1
+ 'replication_factor' : {self.replication_factor}
}};
""")
@@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
def __init__(
self, hosts=None,
- keyspace="trustgraph", username=None, password=None
+ keyspace="trustgraph", username=None, password=None,
+ replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
+ self.replication_factor = replication_factor
self.username = username
# 2-table entity-centric schema
@@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
- 'replication_factor' : 1
+ 'replication_factor' : {self.replication_factor}
}};
""")
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
index 6696afbe..90080cc4 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
@@ -73,6 +73,39 @@ class CoreExport:
enc = msgpack.packb(msg)
await response.write(enc)
+ if "library-metadata" in resp:
+
+ data = resp["library-metadata"]
+ msg = (
+ "lm",
+ {
+ "i": data["id"],
+ "k": data.get("kind", ""),
+ "t": data.get("title", ""),
+ "p": data.get("parent-id", ""),
+ "d": data.get("document-type", ""),
+ "c": data.get("comments", ""),
+ "g": data.get("tags", []),
+ }
+ )
+
+ enc = msgpack.packb(msg)
+ await response.write(enc)
+
+ if "library-blob" in resp:
+
+ data = resp["library-blob"]
+ msg = (
+ "lb",
+ {
+ "i": data["id"],
+ "d": data.get("data", b""),
+ }
+ )
+
+ enc = msgpack.packb(msg, use_bin_type=True)
+ await response.write(enc)
+
await kr.process(
{
"operation": "get-kg-core",
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
index d03d4efd..bf660def 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
@@ -79,6 +79,39 @@ class CoreImport:
await kr.process(msg)
+ elif unpacked[0] == "lm":
+ msg = unpacked[1]
+ msg = {
+ "operation": "put-kg-core",
+ "workspace": workspace,
+ "id": id,
+ "library-metadata": {
+ "id": msg["i"],
+ "kind": msg.get("k", ""),
+ "title": msg.get("t", ""),
+ "parent-id": msg.get("p", ""),
+ "document-type": msg.get("d", ""),
+ "comments": msg.get("c", ""),
+ "tags": msg.get("g", []),
+ }
+ }
+
+ await kr.process(msg)
+
+ elif unpacked[0] == "lb":
+ msg = unpacked[1]
+ msg = {
+ "operation": "put-kg-core",
+ "workspace": workspace,
+ "id": id,
+ "library-blob": {
+ "id": msg["i"],
+ "data": msg.get("d", b""),
+ }
+ }
+
+ await kr.process(msg)
+
except Exception as e:
logger.error(f"Core import exception: {e}", exc_info=True)
await error(str(e))
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
index 2992d99f..74b4d7df 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
@@ -3,6 +3,7 @@ import asyncio
import uuid
import logging
from . librarian import LibrarianRequestor
+from ... schema import librarian_request_queue, librarian_response_queue
# Module logger
logger = logging.getLogger(__name__)
@@ -23,10 +24,13 @@ class DocumentStreamExport:
response = await ok()
+ uid = str(uuid.uuid4())
lr = LibrarianRequestor(
backend=self.backend,
- consumer="api-gateway-doc-stream-" + str(uuid.uuid4()),
- subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()),
+ consumer="api-gateway-doc-stream-" + uid,
+ subscriber="api-gateway-doc-stream-" + uid,
+ request_queue=f"{librarian_request_queue}:{workspace}",
+ response_queue=f"{librarian_response_queue}:{workspace}",
)
try:
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
index bdbd18d8..9b119f8e 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
@@ -4,6 +4,8 @@ import queue
import uuid
import logging
+from ..capabilities import PUBLIC, AUTHENTICATED
+
# Module logger
logger = logging.getLogger(__name__)
@@ -156,37 +158,41 @@ class Mux:
})
return
- # Resolve workspace first (default-fill from the caller's
- # bound workspace), then ask the regime to authorise the
- # service-level capability against the matched
- # operation's resource shape.
+ # Resolve workspace (default-fill from the caller's
+ # bound workspace). Workspace resolution applies to all
+ # operations regardless of capability level.
try:
await enforce_workspace(data, self.identity, self.auth)
if isinstance(inner, dict):
await enforce_workspace(inner, self.identity, self.auth)
- if data.get("flow"):
- resource = {
- "workspace": data.get("workspace", ""),
- "flow": data.get("flow", ""),
- }
- parameters = {}
- else:
- # Build a minimal RequestContext so the matched
- # operation's own extractors decide resource and
- # parameters — same path the HTTP endpoints take.
- from ..registry import RequestContext
- ctx = RequestContext(
- body=inner if isinstance(inner, dict) else {},
- match_info={},
- identity=self.identity,
- )
- resource = op.extract_resource(ctx)
- parameters = op.extract_parameters(ctx)
+ # Authorisation: capability sentinels short-circuit
+ # the regime call; capability strings go through
+ # authorise().
+ if op.capability not in (PUBLIC, AUTHENTICATED):
+ if data.get("flow"):
+ resource = {
+ "workspace": data.get("workspace", ""),
+ "flow": data.get("flow", ""),
+ }
+ parameters = {}
+ else:
+ # Build a minimal RequestContext so the matched
+ # operation's own extractors decide resource
+ # and parameters — same path the HTTP
+ # endpoints take.
+ from ..registry import RequestContext
+ ctx = RequestContext(
+ body=inner if isinstance(inner, dict) else {},
+ match_info={},
+ identity=self.identity,
+ )
+ resource = op.extract_resource(ctx)
+ parameters = op.extract_parameters(ctx)
- await self.auth.authorise(
- self.identity, op.capability, resource, parameters,
- )
+ await self.auth.authorise(
+ self.identity, op.capability, resource, parameters,
+ )
except _web.HTTPNotFound:
await self.ws.send_json({
"id": request_id,
@@ -288,6 +294,8 @@ class Mux:
await self.maybe_tidy_workers(workers)
async def responder(resp, fin):
+ if self.ws is None:
+ return
await self.ws.send_json({
"id": id,
"response": resp,
@@ -321,6 +329,8 @@ class Mux:
)
except Exception as e:
+ if self.ws is None:
+ return
await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
index f53ad73b..af6183db 100644
--- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
+++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
@@ -117,8 +117,10 @@ class SocketEndpoint:
running = Running()
+ params = dict(request.query)
+ params.update(request.match_info)
dispatcher = await self.dispatcher(
- ws, running, request.match_info
+ ws, running, params
)
worker_task = tg.create_task(
diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py
index 8ce22757..b2f3976d 100644
--- a/trustgraph-flow/trustgraph/iam/service/service.py
+++ b/trustgraph-flow/trustgraph/iam/service/service.py
@@ -101,6 +101,7 @@ class Processor(AsyncProcessor):
username=cassandra_username,
password=cassandra_password,
default_keyspace="iam",
+ replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts
diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py
index 1c4d010e..cc5f0bdf 100644
--- a/trustgraph-flow/trustgraph/librarian/librarian.py
+++ b/trustgraph-flow/trustgraph/librarian/librarian.py
@@ -162,6 +162,9 @@ class Librarian:
request.document_id
)
+ if object_id is None:
+ raise RequestError(f"Document not found: {request.document_id}")
+
content = await self.blob_store.get(
object_id
)
diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py
index cc5efdae..4d3efbfb 100755
--- a/trustgraph-flow/trustgraph/librarian/service.py
+++ b/trustgraph-flow/trustgraph/librarian/service.py
@@ -8,6 +8,7 @@ import asyncio
import base64
import json
import logging
+import os
from datetime import datetime
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
@@ -54,6 +55,16 @@ default_object_store_access_key = "object-user"
default_object_store_secret_key = "object-password"
default_object_store_use_ssl = False
default_object_store_region = None
+
+# Environment variables consulted as a fallback when the
+# corresponding params field is not set in the processor-group YAML
+# or via CLI. Intended for K8s Secret / env-var injection so
+# credentials never have to live in the YAML (and thus in git).
+ENV_OBJECT_STORE_ENDPOINT = "OBJECT_STORE_ENDPOINT"
+ENV_OBJECT_STORE_ACCESS_KEY = "OBJECT_STORE_ACCESS_KEY"
+ENV_OBJECT_STORE_SECRET_KEY = "OBJECT_STORE_SECRET_KEY"
+ENV_OBJECT_STORE_USE_SSL = "OBJECT_STORE_USE_SSL"
+ENV_OBJECT_STORE_REGION = "OBJECT_STORE_REGION"
default_cassandra_host = "cassandra"
default_min_chunk_size = 1 # No minimum by default (for Garage)
@@ -89,22 +100,36 @@ class Processor(WorkspaceProcessor):
"config_response_queue", default_config_response_queue
)
- object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint)
- object_store_access_key = params.get(
- "object_store_access_key",
- default_object_store_access_key
+ # Resolve object-store config. Precedence: explicit params
+ # (CLI / processor-group YAML) → environment variable →
+ # hardcoded default. The env-var path lets K8s Secrets feed
+ # credentials without them appearing in the YAML.
+ object_store_endpoint = (
+ params.get("object_store_endpoint")
+ or os.environ.get(ENV_OBJECT_STORE_ENDPOINT)
+ or default_object_store_endpoint
)
- object_store_secret_key = params.get(
- "object_store_secret_key",
- default_object_store_secret_key
+ object_store_access_key = (
+ params.get("object_store_access_key")
+ or os.environ.get(ENV_OBJECT_STORE_ACCESS_KEY)
+ or default_object_store_access_key
)
- object_store_use_ssl = params.get(
- "object_store_use_ssl",
- default_object_store_use_ssl
+ object_store_secret_key = (
+ params.get("object_store_secret_key")
+ or os.environ.get(ENV_OBJECT_STORE_SECRET_KEY)
+ or default_object_store_secret_key
)
- object_store_region = params.get(
- "object_store_region",
- default_object_store_region
+ object_store_use_ssl = params.get("object_store_use_ssl")
+ if object_store_use_ssl is None:
+ env_ssl = os.environ.get(ENV_OBJECT_STORE_USE_SSL)
+ if env_ssl is not None:
+ object_store_use_ssl = env_ssl.lower() in ("true", "1", "yes")
+ else:
+ object_store_use_ssl = default_object_store_use_ssl
+ object_store_region = (
+ params.get("object_store_region")
+ or os.environ.get(ENV_OBJECT_STORE_REGION)
+ or default_object_store_region
)
min_chunk_size = params.get(
@@ -121,7 +146,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="librarian"
+ default_keyspace="librarian",
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
index f6770744..de25a139 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
@@ -12,31 +12,33 @@ from qdrant_client import QdrantClient
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
-default_store_uri = 'http://localhost:6333'
-
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
- #optional api key
- api_key = params.get("api_key", None)
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri,
+ api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
async def query_document_embeddings(self, workspace, msg):
@@ -85,18 +87,7 @@ class Processor(DocumentEmbeddingsQueryService):
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'API key for qdrant (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
index 167130c9..aa93925d 100755
--- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
@@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
-default_store_uri = 'http://localhost:6333'
-
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
- #optional api key
- api_key = params.get("api_key", None)
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'API key for qdrant (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
index b7f0f423..a9005ee4 100644
--- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
+++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
@@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
- WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
+ WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
""")
# Create triples table optimized for SPARQL queries
diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
index 1534c044..7e1a5851 100644
--- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
@@ -19,12 +19,12 @@ from .... schema import (
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
-default_store_uri = 'http://localhost:6333'
default_concurrency = 10
@@ -35,13 +35,17 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
"id": id,
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
@@ -62,7 +66,7 @@ class Processor(FlowProcessor):
)
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@@ -192,21 +196,9 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- """Add command-line arguments"""
FlowProcessor.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help='API key for Qdrant (default: None)'
- )
+ add_qdrant_args(parser)
parser.add_argument(
'-c', '--concurrency',
diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
index 7157daae..f9868d67 100644
--- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
+++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
@@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
-from .... tables.cassandra_async import async_execute
+from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from ... graphql import GraphQLSchemaBuilder, SortDirection
@@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
- indexed=field_def.get("indexed", False)
+ indexed=field_def.get("indexed", False),
)
fields.append(field)
@@ -232,6 +232,8 @@ class Processor(FlowProcessor):
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
+ if value == "" or value is None:
+ continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
@@ -282,11 +284,13 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
- rows = await async_execute(self.session, query, params)
- for row in rows:
- # Convert data map to dict with proper field names
- row_dict = dict(row.data) if row.data else {}
- results.append(row_dict)
+ pages = await async_execute_paged(
+ self.session, query, params
+ )
+ for page in pages:
+ for row in page:
+ row_dict = dict(row.data) if row.data else {}
+ results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
@@ -308,8 +312,6 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
- # We need to scan all values for this index
- # This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
@@ -320,17 +322,18 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
- rows = await async_execute(self.session, query, params)
-
- for row in rows:
+ def row_filter(row):
row_dict = dict(row.data) if row.data else {}
+ return self._matches_filters(row_dict, filters, row_schema)
- # Apply post-filters
- if self._matches_filters(row_dict, filters, row_schema):
- results.append(row_dict)
-
- if limit and len(results) >= limit:
- break
+ matched_rows = await async_scan(
+ self.session, query, params,
+ row_filter=row_filter,
+ limit=limit,
+ )
+ for row in matched_rows:
+ row_dict = dict(row.data) if row.data else {}
+ results.append(row_dict)
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
@@ -363,7 +366,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
- if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
+ if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
@@ -400,6 +403,18 @@ class Processor(FlowProcessor):
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
+ elif operator == 'not':
+ if str(row_value) == str(filter_value):
+ return False
+ elif operator == 'startsWith':
+ if not str(row_value).startswith(str(filter_value)):
+ return False
+ elif operator == 'endsWith':
+ if not str(row_value).endswith(str(filter_value)):
+ return False
+ elif operator == 'not_in':
+ if str(row_value) in [str(v) for v in filter_value]:
+ return False
except (ValueError, TypeError):
return False
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
index 2bfef99c..08d88849 100644
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
@@ -14,29 +14,36 @@ from qdrant_client.models import Distance, VectorParams
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-write"
-default_store_uri = 'http://localhost:6333'
-
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -61,6 +68,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -109,18 +118,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def add_args(parser):
DocumentEmbeddingsStoreService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'Qdrant API key (default: None)'
- )
+ add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
index 13dcdba8..b6072bdc 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
@@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
from .... schema import IRI, LITERAL
# Module logger
@@ -29,29 +30,34 @@ def get_term_value(term):
elif term.type == LITERAL:
return term.value
else:
- # For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
-default_store_uri = 'http://localhost:6333'
-
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -76,6 +82,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -128,18 +136,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def add_args(parser):
GraphEmbeddingsStoreService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'Qdrant API key'
- )
+ add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py
index 162a4057..f6e12a85 100644
--- a/trustgraph-flow/trustgraph/storage/knowledge/store.py
+++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py
@@ -27,7 +27,8 @@ class Processor(FlowProcessor):
host=params.get("cassandra_host"),
username=params.get("cassandra_username"),
password=params.get("cassandra_password"),
- default_keyspace='knowledge'
+ default_keyspace='knowledge',
+ replication_factor=params.get("cassandra_replication_factor"),
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
index a01629c5..4c65edb1 100644
--- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
@@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
-default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
@@ -41,13 +41,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
id = params.get("id", default_ident)
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
+ )
super(Processor, self).__init__(
**params | {
"id": id,
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
@@ -63,7 +69,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -103,6 +111,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
size=dimension,
distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -249,21 +259,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
@staticmethod
def add_args(parser):
- """Add command-line arguments"""
FlowProcessor.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help='Qdrant API key (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index 65eeee06..31fc41a7 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -47,16 +47,18 @@ class Processor(CollectionConfigHandler, FlowProcessor):
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
- hosts, username, password, keyspace, _ = resolve_cassandra_config(
+ hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
- password=cassandra_password
+ password=cassandra_password,
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
+ self.replication_factor = replication_factor
# Config key for schemas
self.config_key = params.get("config_type", "schema")
@@ -170,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
- indexed=field_def.get("indexed", False)
+ indexed=field_def.get("indexed", False),
)
fields.append(field)
@@ -232,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
- 'replication_factor': 1
+ 'replication_factor': {self.replication_factor}
}}
"""
diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py
index 2f497748..fe410a26 100644
--- a/trustgraph-flow/trustgraph/tables/cassandra_async.py
+++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py
@@ -27,6 +27,8 @@ Notes:
import asyncio
+from cassandra.query import SimpleStatement
+
async def async_execute(session, query, parameters=None):
"""Execute a CQL statement asynchronously.
@@ -76,3 +78,83 @@ def _set_result_if_pending(fut, result):
def _set_exception_if_pending(fut, exc):
if not fut.done():
fut.set_exception(exc)
+
+
+async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
+ """Execute a CQL query with page-by-page iteration.
+
+ Uses synchronous session.execute() inside run_in_executor so that
+ the driver's ResultSet paging works correctly without materialising
+ the entire result set in memory.
+
+ Returns all pages as a list of lists.
+ """
+ loop = asyncio.get_running_loop()
+
+ if isinstance(query, str):
+ stmt = SimpleStatement(query, fetch_size=fetch_size)
+ else:
+ stmt = query
+ stmt.fetch_size = fetch_size
+
+ def _fetch_all_pages():
+ pages = []
+ result_set = session.execute(stmt, parameters)
+ while True:
+ pages.append(list(result_set.current_rows))
+ if result_set.has_more_pages:
+ result_set.fetch_next_page()
+ else:
+ break
+ return pages
+
+ return await loop.run_in_executor(
+ None, _fetch_all_pages
+ )
+
+
+async def async_scan(
+ session, query, parameters=None, row_filter=None,
+ limit=None, fetch_size=5000,
+):
+ """Scan a CQL query page-by-page, applying a filter and limit.
+
+ Only matching rows accumulate in memory. Each page is discarded
+ after processing, so peak memory is bounded by fetch_size plus
+ the number of matching rows (capped by limit).
+
+ Args:
+ session: cassandra.cluster.Session
+ query: CQL statement string
+ parameters: bind params
+ row_filter: callable(row) -> bool, or None to accept all
+ limit: max results to return, or None for unlimited
+ fetch_size: rows per Cassandra page fetch
+
+ Returns:
+ List of matching rows.
+ """
+ loop = asyncio.get_running_loop()
+
+ if isinstance(query, str):
+ stmt = SimpleStatement(query, fetch_size=fetch_size)
+ else:
+ stmt = query
+ stmt.fetch_size = fetch_size
+
+ def _scan():
+ results = []
+ result_set = session.execute(stmt, parameters)
+ while True:
+ for row in result_set.current_rows:
+ if row_filter is None or row_filter(row):
+ results.append(row)
+ if limit and len(results) >= limit:
+ return results
+ if result_set.has_more_pages:
+ result_set.fetch_next_page()
+ else:
+ break
+ return results
+
+ return await loop.run_in_executor(None, _scan)
diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py
index 74ceb6f4..c87cb3b5 100644
--- a/trustgraph-flow/trustgraph/tables/config.py
+++ b/trustgraph-flow/trustgraph/tables/config.py
@@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -33,7 +33,7 @@ class ConfigTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py
index d7bf5e3d..b60e9cff 100644
--- a/trustgraph-flow/trustgraph/tables/iam.py
+++ b/trustgraph-flow/trustgraph/tables/iam.py
@@ -15,7 +15,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
from . cassandra_async import async_execute
@@ -39,7 +39,7 @@ class IamTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password,
)
diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py
index cf085fdd..53a12b35 100644
--- a/trustgraph-flow/trustgraph/tables/knowledge.py
+++ b/trustgraph-flow/trustgraph/tables/knowledge.py
@@ -5,7 +5,7 @@ from .. schema import DocumentEmbeddings, ChunkEmbeddings
from cassandra.cluster import Cluster
-from . cassandra_async import async_execute
+from . cassandra_async import async_execute, async_execute_paged
def term_to_tuple(term):
@@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
else:
return Term(type=LITERAL, value=value)
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -50,7 +50,7 @@ class KnowledgeTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
@@ -98,7 +98,8 @@ class KnowledgeTableStore:
text, boolean, text, boolean, text, boolean
>>,
triples list>,
PRIMARY KEY ((workspace, document_id), id)
);
@@ -234,7 +235,8 @@ class KnowledgeTableStore:
triples = [
(
- *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
+ *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
+ v.g or ""
)
for v in m.triples
]
@@ -398,7 +400,7 @@ class KnowledgeTableStore:
logger.debug("Get triples...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_triples_stmt,
(workspace, document_id),
@@ -407,29 +409,31 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- triples = [
- Triple(
- s = tuple_to_term(elt[0], elt[1]),
- p = tuple_to_term(elt[2], elt[3]),
- o = tuple_to_term(elt[4], elt[5]),
+ if row[3]:
+ triples = [
+ Triple(
+ s = tuple_to_term(elt[0], elt[1]),
+ p = tuple_to_term(elt[2], elt[3]),
+ o = tuple_to_term(elt[4], elt[5]),
+ g = elt[6] if elt[6] else None,
+ )
+ for elt in row[3]
+ ]
+ else:
+ triples = []
+
+ await receiver(
+ Triples(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ triples = triples
)
- for elt in row[3]
- ]
- else:
- triples = []
-
- await receiver(
- Triples(
- metadata = Metadata(
- id = document_id,
- collection = "default", # FIXME: What to put here?
- ),
- triples = triples
)
- )
logger.debug("Done")
@@ -438,7 +442,7 @@ class KnowledgeTableStore:
logger.debug("Get GE...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_graph_embeddings_stmt,
(workspace, document_id),
@@ -447,28 +451,29 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- entities = [
- EntityEmbeddings(
- entity = tuple_to_term(ent[0][0], ent[0][1]),
- vector = ent[1]
+ if row[3]:
+ entities = [
+ EntityEmbeddings(
+ entity = tuple_to_term(ent[0][0], ent[0][1]),
+ vector = ent[1]
+ )
+ for ent in row[3]
+ ]
+ else:
+ entities = []
+
+ await receiver(
+ GraphEmbeddings(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ entities = entities
)
- for ent in row[3]
- ]
- else:
- entities = []
-
- await receiver(
- GraphEmbeddings(
- metadata = Metadata(
- id = document_id,
- collection = "default", # FIXME: What to put here?
- ),
- entities = entities
)
- )
logger.debug("Done")
@@ -477,7 +482,7 @@ class KnowledgeTableStore:
logger.debug("Get DE...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_document_embeddings_stmt,
(workspace, document_id),
@@ -486,28 +491,29 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- chunks = [
- ChunkEmbeddings(
- chunk_id=ch[0],
- vector=ch[1],
+ if row[3]:
+ chunks = [
+ ChunkEmbeddings(
+ chunk_id=ch[0],
+ vector=ch[1],
+ )
+ for ch in row[3]
+ ]
+ else:
+ chunks = []
+
+ await receiver(
+ DocumentEmbeddings(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ chunks = chunks
)
- for ch in row[3]
- ]
- else:
- chunks = []
-
- await receiver(
- DocumentEmbeddings(
- metadata = Metadata(
- id = document_id,
- collection = "default",
- ),
- chunks = chunks
)
- )
logger.debug("Done")
diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py
index 58486f0e..5094e103 100644
--- a/trustgraph-flow/trustgraph/tables/library.py
+++ b/trustgraph-flow/trustgraph/tables/library.py
@@ -24,7 +24,7 @@ from .. exceptions import RequestError
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -53,7 +53,7 @@ class LibraryTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py
index 7378db64..11b975b2 100755
--- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py
+++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py
@@ -8,71 +8,180 @@ import logging
import json
import uuid
import argparse
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from collections.abc import AsyncIterator
from functools import partial
from mcp.server.fastmcp import FastMCP, Context
-from mcp.types import TextContent
-from websockets.asyncio.client import connect
+from mcp.server.auth.provider import AccessToken, TokenVerifier
+from mcp.server.auth.middleware.auth_context import get_access_token
from trustgraph.base.logging import add_logging_args, setup_logging
-from . tg_socket import WebSocketManager
+from . tg_socket import WebSocketManager, _token_key
+
+logger = logging.getLogger(__name__)
+
+
+# Wire-format Term type codes (match TermTranslator compact keys)
+_TERM_TYPES = {
+ "iri": "i",
+ "literal": "l",
+ "blank": "b",
+}
+
+
+def _make_term(value: str, term_type: str) -> dict:
+ """Build a compact-key Term dict for the gateway wire format.
+
+ Args:
+ value: The term value (IRI string, literal text, or blank node id).
+ term_type: One of "iri", "literal", "blank".
+ """
+ t = _TERM_TYPES.get(term_type)
+ if t is None:
+ raise ValueError(
+ f"Unknown term type '{term_type}' — "
+ f"expected one of: {', '.join(_TERM_TYPES)}"
+ )
+
+ if t == "i":
+ return {"t": t, "i": value}
+ elif t == "l":
+ return {"t": t, "v": value}
+ elif t == "b":
+ return {"t": t, "d": value}
+ return {"t": t}
+
+# ── Security boundary: MCP client → MCP server ──
+# The MCP client authenticates to this server via a Bearer token in the
+# HTTP Authorization header. The SDK's auth middleware extracts and
+# verifies the token before any tool handler runs.
+#
+# We implement a pass-through TokenVerifier: the gateway is the real
+# authority, so we accept any non-empty Bearer token here and forward
+# it to the gateway for validation. The gateway's in-band auth
+# protocol and IAM regime decide whether the token is valid.
+#
+# This means an invalid token will connect to the MCP server but will
+# fail when the first WebSocket auth frame is sent to the gateway.
+# That is intentional — the gateway is the single source of truth.
+
+
+class PassthroughTokenVerifier(TokenVerifier):
+ """Accept any non-empty Bearer token and forward it downstream.
+
+ The TrustGraph gateway is the authority for token validation, not
+ this MCP server. We store the raw token in the AccessToken so that
+ tool handlers can retrieve it via ``get_access_token().token`` and
+ forward it to the gateway.
+ """
+
+ async def verify_token(self, token: str) -> AccessToken | None:
+ if not token:
+ return None
+ return AccessToken(
+ token=token,
+ client_id="mcp-caller",
+ scopes=[],
+ )
+
@dataclass
class AppContext:
- sockets: dict[str, WebSocketManager]
- websocket_url: str
- gateway_token: str
+ sockets: dict[str, WebSocketManager] = field(default_factory=dict)
+ websocket_url: str = ""
+
@asynccontextmanager
-async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
+async def app_lifespan(
+ server: FastMCP,
+ websocket_url: str = "ws://api-gateway:8088/api/v1/socket",
+) -> AsyncIterator[AppContext]:
+ """Manage per-server state: the pool of per-caller WebSocket
+ connections to the gateway."""
- """
- Manage application lifecycle with type-safe context
- """
-
- # Initialize on startup
- sockets = {}
+ sockets: dict[str, WebSocketManager] = {}
try:
- yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
+ yield AppContext(sockets=sockets, websocket_url=websocket_url)
finally:
- # Cleanup on shutdown
- logging.info("Shutting down context")
+ logger.info("Shutting down — closing %d WebSocket(s)", len(sockets))
- for k, manager in sockets.items():
- logging.info(f"Closing socket for {k}")
- await manager.stop()
+ for key, manager in sockets.items():
+ try:
+ await manager.stop()
+ except Exception as e:
+ logger.warning("Error closing socket %s: %s", key, e)
- logging.info("Shutdown complete")
+ logger.info("Shutdown complete")
-async def get_socket_manager(ctx):
+
+def _require_token() -> str:
+ """Extract the caller's Bearer token from the MCP auth context.
+
+ Raises RuntimeError if no token is present (the caller did not
+ authenticate).
+ """
+ # ── Security boundary: token extraction ──
+ # get_access_token() reads the contextvar set by the SDK's
+ # AuthContextMiddleware. The token was placed there by
+ # PassthroughTokenVerifier.verify_token() and is the raw Bearer
+ # value from the MCP client's Authorization header.
+ access = get_access_token()
+ if access is None or not access.token:
+ raise RuntimeError(
+ "Authentication required — send a Bearer token in the "
+ "Authorization header"
+ )
+ return access.token
+
+
+async def get_socket_manager(ctx, token):
+ """Return (or create) an authenticated WebSocket for this token.
+
+ Each unique token gets its own WebSocket connection so that
+ gateway-side identity, workspace binding, and capability scoping
+ are preserved per caller.
+ """
lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url
- gateway_token = lifespan_context.gateway_token
- if "default" in sockets:
- logging.info("Return existing socket manager")
- return sockets["default"]
+ key = _token_key(token)
- logging.info(f"Opening socket to {websocket_url}...")
+ if key in sockets:
+ manager = sockets[key]
+ if manager.socket is not None:
+ return manager
+ # Socket was closed (e.g. server-side timeout) — reconnect.
+ del sockets[key]
- # Create manager with empty pending requests
- manager = WebSocketManager(websocket_url, token=gateway_token)
+ logger.info("Opening authenticated WebSocket to %s …", websocket_url)
- # Start reader task with the proper manager
+ manager = WebSocketManager(websocket_url, token=token)
await manager.start()
- sockets["default"] = manager
+ # Verify the token is valid by calling whoami. This confirms the
+ # gateway accepted the token and gives us the caller's identity.
+ try:
+ identity = await manager.whoami()
+ logger.info(
+ "WebSocket ready — caller: %s",
+ identity.get("handle", "unknown"),
+ )
+ except Exception as e:
+ await manager.stop()
+ raise RuntimeError(
+ f"Token rejected by gateway (whoami failed): {e}"
+ ) from e
- logging.info("Return new socket manager")
+ sockets[key] = manager
return manager
+
@dataclass
class EmbeddingsResponse:
vectors: List[List[float]]
@@ -182,10 +291,23 @@ class PutConfigResponse:
class DeleteConfigResponse:
pass
+@dataclass
+class SparqlQueryResponse:
+ query_type: str
+ variables: List[str]
+ bindings: List[Dict[str, Any]]
+ ask_result: bool
+ triples: List[Dict[str, Any]]
+
+@dataclass
+class GraphQLQueryResponse:
+ data: Any
+ errors: List[Dict[str, Any]]
+
@dataclass
class GetPromptsResponse:
prompts: List[str]
-
+
@dataclass
class GetPromptResponse:
prompt: Dict[str, Any]
@@ -194,31 +316,61 @@ class GetPromptResponse:
class GetSystemPromptResponse:
prompt: str
+
class McpServer:
- def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
+ def __init__(
+ self,
+ host: str = "0.0.0.0",
+ port: int = 8000,
+ websocket_url: str = "ws://api-gateway:8088/api/v1/socket",
+ auth_issuer: str = "",
+ auth_resource_url: str = "",
+ ):
self.host = host
self.port = port
self.websocket_url = websocket_url
- self.gateway_token = gateway_token
- # Create a partial function to pass websocket_url to app_lifespan
- lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
-
+ lifespan_with_url = partial(
+ app_lifespan, websocket_url=websocket_url,
+ )
+
+ # ── Security: MCP-level auth configuration ──
+ # The SDK requires AuthSettings whenever a token_verifier is
+ # present. The issuer_url tells MCP clients where to obtain
+ # tokens; resource_server_url identifies this server in OAuth
+ # protected-resource metadata.
+ #
+ # The PassthroughTokenVerifier accepts any non-empty Bearer
+ # token — real validation happens at the gateway. This is
+ # intentional: the gateway is the single source of truth for
+ # identity and capability checks.
+ from mcp.server.auth.settings import AuthSettings
+
+ auth_settings = AuthSettings(
+ issuer_url=auth_issuer or f"http://{host}:{port}",
+ resource_server_url=auth_resource_url or f"http://{host}:{port}",
+ )
+
self.mcp = FastMCP(
- "TrustGraph", dependencies=["trustgraph-base"],
- host=self.host, port=self.port,
+ "TrustGraph",
+ dependencies=["trustgraph-base"],
+ host=self.host,
+ port=self.port,
lifespan=lifespan_with_url,
+ token_verifier=PassthroughTokenVerifier(),
+ auth=auth_settings,
)
self._register_tools()
-
+
def _register_tools(self):
"""Register all MCP tools"""
- # Register all the tools that were previously registered globally
self.mcp.tool()(self.embeddings)
self.mcp.tool()(self.text_completion)
self.mcp.tool()(self.graph_rag)
self.mcp.tool()(self.agent)
self.mcp.tool()(self.triples_query)
+ self.mcp.tool()(self.sparql_query)
+ self.mcp.tool()(self.graphql_query)
self.mcp.tool()(self.graph_embeddings_query)
self.mcp.tool()(self.get_config_all)
self.mcp.tool()(self.get_config)
@@ -243,67 +395,69 @@ class McpServer:
self.mcp.tool()(self.load_document)
self.mcp.tool()(self.remove_document)
self.mcp.tool()(self.add_processing)
-
+
def run(self):
"""Run the MCP server"""
self.mcp.run(transport="streamable-http")
+ async def _get_manager(self, ctx):
+ """Get an authenticated WebSocket manager for the current caller.
+
+ Extracts the Bearer token from the MCP auth context and returns
+ a per-token WebSocket connection to the gateway.
+ """
+ token = _require_token()
+ return await get_socket_manager(ctx, token)
+
async def embeddings(
self,
- text: str,
+ texts: List[str],
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> EmbeddingsResponse:
"""
- Generate vector embeddings for the given text using TrustGraph's embedding models.
-
+ Generate vector embeddings for the given texts using TrustGraph's embedding models.
+
This tool converts text into high-dimensional vectors that capture semantic meaning,
enabling similarity searches, clustering, and other vector-based operations.
-
+
Args:
- text: The input text to convert into embeddings. Can be a sentence, paragraph,
- or document. The text will be processed by the configured embedding model.
+ texts: List of input texts to convert into embeddings. Each text can be a
+ sentence, paragraph, or document.
flow_id: Optional flow identifier to use for processing (default: "default").
Different flows may use different embedding models or configurations.
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
- EmbeddingsResponse containing a list of vectors. Each vector is a list of floats
- representing the text's semantic embedding in the model's vector space.
-
- Example usage:
- - Convert a query into embeddings for similarity search
- - Generate embeddings for documents before storing them
- - Create embeddings for comparison with existing knowledge
+ EmbeddingsResponse containing a list of vectors, one per input text.
"""
- logging.info("Embeddings request made")
+ logger.info("Embeddings request")
if flow_id is None: flow_id = "default"
- manager = await get_socket_manager(ctx, "trustgraph")
+ manager = await self._get_manager(ctx)
- if ctx is None:
- raise RuntimeError("No context provided")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Computing embeddings via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- await ctx.session.send_log_message(
- level="info",
- data=f"Computing embeddings via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
+ request_data = {"texts": texts}
+
+ gen = manager.request(
+ "embeddings", request_data, flow_id, workspace=workspace,
)
- # Send websocket request
- request_data = {"text": text}
- logging.info("making request")
-
- gen = manager.request("embeddings", request_data, flow_id)
-
async for response in gen:
-
- # Extract vectors from response
vectors = response.get("vectors", [[]])
break
-
+
return EmbeddingsResponse(vectors=vectors)
async def text_completion(
@@ -311,62 +465,47 @@ class McpServer:
prompt: str,
system: str | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> TextCompletionResponse:
"""
Generate text completions using TrustGraph's language models.
-
- This tool sends prompts to configured language models and returns generated text.
- It supports both user prompts and system instructions for controlling generation.
-
+
Args:
prompt: The main prompt or question to send to the language model.
- This is the primary input that guides the model's response.
system: Optional system prompt that sets the context, role, or behavior
- for the AI assistant (e.g., "You are a helpful coding assistant").
- System prompts influence how the model interprets and responds.
- flow_id: Optional flow identifier (default: "default"). Different flows
- may use different models, parameters, or processing pipelines.
-
+ for the AI assistant.
+ flow_id: Optional flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
TextCompletionResponse containing the generated text response from the model.
-
- Example usage:
- - Ask questions and get AI-generated answers
- - Generate code, documentation, or creative content
- - Perform text analysis, summarization, or transformation tasks
- - Use system prompts to control tone, style, or domain expertise
"""
if system is None: system = ""
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- # Use websocket if context is available
- logging.info("Text completion request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Generating text completion via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Generating text completion via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Send websocket request
request_data = {"system": system, "prompt": prompt}
- gen = manager.request("text-completion", request_data, flow_id)
+ gen = manager.request(
+ "text-completion", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
-
- # Extract vectors from response
text = response.get("response", "")
break
-
+
return TextCompletionResponse(response=text)
async def graph_rag(
@@ -378,58 +517,43 @@ class McpServer:
max_subgraph_size: int | None = None,
max_path_length: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> GraphRagResponse:
"""
Perform Graph-based Retrieval Augmented Generation (GraphRAG) queries.
-
+
GraphRAG combines knowledge graph traversal with language model generation to provide
- contextually rich answers. It explores relationships between entities to build relevant
- context before generating responses.
-
+ contextually rich answers.
+
Args:
question: The question or query to answer using the knowledge graph.
- The system will find relevant entities and relationships to inform the response.
collection: Knowledge collection to query (default: "default").
- Different collections may contain domain-specific knowledge.
entity_limit: Maximum number of entities to retrieve during graph traversal.
- Higher limits provide more context but increase processing time.
triple_limit: Maximum number of relationship triples to consider.
- Controls the depth of relationship exploration.
max_subgraph_size: Maximum size of the subgraph to extract for context.
- Larger subgraphs provide richer context but use more resources.
max_path_length: Maximum path length to traverse in the knowledge graph.
- Longer paths can discover distant but relevant relationships.
flow_id: Processing flow to use (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
GraphRagResponse containing the generated answer informed by knowledge graph context.
-
- Example usage:
- - Answer complex questions requiring multi-hop reasoning
- - Explore relationships between entities in your knowledge base
- - Generate responses grounded in structured knowledge
- - Perform research queries across connected information
"""
if collection is None: collection = "default"
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("GraphRAG request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing GraphRAG query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing GraphRAG query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with all parameters
request_data = {
"query": question
}
@@ -440,20 +564,19 @@ class McpServer:
if max_subgraph_size: request_data["max_subgraph_size"] = max_subgraph_size
if max_path_length: request_data["max_path_length"] = max_path_length
- gen = manager.request("graph-rag", request_data, flow_id)
+ gen = manager.request(
+ "graph-rag", request_data, flow_id, workspace=workspace,
+ )
text_chunks = []
async for response in gen:
- # Handle new message format with message_type
message_type = response.get("message_type", "chunk")
- # Only collect text from chunk messages
if message_type == "chunk":
chunk_text = response.get("response", "")
if chunk_text:
text_chunks.append(chunk_text)
- # Check if session is complete
if response.get("end_of_session"):
break
@@ -464,404 +587,447 @@ class McpServer:
question: str,
collection: str | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> AgentResponse:
"""
Execute intelligent agent queries with reasoning and tool usage capabilities.
-
- The agent can perform complex multi-step reasoning, use tools, and provide
- detailed thought processes. It's designed for tasks requiring planning,
- analysis, and iterative problem-solving.
-
+
Args:
- question: The question or task for the agent to solve. Can be complex
- queries requiring multiple steps, analysis, or tool usage.
+ question: The question or task for the agent to solve.
collection: Knowledge collection the agent can access (default: "default").
- Determines what information and tools are available.
- flow_id: Agent workflow to use (default: "default"). Different flows
- may have different capabilities, tools, or reasoning strategies.
-
+ flow_id: Agent workflow to use (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
AgentResponse containing the final answer after the agent's reasoning process.
- During execution, you'll see intermediate thoughts and observations.
-
- Example usage:
- - Solve complex analytical problems requiring multiple steps
- - Perform research tasks across multiple information sources
- - Handle queries that need tool usage and decision-making
- - Get detailed explanations of reasoning processes
-
- Note: This tool provides real-time updates on the agent's thinking process
- through log messages, so you can follow its reasoning steps.
"""
if collection is None: collection = "default"
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Agent request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing agent query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing agent query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with all parameters
request_data = {
"question": question
}
if collection: request_data["collection"] = collection
- gen = manager.request("agent", request_data, flow_id)
+ gen = manager.request(
+ "agent", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- logging.debug(f"Agent response: {response}")
+ logger.debug("Agent response: %s", response)
- if "thought" in response:
- await ctx.session.send_log_message(
- level="info",
- data=f"Thinking: {response['thought']}",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ if "thought" in response:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Thinking: {response['thought']}",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- if "observation" in response:
- await ctx.session.send_log_message(
- level="info",
- data=f"Observation: {response['observation']}",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if "observation" in response:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Observation: {response['observation']}",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- # Extract vectors from response
if "answer" in response:
answer = response.get("answer", "")
return AgentResponse(answer=answer)
async def triples_query(
self,
- s_v: str | None = None,
- s_e: bool | None = None,
- p_v: str | None = None,
- p_e: bool | None = None,
- o_v: str | None = None,
- o_e: bool | None = None,
+ s: str | None = None,
+ s_type: str | None = None,
+ p: str | None = None,
+ p_type: str | None = None,
+ o: str | None = None,
+ o_type: str | None = None,
+ collection: str | None = None,
+ graph: str | None = None,
limit: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> TriplesQueryResponse:
"""
Query knowledge graph triples using subject-predicate-object patterns.
-
- Knowledge graphs store information as triples (subject, predicate, object).
- This tool allows flexible querying by specifying any combination of these
- components, with wildcards for unspecified parts.
-
+
+ Each of s, p, o is an RDF term value. Use the corresponding _type
+ parameter to specify the term kind:
+ - "iri" (default for s and p): an IRI / entity reference
+ - "literal" (default for o): a plain literal value
+ - "blank": a blank node identifier
+
Args:
- s_v: Subject value to match (e.g., "John", "Apple Inc."). Leave None for wildcard.
- s_e: Whether subject should be treated as an entity (True) or literal (False).
- p_v: Predicate/relationship value (e.g., "works_for", "type_of"). Leave None for wildcard.
- p_e: Whether predicate should be treated as an entity (True) or literal (False).
- o_v: Object value to match (e.g., "Engineer", "Company"). Leave None for wildcard.
- o_e: Whether object should be treated as an entity (True) or literal (False).
+ s: Subject value to match. Leave None for wildcard.
+ s_type: Subject term type: "iri" (default), "literal", or "blank".
+ p: Predicate value to match. Leave None for wildcard.
+ p_type: Predicate term type: "iri" (default), "literal", or "blank".
+ o: Object value to match. Leave None for wildcard.
+ o_type: Object term type: "iri", "literal" (default), or "blank".
+ collection: Knowledge collection to query (default: "default").
+ graph: Named graph IRI to restrict the query. None = default graph,
+ "*" = all graphs.
limit: Maximum number of triples to return (default: 20).
flow_id: Processing flow identifier (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
TriplesQueryResponse containing matching triples from the knowledge graph.
-
- Example queries:
- - Find all relationships for an entity: s_v="John", others None
- - Find all instances of a relationship: p_v="works_for", others None
- - Find specific facts: s_v="John", p_v="works_for", o_v=None
- - Explore entity types: p_v="type_of", others None
-
- Use this for:
- - Exploring knowledge graph structure
- - Finding specific facts or relationships
- - Discovering connections between entities
- - Validating or debugging knowledge content
"""
if flow_id is None: flow_id = "default"
if limit is None: limit = 20
+ if collection is None: collection = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Triples query request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing triples query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing triples query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with Value objects
request_data = {
- "limit": limit
+ "limit": limit,
+ "collection": collection,
}
- # Add subject if provided
- if s_v is not None:
- request_data["s"] = {"v": s_v, "e": s_e }
+ if s is not None:
+ request_data["s"] = _make_term(s, s_type or "iri")
- # Add predicate if provided
- if p_v is not None:
- request_data["p"] = {"v": p_v, "e": p_e }
+ if p is not None:
+ request_data["p"] = _make_term(p, p_type or "iri")
- # Add object if provided
- if o_v is not None:
- request_data["o"] = {"v": o_v, "e": o_e }
+ if o is not None:
+ request_data["o"] = _make_term(o, o_type or "literal")
- gen = manager.request("triples", request_data, flow_id)
+ if graph is not None:
+ request_data["g"] = graph
+
+ gen = manager.request(
+ "triples", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- # Extract response data
triples = response.get("response", [])
break
-
+
return TriplesQueryResponse(triples=triples)
+ async def sparql_query(
+ self,
+ query: str,
+ collection: str | None = None,
+ limit: int | None = None,
+ flow_id: str | None = None,
+ workspace: str | None = None,
+ ctx: Context = None,
+ ) -> SparqlQueryResponse:
+ """
+ Execute a SPARQL query against the knowledge graph.
+
+ Supports SELECT, ASK, CONSTRUCT, and DESCRIBE query forms.
+
+ Args:
+ query: SPARQL query string (e.g. "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10").
+ collection: Knowledge collection to query (default: "default").
+ limit: Safety limit on number of results (default: 10000).
+ flow_id: Processing flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
+ Returns:
+ SparqlQueryResponse containing the query results. The structure depends
+ on query type:
+ - SELECT: variables (column names) and bindings (rows of Term values)
+ - ASK: ask_result (boolean)
+ - CONSTRUCT/DESCRIBE: triples
+ """
+
+ if collection is None: collection = "default"
+ if flow_id is None: flow_id = "default"
+ if limit is None: limit = 10000
+
+ manager = await self._get_manager(ctx)
+
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing SPARQL query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
+ request_data = {
+ "query": query,
+ "collection": collection,
+ "limit": limit,
+ }
+
+ gen = manager.request(
+ "sparql", request_data, flow_id, workspace=workspace,
+ )
+
+ async for response in gen:
+ query_type = response.get("query-type", "")
+ return SparqlQueryResponse(
+ query_type=query_type,
+ variables=response.get("variables", []),
+ bindings=response.get("bindings", []),
+ ask_result=response.get("ask-result", False),
+ triples=response.get("triples", []),
+ )
+
+ async def graphql_query(
+ self,
+ query: str,
+ collection: str | None = None,
+ variables: Dict[str, Any] | None = None,
+ operation_name: str | None = None,
+ flow_id: str | None = None,
+ workspace: str | None = None,
+ ctx: Context = None,
+ ) -> GraphQLQueryResponse:
+ """
+ Execute a GraphQL query against structured data (rows).
+
+ Queries structured data schemas that have been loaded into TrustGraph.
+ The available types and fields depend on the schemas configured in the
+ target workspace.
+
+ Args:
+ query: GraphQL query string (e.g. '{ customers(where: {status: {eq: "active"}}) { id name } }').
+ collection: Data collection to query (default: "default").
+ variables: Optional GraphQL variables as a dict.
+ operation_name: Optional operation name for multi-operation documents.
+ flow_id: Processing flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
+ Returns:
+ GraphQLQueryResponse containing data (the query result) and errors
+ (any GraphQL field-level errors).
+ """
+
+ if collection is None: collection = "default"
+ if flow_id is None: flow_id = "default"
+
+ manager = await self._get_manager(ctx)
+
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing GraphQL query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
+ request_data = {
+ "query": query,
+ "collection": collection,
+ "variables": variables or {},
+ }
+
+ if operation_name is not None:
+ request_data["operation_name"] = operation_name
+
+ gen = manager.request(
+ "rows", request_data, flow_id, workspace=workspace,
+ )
+
+ async for response in gen:
+ return GraphQLQueryResponse(
+ data=response.get("data"),
+ errors=response.get("errors", []),
+ )
+
async def graph_embeddings_query(
self,
vectors: List[List[float]],
limit: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> GraphEmbeddingsQueryResponse:
"""
Find entities in the knowledge graph using vector similarity search.
-
- This tool performs semantic search by comparing embedding vectors to find
- the most similar entities in the knowledge graph. It's useful for finding
- conceptually related information even when exact text matches don't exist.
-
+
Args:
- vectors: List of embedding vectors to search with. Each vector should be
- a list of floats representing semantic embeddings (typically from
- the embeddings tool). Multiple vectors can be provided for batch queries.
+ vectors: List of embedding vectors to search with.
limit: Maximum number of similar entities to return (default: 20).
- Higher limits provide more results but may include less relevant matches.
flow_id: Processing flow identifier (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GraphEmbeddingsQueryResponse containing entities ranked by similarity to the
- input vectors, along with similarity scores and entity metadata.
-
- Example workflow:
- 1. Use the 'embeddings' tool to convert text to vectors
- 2. Use this tool to find similar entities in the knowledge graph
- 3. Explore the returned entities for relevant information
-
- Use this for:
- - Semantic search across knowledge entities
- - Finding conceptually similar content
- - Discovering related entities without exact keyword matches
- - Building recommendation systems based on entity similarity
+ GraphEmbeddingsQueryResponse containing entities ranked by similarity.
"""
if flow_id is None: flow_id = "default"
if limit is None: limit = 20
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Graph embeddings query request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing graph embeddings query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing graph embeddings query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data
request_data = {
"vectors": vectors,
"limit": limit
}
- gen = manager.request("graph-embeddings", request_data, flow_id)
+ gen = manager.request(
+ "graph-embeddings", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- # Extract entities from response
entities = response.get("entities", [])
break
-
+
return GraphEmbeddingsQueryResponse(entities=entities)
async def get_config_all(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigResponse:
"""
Retrieve the complete TrustGraph system configuration.
-
- This tool returns all configuration settings for the TrustGraph system,
- including model configurations, API keys, flow definitions, and system parameters.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- ConfigResponse containing the full configuration as a nested dictionary
- with all system settings, organized by category (e.g., models, flows, storage).
-
- Use this for:
- - Inspecting current system configuration
- - Debugging configuration issues
- - Understanding available models and settings
- - Auditing system setup and parameters
+ ConfigResponse containing the full configuration as a nested dictionary.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get config all request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving all configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving all configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
break
-
+
return ConfigResponse(config=config)
async def get_config(
self,
keys: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigGetResponse:
"""
Retrieve specific configuration values by key.
-
- This tool allows you to fetch specific configuration settings without
- retrieving the entire configuration. Useful for checking particular
- settings or API keys.
-
+
Args:
- keys: List of configuration keys to retrieve. Each key should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings', 'storage')
- - 'key': Specific setting name within that category
-
+ keys: List of configuration keys to retrieve. Each key should be a dict with
+ 'type' and 'key' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
ConfigGetResponse containing the requested configuration values.
-
- Example keys:
- - {'type': 'llm', 'key': 'openai.model'}
- - {'type': 'embeddings', 'key': 'default.model'}
- - {'type': 'storage', 'key': 'database.url'}
-
- Use this for:
- - Checking specific model configurations
- - Validating API key settings
- - Inspecting individual system parameters
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving specific configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving specific configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get",
"keys": keys
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
values = response.get("values", [])
break
-
+
return ConfigGetResponse(values=values)
async def put_config(
self,
values: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> PutConfigResponse:
"""
Update system configuration values.
-
- This tool allows you to modify TrustGraph system settings, such as
- model parameters, API keys, and system behavior configurations.
-
+
Args:
- values: List of configuration updates. Each update should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings')
- - 'key': Specific setting name to update
- - 'value': New value for the setting
-
+ values: List of configuration updates. Each should be a dict with
+ 'type', 'key', and 'value' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
PutConfigResponse confirming the configuration update.
-
- Example updates:
- - {'type': 'llm', 'key': 'openai.model', 'value': 'gpt-4'}
- - {'type': 'embeddings', 'key': 'batch_size', 'value': '100'}
-
- Use this for:
- - Switching between different models
- - Updating API credentials
- - Modifying system behavior parameters
- - Configuring processing settings
-
- Note: Configuration changes may require system restart to take effect.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Put config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Updating configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Updating configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "put",
"values": values
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
return PutConfigResponse()
@@ -869,97 +1035,73 @@ class McpServer:
async def delete_config(
self,
keys: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> DeleteConfigResponse:
"""
Delete specific configuration entries from the system.
-
- This tool removes configuration settings, reverting them to system defaults
- or disabling specific features.
-
+
Args:
- keys: List of configuration keys to delete. Each key should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings')
- - 'key': Specific setting name to remove
-
+ keys: List of configuration keys to delete. Each should be a dict with
+ 'type' and 'key' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
DeleteConfigResponse confirming the deletion.
-
- Use this for:
- - Removing custom model configurations
- - Clearing API credentials
- - Resetting settings to defaults
- - Cleaning up obsolete configurations
-
- Warning: Deleting essential configuration may cause system functionality
- to be disabled until properly reconfigured.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Delete config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Deleting configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Deleting configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "delete",
"keys": keys
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
return DeleteConfigResponse()
async def get_prompts(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetPromptsResponse:
"""
List all available prompt templates in the system.
-
- Prompt templates are reusable prompts that can be used with language models
- for consistent behavior across different queries and use cases.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
GetPromptsResponse containing a list of available prompt template IDs.
- Each ID can be used with get_prompt to retrieve the full template.
-
- Use this for:
- - Discovering available prompt templates
- - Exploring pre-configured prompts for different tasks
- - Finding templates for specific use cases
- - Understanding what prompt options are available
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get prompts request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving prompt templates via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving prompt templates via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -971,49 +1113,36 @@ class McpServer:
async def get_prompt(
self,
prompt_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetPromptResponse:
"""
Retrieve a specific prompt template by ID.
-
- Prompt templates contain structured prompts with placeholders, instructions,
- and metadata for specific tasks or domains.
-
+
Args:
prompt_id: The unique identifier of the prompt template to retrieve.
- Use get_prompts to see available template IDs.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GetPromptResponse containing the complete prompt template with its
- structure, placeholders, and usage instructions.
-
- Use this for:
- - Examining prompt template structure
- - Understanding how to use specific templates
- - Copying or modifying existing prompts
- - Learning prompt engineering patterns
+ GetPromptResponse containing the complete prompt template.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get prompt request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving prompt template '{prompt_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving prompt template '{prompt_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -1025,44 +1154,35 @@ class McpServer:
async def get_system_prompt(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetSystemPromptResponse:
"""
Retrieve the current system prompt configuration.
-
- The system prompt defines the default behavior, personality, and instructions
- for language models across the TrustGraph system.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GetSystemPromptResponse containing the system prompt text and configuration.
-
- Use this for:
- - Understanding default AI behavior settings
- - Checking current system-wide prompt configuration
- - Auditing AI personality and instruction settings
- - Debugging unexpected AI responses
+ GetSystemPromptResponse containing the system prompt text.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get system prompt request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving system prompt via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving system prompt via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -1073,51 +1193,39 @@ class McpServer:
async def get_token_costs(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigTokenCostsResponse:
"""
Retrieve token pricing information for all configured AI models.
-
- This tool provides cost information for input and output tokens across
- different language models, helping with budget planning and cost optimization.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- ConfigTokenCostsResponse containing pricing data for each model including:
- - Model name/identifier
- - Input token cost (per token)
- - Output token cost (per token)
-
- Use this for:
- - Estimating costs for different models
- - Choosing cost-effective models for tasks
- - Budget planning and cost analysis
- - Monitoring and optimizing AI spending
+ ConfigTokenCostsResponse containing pricing data for each model.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get token costs request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving token costs via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving token costs via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "getvalues",
"type": "token-costs"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
values = response.get("values", [])
- # Transform to match TypeScript API format
costs = []
for item in values:
try:
@@ -1130,106 +1238,89 @@ class McpServer:
except (json.JSONDecodeError, AttributeError):
continue
break
-
+
return ConfigTokenCostsResponse(costs=costs)
async def get_knowledge_cores(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> KnowledgeCoresResponse:
"""
List all available knowledge graph cores in the current workspace.
- Knowledge cores are packaged collections of structured knowledge that can
- be loaded into the system for querying and reasoning. They contain entities,
- relationships, and facts organized as knowledge graphs.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
KnowledgeCoresResponse containing a list of available knowledge core IDs.
-
- Use this for:
- - Discovering available knowledge collections
- - Understanding what knowledge domains are accessible
- - Planning which cores to load for specific tasks
- - Managing knowledge resources
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get knowledge cores request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving knowledge graph cores via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving knowledge graph cores via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-kg-cores",
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
ids = response.get("ids", [])
break
-
+
return KnowledgeCoresResponse(ids=ids)
async def delete_kg_core(
self,
core_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> DeleteKgCoreResponse:
"""
Permanently delete a knowledge graph core.
- This operation removes a knowledge core from storage. Use with caution
- as this action cannot be undone.
-
Args:
core_id: Unique identifier of the knowledge core to delete.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
DeleteKgCoreResponse confirming the deletion.
-
- Use this for:
- - Cleaning up obsolete knowledge cores
- - Removing test or experimental data
- - Managing storage space
- - Maintaining organized knowledge collections
-
- Warning: This permanently deletes the knowledge core and all its data.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Delete KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Deleting knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Deleting knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "delete-kg-core",
"id": core_id,
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return DeleteKgCoreResponse()
async def load_kg_core(
@@ -1237,46 +1328,34 @@ class McpServer:
core_id: str,
flow: str,
collection: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> LoadKgCoreResponse:
"""
Load a knowledge graph core into the active system for querying.
- This operation makes a knowledge core available for GraphRAG queries,
- triple searches, and other knowledge-based operations.
-
Args:
core_id: Unique identifier of the knowledge core to load.
- flow: Processing flow to use for loading the core. Different flows
- may apply different processing, indexing, or optimization steps.
- collection: Target collection name (default: "default"). The loaded
- knowledge will be available under this collection name.
+ flow: Processing flow to use for loading the core.
+ collection: Target collection name (default: "default").
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
LoadKgCoreResponse confirming the core has been loaded.
-
- Use this for:
- - Making knowledge cores available for queries
- - Switching between different knowledge domains
- - Loading domain-specific knowledge for tasks
- - Preparing knowledge for GraphRAG operations
"""
if collection is None: collection = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Load KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Loading knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Loading knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "load-kg-core",
@@ -1285,292 +1364,241 @@ class McpServer:
"collection": collection
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return LoadKgCoreResponse()
async def get_kg_core(
self,
core_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetKgCoreResponse:
"""
Download and retrieve the complete content of a knowledge graph core.
- This tool streams the entire content of a knowledge core, returning all
- entities, relationships, and metadata. Due to potentially large data sizes,
- the content is streamed in chunks.
-
Args:
core_id: Unique identifier of the knowledge core to retrieve.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
GetKgCoreResponse containing all chunks of the knowledge core data.
- Each chunk contains part of the knowledge graph structure.
-
- Use this for:
- - Examining knowledge core content and structure
- - Debugging knowledge graph data
- - Exporting knowledge for backup or analysis
- - Understanding the scope and quality of knowledge
-
- Note: Large knowledge cores may take significant time to download.
- Progress updates are provided through log messages during streaming.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-kg-core",
"id": core_id,
}
- # Collect all streaming responses
chunks = []
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
- # Check for end of stream
if response.get("eos", False):
- await ctx.session.send_log_message(
- level="info",
- data=f"Completed streaming KG core data",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Completed streaming KG core data",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
break
else:
chunks.append(response)
- await ctx.session.send_log_message(
- level="info",
- data=f"Received KG core chunk ({len(chunks)} chunks so far)",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Received KG core chunk ({len(chunks)} chunks so far)",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
return GetKgCoreResponse(chunks=chunks)
async def get_flows(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowsResponse:
"""
List all available processing flows in the system.
-
- Flows define processing pipelines for different types of operations
- (e.g., document processing, knowledge extraction, query handling).
- Each flow encapsulates a specific workflow with configured steps.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
FlowsResponse containing a list of available flow identifiers.
-
- Use this for:
- - Discovering available processing workflows
- - Understanding what processing options are available
- - Choosing appropriate flows for specific tasks
- - Planning workflow-based operations
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flows request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving available flows via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving available flows via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-flows"
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
flow_ids = response.get("flow-ids", [])
break
-
+
return FlowsResponse(flow_ids=flow_ids)
async def get_flow(
self,
flow_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowResponse:
"""
Retrieve the complete definition of a specific processing flow.
-
- This tool returns the detailed configuration, steps, and parameters
- of a processing flow, showing how it processes data and what operations it performs.
-
+
Args:
flow_id: Unique identifier of the flow to retrieve.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- FlowResponse containing the complete flow definition including:
- - Flow configuration and parameters
- - Processing steps and their order
- - Input/output specifications
- - Dependencies and requirements
-
- Use this for:
- - Understanding how specific flows work
- - Debugging flow processing issues
- - Learning flow configuration patterns
- - Customizing or duplicating flows
+ FlowResponse containing the complete flow definition.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow definition for '{flow_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving flow definition for '{flow_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-flow",
"flow-id": flow_id,
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
flow_data = response.get("flow", "{}")
- # Parse JSON flow definition as done in TypeScript
flow = json.loads(flow_data) if isinstance(flow_data, str) else flow_data
break
-
+
return FlowResponse(flow=flow)
async def get_flow_classes(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowClassesResponse:
"""
List all available flow class templates.
-
- Flow classes are templates that define types of processing workflows.
- They serve as blueprints for creating specific flow instances with
- customized parameters.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
FlowClassesResponse containing a list of available flow class names.
-
- Use this for:
- - Discovering available flow templates
- - Understanding what types of processing are supported
- - Planning new flow creation
- - Exploring system capabilities
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow classes request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow classes via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving flow classes via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-classes"
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
class_names = response.get("class-names", [])
break
-
+
return FlowClassesResponse(class_names=class_names)
async def get_flow_class(
self,
class_name: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowClassResponse:
"""
Retrieve the definition of a specific flow class template.
-
- Flow classes define the structure, parameters, and capabilities of
- flow types. This tool returns the class specification including
- configurable parameters and processing logic.
-
+
Args:
class_name: Name of the flow class to retrieve.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- FlowClassResponse containing the flow class definition with:
- - Class parameters and configuration options
- - Processing capabilities and requirements
- - Usage instructions and examples
-
- Use this for:
- - Understanding flow class capabilities
- - Learning how to configure new flows
- - Troubleshooting flow creation issues
- - Exploring advanced flow features
+ FlowClassResponse containing the flow class definition.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow class request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow class definition for '{class_name}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving flow class definition for '{class_name}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-class",
"class-name": class_name
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
class_def_data = response.get("class-definition", "{}")
- # Parse JSON class definition as done in TypeScript
class_definition = json.loads(class_def_data) if isinstance(class_def_data, str) else class_def_data
break
-
+
return FlowClassResponse(class_definition=class_definition)
async def start_flow(
@@ -1578,43 +1606,32 @@ class McpServer:
flow_id: str,
class_name: str,
description: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> StartFlowResponse:
"""
Create and start a new processing flow instance.
-
- This tool creates a new flow based on a flow class template and starts
- it running. The flow will begin processing according to its configuration.
-
+
Args:
flow_id: Unique identifier for the new flow instance.
class_name: Flow class template to use for creating the flow.
- Use get_flow_classes to see available classes.
description: Human-readable description of the flow's purpose.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
StartFlowResponse confirming the flow has been started.
-
- Use this for:
- - Creating new processing workflows
- - Starting automated processing tasks
- - Launching background operations
- - Initiating data processing pipelines
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Start flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "start-flow",
@@ -1623,162 +1640,135 @@ class McpServer:
"description": description
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return StartFlowResponse()
async def stop_flow(
self,
flow_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> StopFlowResponse:
"""
Stop a running flow instance.
-
- This tool gracefully stops a running flow, allowing it to complete
- current operations before shutting down.
-
+
Args:
flow_id: Unique identifier of the flow instance to stop.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
StopFlowResponse confirming the flow has been stopped.
-
- Use this for:
- - Stopping unwanted or completed flows
- - Managing system resources
- - Interrupting long-running processes
- - Maintaining flow lifecycle
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Stop flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Stopping flow '{flow_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Stopping flow '{flow_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "stop-flow",
"flow-id": flow_id
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return StopFlowResponse()
async def get_documents(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> DocumentsResponse:
"""
List all documents stored in the TrustGraph document library.
- This tool returns metadata for all documents that have been uploaded
- to the system, including their processing status and properties.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
- DocumentsResponse containing metadata for each document including:
- - Document ID and title
- - Upload timestamp
- - MIME type and size information
- - Tags and custom metadata
- - Processing status
-
- Use this for:
- - Browsing available documents
- - Managing document collections
- - Finding documents by metadata
- - Auditing document storage
+ DocumentsResponse containing metadata for each document.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get documents request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving documents list via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving documents list via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-documents",
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
document_metadatas = response.get("document-metadatas", [])
break
-
+
return DocumentsResponse(document_metadatas=document_metadatas)
async def get_processing(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ProcessingResponse:
"""
List all documents currently in the processing queue.
- This tool shows documents that are being processed or waiting to be
- processed, along with their processing status and configuration.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
- ProcessingResponse containing processing metadata including:
- - Processing job ID and document ID
- - Processing flow and status
- - Target collection
- - Timestamp and progress information
-
- Use this for:
- - Monitoring document processing progress
- - Debugging processing issues
- - Managing processing queues
- - Understanding system workload
+ ProcessingResponse containing processing metadata.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get processing request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving processing list via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving processing list via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-processing",
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
processing_metadatas = response.get("processing-metadatas", [])
break
-
+
return ProcessingResponse(processing_metadatas=processing_metadatas)
async def load_document(
@@ -1790,50 +1780,39 @@ class McpServer:
title: str = "",
comments: str = "",
tags: List[str] | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> LoadDocumentResponse:
"""
Upload a document to the TrustGraph document library.
- This tool stores documents with rich metadata for later processing,
- search, and knowledge extraction. Documents can be text files, PDFs,
- or other supported formats.
-
Args:
document: The document content as a string. For binary files,
this should be base64-encoded content.
document_id: Optional unique identifier. If not provided, one will be generated.
metadata: Optional list of custom metadata key-value pairs.
- mime_type: MIME type of the document (e.g., 'text/plain', 'application/pdf').
+ mime_type: MIME type of the document.
title: Human-readable title for the document.
comments: Optional description or notes about the document.
- tags: List of tags for categorizing and finding the document.
+ tags: List of tags for categorizing the document.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
LoadDocumentResponse confirming the document has been stored.
-
- Use this for:
- - Adding new documents to the knowledge base
- - Storing reference materials and data sources
- - Building document collections for processing
- - Importing external content for analysis
"""
if tags is None: tags = []
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Load document request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Loading document to library via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Loading document to library via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
import time
timestamp = int(time.time())
@@ -1852,63 +1831,55 @@ class McpServer:
"content": document
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return LoadDocumentResponse()
async def remove_document(
self,
document_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> RemoveDocumentResponse:
"""
Permanently remove a document from the library.
- This operation deletes a document and all its associated metadata.
- Use with caution as this action cannot be undone.
-
Args:
document_id: Unique identifier of the document to remove.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
RemoveDocumentResponse confirming the document has been deleted.
-
- Use this for:
- - Cleaning up obsolete or incorrect documents
- - Managing storage space
- - Removing sensitive or inappropriate content
- - Maintaining organized document collections
-
- Warning: This permanently deletes the document and all its metadata.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Remove document request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Removing document '{document_id}' from library via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Removing document '{document_id}' from library via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "remove-document",
"document-id": document_id,
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return RemoveDocumentResponse()
async def add_processing(
@@ -1918,53 +1889,37 @@ class McpServer:
flow: str,
collection: str | None = None,
tags: List[str] | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> AddProcessingResponse:
"""
Queue a document for processing through a specific workflow.
- This tool adds a document to the processing queue where it will be
- processed by the specified flow to extract knowledge, create embeddings,
- or perform other analysis operations.
-
Args:
processing_id: Unique identifier for this processing job.
document_id: ID of the document to process (must exist in library).
- flow: Processing flow to use. Different flows perform different
- types of analysis (e.g., knowledge extraction, summarization).
+ flow: Processing flow to use.
collection: Target collection for processed knowledge (default: "default").
- Results will be stored under this collection name.
tags: Optional tags for categorizing this processing job.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
AddProcessingResponse confirming the document has been queued.
-
- Use this for:
- - Processing uploaded documents into knowledge
- - Extracting entities and relationships from text
- - Creating searchable embeddings
- - Converting documents into structured knowledge
-
- Note: Processing may take time depending on document size and flow complexity.
- Use get_processing to monitor progress.
"""
if collection is None: collection = "default"
if tags is None: tags = []
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Add processing request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Adding document '{document_id}' to processing queue via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Adding document '{document_id}' to processing queue via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
import time
timestamp = int(time.time())
@@ -1981,38 +1936,61 @@ class McpServer:
}
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return AddProcessingResponse()
+
def main():
parser = argparse.ArgumentParser(description='TrustGraph MCP Server')
- parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)')
- parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)')
- parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)')
+ parser.add_argument(
+ '--host', default='0.0.0.0',
+ help='Host to bind to (default: 0.0.0.0)',
+ )
+ parser.add_argument(
+ '--port', type=int, default=8000,
+ help='Port to bind to (default: 8000)',
+ )
+ parser.add_argument(
+ '--websocket-url',
+ default='ws://api-gateway:8088/api/v1/socket',
+ help='WebSocket URL for the TrustGraph gateway',
+ )
+ parser.add_argument(
+ '--auth-issuer',
+ default=os.environ.get("AUTH_ISSUER", ""),
+ help='OAuth issuer URL for MCP auth metadata discovery',
+ )
+ parser.add_argument(
+ '--auth-resource-url',
+ default=os.environ.get("AUTH_RESOURCE_URL", ""),
+ help='Resource server URL for OAuth protected resource metadata',
+ )
- # Add logging arguments
add_logging_args(parser)
args = parser.parse_args()
- # Setup logging before creating server
setup_logging(vars(args))
- # Read gateway auth token from environment
- gateway_token = os.environ.get("GATEWAY_SECRET", "")
-
- # Create and run the MCP server
- server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
+ server = McpServer(
+ host=args.host,
+ port=args.port,
+ websocket_url=args.websocket_url,
+ auth_issuer=args.auth_issuer,
+ auth_resource_url=args.auth_resource_url,
+ )
server.run()
+
def run():
- """Legacy function for backward compatibility"""
main()
+
if __name__ == "__main__":
main()
-
diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
index bff8ae75..9fbf7459 100644
--- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
+++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
@@ -1,49 +1,110 @@
-from dataclasses import dataclass
from websockets.asyncio.client import connect
-from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
import uuid
-import time
+import hashlib
+
+logger = logging.getLogger(__name__)
+
+
+def _token_key(token):
+ """Derive a dict key from a token without storing the raw secret."""
+ return hashlib.sha256(token.encode()).hexdigest()[:16]
+
class WebSocketManager:
+ """Manages an authenticated WebSocket connection to the TrustGraph
+ gateway on behalf of a single caller.
- def __init__(self, url, token=None):
+ Each caller token gets its own WebSocketManager so that gateway-side
+ identity, workspace, and capability scoping are preserved end-to-end.
+ """
+
+ def __init__(self, url, token):
self.url = url
+ # ── Security boundary: token storage ──
+ # This is the MCP caller's Bearer token, forwarded verbatim to
+ # the gateway. It MUST NOT be logged, persisted, or shared
+ # across callers. It is held only for the lifetime of this
+ # connection so that re-auth (e.g. after a reconnect) is
+ # possible.
self.token = token
self.socket = None
-
- # FIXME: authentication is broken. The /api/v1/socket endpoint uses
- # in-band auth (first-frame protocol via the Mux dispatcher), not
- # query-parameter tokens. This query-string token is silently ignored.
- # Fix: after connect(), send an auth frame with the bearer token as
- # the first message, matching the gateway's in-band auth protocol.
- def _build_url(self):
- if not self.token:
- return self.url
- parsed = urlparse(self.url)
- params = parse_qs(parsed.query)
- params["token"] = [self.token]
- new_query = urlencode(params, doseq=True)
- return urlunparse(parsed._replace(query=new_query))
+ self.identity = None
+ self.last_used = None
async def start(self):
- self.socket = await connect(self._build_url())
+ """Connect and authenticate via the gateway's in-band auth
+ protocol. Raises on auth failure."""
+
+ # ── Security boundary: MCP server → gateway ──
+ # The WebSocket connects to the gateway and authenticates using
+ # the caller's Bearer token via the in-band first-frame auth
+ # protocol. The token belongs to the MCP client — we forward
+ # it as-is and never interpret its contents.
+ self.socket = await connect(self.url)
self.pending_requests = {}
self.running = True
+
+ await self._authenticate()
+
self.reader_task = asyncio.create_task(self.reader())
+ async def _authenticate(self):
+ """Send in-band auth frame and wait for auth-ok / auth-failed.
+
+ The gateway expects ``{"type": "auth", "token": "..."}`` as the
+ first frame on a new WebSocket. Any service frame sent before
+ auth-ok is rejected.
+ """
+ await self.socket.send(json.dumps({
+ "type": "auth",
+ "token": self.token,
+ }))
+
+ response_text = await asyncio.wait_for(self.socket.recv(), 10)
+ response = json.loads(response_text)
+
+ if response.get("type") == "auth-ok":
+ logger.info(
+ "WebSocket authenticated, default workspace: %s",
+ response.get("workspace"),
+ )
+ return
+
+ # Auth failed — close immediately, do not leave an
+ # unauthenticated socket open.
+ await self.socket.close()
+ self.socket = None
+
+ if response.get("type") == "auth-failed":
+ raise RuntimeError(
+ "Gateway rejected the authentication token"
+ )
+
+ raise RuntimeError(
+ f"Unexpected auth response type: {response.get('type')}"
+ )
+
+ async def whoami(self):
+ """Verify the token by calling the gateway's whoami endpoint.
+ Returns the identity dict and caches it on ``self.identity``.
+ """
+ gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
+ async for response in gen:
+ self.identity = response
+ return response
+
async def stop(self):
self.running = False
- await self.reader_task
+ if hasattr(self, "reader_task"):
+ await self.reader_task
async def reader(self):
- """
- Background task to read websocket responses and route to correct
- request
- """
+ """Background task: read WebSocket frames and route them to the
+ correct pending-request queue by ``id``."""
while self.running:
try:
@@ -59,23 +120,21 @@ class WebSocketManager:
request_id = response.get("id")
if request_id and request_id in self.pending_requests:
- # Put the response in the queue
queue = self.pending_requests[request_id]
await queue.put(response)
else:
- logging.warning(
- f"Response for unknown request ID: {request_id}"
+ logger.warning(
+ "Response for unknown request ID: %s", request_id
)
except Exception as e:
- logging.error(f"Error in websocket reader: {e}")
+ logger.error("Error in websocket reader: %s", e)
- # Put error in all pending queues
for queue in self.pending_requests.values():
try:
await queue.put({"error": str(e)})
- except:
+ except Exception:
pass
self.pending_requests.clear()
@@ -86,25 +145,29 @@ class WebSocketManager:
async def request(
self, service, request_data, flow_id="default",
+ workspace=None,
):
- """
- Send a request via websocket and handle single or streaming responses
+ """Send a request via WebSocket and yield responses.
+
+ Args:
+ service: Gateway service name (e.g. "graph-rag", "config").
+ request_data: Inner request payload.
+ flow_id: Optional flow identifier. ``None`` omits the field
+ (workspace-level services don't use flows).
+ workspace: Optional workspace override. When ``None`` the
+ gateway uses the caller's default workspace.
"""
- # Generate unique request ID
+ import time
+ self.last_used = time.monotonic()
+
request_id = f"{uuid.uuid4()}"
- # Determine if this service streams responses
- streaming_services = {"agent"}
- is_streaming = service in streaming_services
-
- # Create a queue for all responses (streaming and single)
response_queue = asyncio.Queue()
self.pending_requests[request_id] = response_queue
try:
- # Build request message
message = {
"id": request_id,
"service": service,
@@ -114,7 +177,16 @@ class WebSocketManager:
if flow_id is not None:
message["flow"] = flow_id
- # Send request
+ # ── Security boundary: workspace scoping ──
+ # When the caller supplies a workspace, we set it on the
+ # message envelope. The gateway's enforce_workspace()
+ # validates that the authenticated identity is permitted
+ # to access the target workspace — we MUST NOT skip or
+ # override that check. When workspace is None, the
+ # gateway default-fills from the identity's bound workspace.
+ if workspace is not None:
+ message["workspace"] = workspace
+
await self.socket.send(json.dumps(message))
while self.running:
@@ -127,19 +199,17 @@ class WebSocketManager:
continue
if "error" in response:
- if "message" in response["error"]:
- raise RuntimeError(response["error"]["text"])
+ if isinstance(response["error"], dict):
+ raise RuntimeError(
+ response["error"].get("message", str(response["error"]))
+ )
else:
raise RuntimeError(str(response["error"]))
yield response["response"]
- if "complete" in response:
- if response["complete"]:
- break
+ if response.get("complete"):
+ break
- except Exception as e:
- # Clean up on error
+ finally:
self.pending_requests.pop(request_id, None)
- raise e
-
diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
index 1b4815c6..0d5101df 100755
--- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
+++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
@@ -107,7 +107,14 @@ class Processor(FlowProcessor):
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
- pages = convert_from_bytes(blob)
+ try:
+ pages = convert_from_bytes(blob)
+ except Exception as e:
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
for ix, page in enumerate(pages):
diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
index b4936786..deedb7b4 100644
--- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
+++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
@@ -418,7 +418,14 @@ class Processor(FlowProcessor):
doc_uri_str = document_uri(source_doc_id)
# Extract elements using unstructured
- elements = self.extract_elements(blob, mime_type)
+ try:
+ elements = self.extract_elements(blob, mime_type)
+ except Exception as e:
+ logger.error(
+ f"Failed to extract elements from {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
if not elements:
logger.warning("No elements extracted from document")