diff --git a/README.md b/README.md index c366a3d9..b66edc70 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,11 @@ trustgraph-ai%2Ftrustgraph | Trendshift -# The agent runtime platform +# The semantic deployment platform -TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads. +TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads. The platform: - [x] Multi-model and multimodal database system @@ -99,23 +99,21 @@ For a browser based configuration, try the [Configuration Terminal](https://conf - [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference) - [**Deployment Guides**](https://docs.trustgraph.ai/deployment) -## Workbench +## Context Graph UI -The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default. +Image -- **Vector Search**: Search the installed knowledge bases -- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs -- **Relationships**: Analyze deep relationships in the installed knowledge bases -- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases -- **Library**: Staging area for installing knowledge bases -- **Flow Classes**: Workflow preset configurations -- **Flows**: Create custom workflows and adjust LLM parameters during runtime -- **Knowledge Cores**: Manage resuable knowledge bases -- **Prompts**: Manage and adjust prompts during runtime -- **Schemas**: Define custom schemas for structured data knowledge bases -- **Ontologies**: Define custom ontologies for unstructured data knowledge bases -- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups -- **MCP Tools**: Connect to MCP servers +The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default. + +- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time +- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from +- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views +- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing +- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs +- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management +- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout +- **Workspace UX** — Workspace selection and management surfaced directly in the interface +- **Prompt Editor** — A dedicated prompt editing workflow ## TypeScript Library for UIs diff --git a/containers/Containerfile.unstructured b/containers/Containerfile.unstructured index 6de8a800..2b9a18f7 100644 --- a/containers/Containerfile.unstructured +++ b/containers/Containerfile.unstructured @@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.13 libxcb mesa-libGL && \ +RUN dnf install -y python3.13 libxcb mesa-libGL poppler-utils && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \ diff --git a/docs/tech-specs/knowledge-core-completeness.md b/docs/tech-specs/knowledge-core-completeness.md new file mode 100644 index 00000000..3ccb41f0 --- /dev/null +++ b/docs/tech-specs/knowledge-core-completeness.md @@ -0,0 +1,535 @@ +--- +layout: default +title: "Knowledge Core Completeness" +parent: "Tech Specs" +--- + +# Knowledge Core Completeness + +## Overview + +Knowledge cores are portable snapshots of extracted knowledge: triples, graph +embeddings, and document embeddings stored in Cassandra's `knowledge` keyspace. +They can be downloaded as files, transferred between TrustGraph instances, and +loaded back into vector and graph stores. + +Recent additions to TrustGraph — explainability/provenance and named graphs — +were not carried through to the knowledge core system. This means that +exporting and re-importing a core loses provenance links, graph assignments, +and source material, breaking the explainability chain. + +This specification addresses three gaps: + +1. **Named graphs not stored** — The `g` (graph name) field on triples is + silently dropped when writing to the core store and comes back as `None` + on read. +2. **Provenance triples not captured** — Provenance triples (PROV-O) are + generated during extraction and flow to graph stores, but never enter + the knowledge core store. It is unclear whether they arrive at the store + in the correct form. +3. **Source material not included** — Documents, text pages, and chunks in + the librarian's bucket store are not part of the core. After loading a + core on a different instance, provenance links to source material point + at nothing. + +## Goals + +- **Self-contained cores**: A downloaded knowledge core file contains + everything needed to reconstruct the full knowledge graph including + provenance and source attribution on a fresh instance. +- **Named graph preservation**: Round-tripping a core preserves graph + assignments on all triples. +- **Backward compatibility**: Existing core files (without graph names or + source material) can still be uploaded and loaded. New fields are optional + on import. +- **No change to core identity**: A core is still identified by its document + ID. The additional data is associated with the same core ID. +- **Minimal file format changes**: Extend the existing msgpack record format + with new record types rather than restructuring existing ones. + +## Background + +### Current Lifecycle + +``` +Extraction pipeline + │ + ├─ triples ──────────────────► knowledge core store (Cassandra) + ├─ graph embeddings ─────────► knowledge core store (Cassandra) + ├─ document embeddings ──────► knowledge core store (Cassandra) + ├─ provenance triples ───────► graph store (only) + └─ source documents ─────────► librarian bucket store (only) + +Download: Cassandra ──► knowledge manager ──► API gateway ──► client file +Upload: client file ──► API gateway ──► knowledge manager ──► Cassandra +Load: Cassandra ──► knowledge manager ──► Pulsar topics ──► graph/vector stores +``` + +### Current Core File Format (msgpack) + +A core file is a sequence of concatenated msgpack records. Each record is a +2-element tuple: `(type_tag, payload)`. + +| Type tag | Payload | Description | +|----------|---------|-------------| +| `"t"` | `{"m": {id, root, collection}, "t": [triple_dicts]}` | Triple batch | +| `"ge"` | `{"m": {id, root, collection}, "e": [{entity, vector}]}` | Graph embedding batch | + +### What's Missing + +#### Named Graphs + +The `Triple` dataclass has a `g: str | None` field (graph name IRI), used to +separate provenance graphs (`urn:graph:source`, `urn:graph:retrieval`) from +the default graph. However: + +- **Cassandra schema** (`knowledge.triples` table): stores a 6-tuple per + triple `(s_val, s_is_uri, p_val, p_is_uri, o_val, o_is_uri)` — no graph + field. +- **`add_triples()`** (`tables/knowledge.py:231`): destructures only `s`, + `p`, `o` — `g` is discarded. +- **`get_triples()`** (`tables/knowledge.py:396`): reconstructs `Triple` + with `g` defaulting to `None`. +- **Core file format**: triple dicts do not include a graph field. + +#### Provenance Triples + +Provenance triples are generated in the extraction pipeline +(`trustgraph-base/trustgraph/provenance/triples.py`) and published to graph +store topics. They use named graphs (`urn:graph:source`, +`urn:graph:retrieval`) and PROV-O vocabulary. + +The knowledge core store processor (`storage/knowledge/store.py`) listens on +`triples-input` and `graph-embeddings-input`. Whether provenance triples +arrive on the same `triples-input` topic or a separate one needs +verification. Even if they do arrive, the graph name would be lost (per +above). + +#### Source Material + +The librarian stores the full document hierarchy in a separate system: + +- **Blob store** (S3/MinIO): original documents, text pages, chunks — + keyed by object UUID under `doc/{object_id}`. +- **Cassandra `library` keyspace**: document metadata including `id`, + `kind` (MIME type), `title`, `parent_id`, `document_type` + (`source`/`extracted`), `object_id` (blob reference). + +Provenance triples link extracted facts back to chunk/page/document IDs. +Those IDs resolve through the librarian. When a core is loaded on a +different instance, the librarian has no matching documents, so the entire +provenance chain is broken. + +### Key Source Files + +| Component | File | Purpose | +|-----------|------|---------| +| Core Cassandra schema | `trustgraph-flow/trustgraph/tables/knowledge.py` | Table definitions, read/write | +| Core manager | `trustgraph-flow/trustgraph/cores/knowledge.py` | API operations, load-to-store | +| Core store processor | `trustgraph-flow/trustgraph/storage/knowledge/store.py` | Extraction → Cassandra | +| CLI download | `trustgraph-cli/trustgraph/cli/get_kg_core.py` | Core → msgpack file | +| CLI upload | `trustgraph-cli/trustgraph/cli/put_kg_core.py` | Msgpack file → core | +| CLI load | `trustgraph-cli/trustgraph/cli/load_kg_core.py` | Core → graph/vector stores | +| API client | `trustgraph-base/trustgraph/api/knowledge.py` | Client-side knowledge API | +| Triple schema | `trustgraph-base/trustgraph/schema/core/primitives.py` | Triple dataclass with `g` field | +| Provenance generation | `trustgraph-base/trustgraph/provenance/triples.py` | PROV-O triple creation | +| Librarian | `trustgraph-flow/trustgraph/librarian/librarian.py` | Document storage service | +| Library tables | `trustgraph-flow/trustgraph/tables/library.py` | Document metadata in Cassandra | +| Blob store | `trustgraph-flow/trustgraph/librarian/blob_store.py` | S3/MinIO object storage | + +## Technical Design + +### Change 1: Named Graph Field in Core Storage + +#### Cassandra Schema + +Extend the `triples` tuple from 6 to 7 elements, adding the graph name: + +``` +triples list> +``` + +**Migration**: The schema change uses `ALTER TABLE` or is handled by +creating a new table version. Existing rows with 6-element tuples must be +handled gracefully on read — if the tuple has 6 elements, treat graph as +default. + +#### Write Path (`add_triples`) + +Change `tables/knowledge.py:add_triples()` to include `triple.g`: + +```python +triples = [ + ( + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o), + v.g or "" + ) + for v in m.triples +] +``` + +#### Read Path (`get_triples`) + +Change `tables/knowledge.py:get_triples()` to restore the graph name: + +```python +Triple( + s = tuple_to_term(elt[0], elt[1]), + p = tuple_to_term(elt[2], elt[3]), + o = tuple_to_term(elt[4], elt[5]), + g = elt[6] if len(elt) > 6 and elt[6] else None, +) +``` + +The `len(elt) > 6` guard provides backward compatibility with existing +6-element rows. + +#### Core File Format + +Extend triple dicts in the `"t"` record to include the graph name: + +```python +# In get_kg_core.py write_triple — each triple dict gains "g" key +{"s": ..., "p": ..., "o": ..., "g": "urn:graph:source"} +``` + +On read (`put_kg_core.py`), treat missing `"g"` key as default graph for +backward compatibility with old core files. + +### Change 2: Provenance Triples in Cores + +#### Investigation Required + +Before implementation, verify: + +1. Whether provenance triples arrive on the `triples-input` topic that the + knowledge core store processor already listens on. +2. If not, which topic they use, and whether the store processor should + subscribe to it. + +#### If provenance triples already arrive at the store + +The only change needed is Change 1 (named graphs) — the provenance triples +are already being stored, just without their graph name. Once graph names +are preserved, provenance triples will round-trip correctly. + +#### If provenance triples do NOT arrive at the store + +Two options: + +**Option A — Route provenance to the existing store topic**: Configure the +flow so provenance triples are published to the same `triples-input` topic. +This is the simpler approach and keeps the store processor unchanged. + +**Option B — Add a subscription**: Add a new `ConsumerSpec` in the store +processor for the provenance topic. This keeps provenance routing +independent but adds complexity. + +Recommendation: Option A, unless there is a reason provenance triples are +intentionally kept off the core store topic. + +### Change 3: Source Material in Cores + +This is the largest change. The goal is that when a core is loaded on a +fresh instance, provenance links to source material resolve. + +#### Architecture + +Source material is **not stored in the knowledge core tables**. It lives in +the librarian (Cassandra `library` keyspace + S3/MinIO blob store) and is +fetched on demand via the librarian's existing service API. + +The knowledge manager acts as a **client of the librarian service** — it +calls the librarian's request/response API over pub/sub to retrieve document +metadata and content. It does not access the library's Cassandra tables or +blob store directly. + +#### Transport + +The librarian's pub/sub API already handles chunking of large documents. +This chunking is designed to be websocket-friendly, so library content +flowing through the API gateway to external clients does not require +re-chunking. The API gateway remains a transport layer. + +``` +Download: + Knowledge manager ──pub/sub──► Librarian (fetch metadata + content) + Knowledge manager ──pub/sub──► API gateway ──websocket──► Client + +Upload: + Client ──websocket──► API gateway ──pub/sub──► Knowledge manager + Knowledge manager ──pub/sub──► Librarian (store metadata + content) +``` + +#### What to Include + +The provenance chain links facts → chunks → pages → documents. For the +chain to resolve, the core must include: + +1. **Document metadata** — the library record for each document in the + hierarchy (id, kind, title, parent_id, document_type, etc.) +2. **Document content** — the blob data for each document (original file, + extracted text pages, text chunks) + +Including the full hierarchy is necessary because: +- A user viewing provenance needs to traverse fact → chunk → page → document +- The chunk text is needed to show what text a fact was extracted from +- The page text provides broader context +- The original document is needed for full source attribution + +#### Size Implications + +Source material will significantly increase core file sizes. A rough model: + +| Component | Typical size per document | +|-----------|-------------------------| +| Triples + embeddings (current) | 1-10 MB | +| Chunk text (all chunks) | ~same as original document | +| Page text (all pages) | ~same as original document | +| Original document (PDF, etc.) | Varies widely (KB to hundreds of MB) | + +For a 10 MB PDF, the core could grow from ~5 MB to ~25 MB (original + +derived text + existing data). For large document sets, cores could become +very large. + +**Decision needed**: Whether to include original documents or just derived +text (pages + chunks). Including only derived text still allows provenance +display but loses the ability to serve the original file. + +#### New Core File Record Types + +Add new msgpack record types for library content: + +| Type tag | Payload | Description | +|----------|---------|-------------| +| `"lm"` | `{"id", "kind", "title", "parent_id", "document_type", "comments", "tags", "metadata"}` | Library document metadata | +| `"lb"` | `{"id", "data"}` | Library document blob content (chunked by pub/sub layer) | + +These are emitted after the existing `"t"` and `"ge"` records during +download and processed during upload. + +#### Download Path + +Extend `KnowledgeManager.get_kg_core()` to: + +1. Stream triples and graph embeddings from the core store (existing + behavior). +2. Use the librarian service API to retrieve documents associated with + this core ID: + a. Fetch the root document metadata and content. + b. Use `list-children` to discover child documents (pages, chunks). + c. Recursively fetch metadata and content for each child. +3. Stream each document as `"lm"` (metadata) and `"lb"` (content) records. + +The knowledge manager gains the librarian service as a pub/sub dependency. +Large document content is chunked by the librarian's existing pub/sub +transport — the knowledge manager receives and forwards these chunks without +buffering the full blob in memory. + +#### Upload Path + +Extend `KnowledgeManager.put_kg_core()` to handle the new record types: + +1. For `"lm"` records: call the librarian service API to create/update + the document metadata. +2. For `"lb"` records: call the librarian service API to store the + document content. + +Parent-child relationships are preserved because `parent_id` is stored in +the metadata. Documents should be processed in hierarchy order (parent +before child) to satisfy any ordering constraints. + +#### Load Path + +The load path (`_load_kg_core`) publishes triples and embeddings to Pulsar +topics for ingestion into graph/vector stores. Source material does not need +to flow through the load path — it is already in the librarian after the +upload step and can be accessed directly by services that need it. + +No changes to the load path for source material. + +#### CLI Changes + +**`tg-get-kg-core`**: Add handling for `"lm"` and `"lb"` record types in +the file writer. + +**`tg-put-kg-core`**: Add handling for `"lm"` and `"lb"` record types in +the file reader. Send library records to the knowledge manager alongside +triple/embedding records. + +#### Associating Documents with Cores + +The core ID is `metadata.root`, which is the root document ID from the +librarian. This provides a natural join: the core's root document and all +its children (pages, chunks) are the source material for that core. + +The librarian's `list-children` API provides the child documents. A +recursive traversal from the root document collects the full hierarchy. + +### API Changes + +#### KnowledgeResponse Schema + +Add optional fields to `KnowledgeResponse` for library data: + +```python +@dataclass +class KnowledgeResponse: + error: Error | None = None + ids: list | None = None + eos: bool = False + triples: Triples | None = None + graph_embeddings: GraphEmbeddings | None = None + document_embeddings: DocumentEmbeddings | None = None + library_metadata: LibraryMetadata | None = None # new + library_blob: LibraryBlob | None = None # new +``` + +#### New Schema Types + +```python +@dataclass +class LibraryMetadata: + id: str + kind: str | None = None + title: str | None = None + parent_id: str | None = None + document_type: str | None = None + comments: str | None = None + tags: list[str] | None = None + metadata: list[Triple] | None = None + +@dataclass +class LibraryBlob: + id: str + data: bytes +``` + +#### Socket API + +The existing streaming protocol for `get-kg-core` / `put-kg-core` carries +these new fields naturally — responses already stream multiple record types. + +### Dependencies Between Changes + +``` +Change 1 (named graphs) ◄── Change 2 depends on this + │ + └── Change 2 (provenance triples) + │ + └── Change 3 (source material) is independent +``` + +Change 1 is a prerequisite for Change 2 (provenance triples use named +graphs). Change 3 is independent and can be implemented in parallel. + +## Security Considerations + +- **Workspace isolation**: Core download/upload must respect workspace + boundaries. Source material from the librarian must only be included if + it belongs to the same workspace as the core. This is already enforced + by the existing workspace-scoped queries. +- **Large blob transfer**: Streaming large documents through the API + is handled by the librarian's existing pub/sub chunking, which is + designed to be websocket-friendly. No additional chunking layer is + needed. +- **Cross-instance trust**: When uploading a core from an external source, + the library content should be treated as untrusted input. Document + metadata and blob content should be validated before insertion. + +## Performance Considerations + +- **Core file size**: Including source material will significantly increase + core file sizes. Consider adding a flag to download/upload commands to + optionally exclude source material for use cases where only the knowledge + graph is needed. +- **Streaming**: All paths already use streaming (paged Cassandra queries, + msgpack record-at-a-time). Library content should follow the same pattern. +- **Cassandra schema migration**: Changing the tuple width in the `triples` + table requires careful handling. Cassandra frozen tuples cannot be altered + in place — a migration strategy is needed (see Migration Plan). + +## Testing Strategy + +- **Unit tests**: Triple round-trip with graph name (write → read → + verify `g` field preserved). Backward compatibility with 6-element tuples. +- **Integration tests**: Full lifecycle — extract with provenance → download + core → upload to fresh instance → load → verify provenance chain resolves. +- **File format tests**: Read old-format core files (no graph name, no + library records) and verify they load without error. +- **Library inclusion tests**: Download core with source material → upload → + verify documents accessible through librarian. + +## Migration Plan + +### Cassandra Schema + +The `triples` table stores tuples in a `list>` column. Cassandra +does not support altering the type of an existing column. Options: + +**Option A — New table**: Create a `triples_v2` table with the 7-element +tuple. Migrate data from `triples` to `triples_v2`. The read path checks +both tables during a transition period, then the old table is dropped. + +**Option B — Dual read**: Keep the existing table. The read path handles +both 6-element and 7-element tuples by checking length. New writes use +7-element tuples. This works if Cassandra accepts variable-length tuples in +a list — **needs verification**. + +**Option C — Separate graph column**: Instead of extending the tuple, add a +parallel `graphs list` column where `graphs[i]` corresponds to +`triples[i]`. This avoids tuple migration entirely but requires keeping the +two lists in sync. + +Recommendation: Verify Option B first (simplest). Fall back to Option A if +Cassandra rejects mixed tuple lengths. + +### Core File Format + +Backward compatible by design: +- Old files lack `"g"` in triple dicts and have no `"lm"`/`"lb"` records → + handled by defaults. +- New files read by old code → old code ignores unknown record types (the + existing `read_message` raises on unknown types, so this needs a small + fix to skip unknown types gracefully). + +## Open Questions + +1. **Provenance topic routing**: Do provenance triples currently arrive at + the `triples-input` topic consumed by the knowledge core store? If not, + what topic are they on? + +2. **Include original documents?**: Should cores include the original + uploaded document (e.g. PDF), or only derived text (pages + chunks)? + Including originals makes cores fully self-contained but potentially + very large. Excluding them preserves provenance text display but loses + the ability to serve the original file. + +3. **Optional source material**: Should there be a flag on download/upload + to include or exclude source material? This would let users choose + between compact cores (knowledge only) and complete cores (knowledge + + sources). + +4. **Cassandra tuple migration**: Can Cassandra handle mixed-length tuples + in a `list>` column, or is a table migration required? + +5. **Document embedding cores**: DE cores are managed alongside KG cores. + Do they need the same treatment (source material inclusion)? The + document embeddings reference chunk IDs — the same provenance chain + applies. + +6. **Core versioning**: Should the core file include a version marker so + readers can distinguish old-format from new-format files without + trial-and-error parsing? + +## References + +- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md` +- Query-time explainability: `docs/tech-specs/query-time-explainability.md` +- Agent explainability: `docs/tech-specs/agent-explainability.md` +- Data ownership model: `docs/tech-specs/data-ownership-model.md` diff --git a/tests/unit/test_base/test_cassandra_config.py b/tests/unit/test_base/test_cassandra_config.py index a291434d..fe8a8379 100644 --- a/tests/unit/test_base/test_cassandra_config.py +++ b/tests/unit/test_base/test_cassandra_config.py @@ -409,4 +409,57 @@ class TestEdgeCases: assert hosts == ['mixed-host'] assert username is None # Stays None - assert password == 'mixed-pass' \ No newline at end of file + assert password == 'mixed-pass' + + +class TestReplicationFactorParamPath: + + def test_explicit_kwarg(self): + with patch.dict(os.environ, {}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=3, + ) + assert rf == 3 + + def test_kwarg_overrides_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=3, + ) + assert rf == 3 + + def test_env_fallback_when_kwarg_none(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=None, + ) + assert rf == 5 + + def test_default_when_no_kwarg_no_env(self): + with patch.dict(os.environ, {}, clear=True): + _, _, _, _, rf = resolve_cassandra_config() + assert rf == 1 + + def test_params_dict_path(self): + with patch.dict(os.environ, {}, clear=True): + params = {'cassandra_replication_factor': 3} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 3 + + def test_params_dict_overrides_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + params = {'cassandra_replication_factor': 3} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 3 + + def test_params_dict_missing_falls_to_env(self): + with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True): + params = {} + _, _, _, _, rf = resolve_cassandra_config( + replication_factor=params.get('cassandra_replication_factor'), + ) + assert rf == 5 \ No newline at end of file diff --git a/tests/unit/test_base/test_qdrant_config.py b/tests/unit/test_base/test_qdrant_config.py new file mode 100644 index 00000000..dbbe4214 --- /dev/null +++ b/tests/unit/test_base/test_qdrant_config.py @@ -0,0 +1,136 @@ + +import os +import pytest +from unittest.mock import patch + +from trustgraph.base.qdrant_config import ( + get_qdrant_defaults, + resolve_qdrant_config, +) + + +class TestGetQdrantDefaults: + + def test_defaults_with_no_env_vars(self): + with patch.dict(os.environ, {}, clear=True): + defaults = get_qdrant_defaults() + assert defaults['url'] == 'http://localhost:6333' + assert defaults['api_key'] is None + assert defaults['replication_factor'] == 1 + assert defaults['shard_number'] == 1 + + def test_defaults_from_env(self): + env = { + 'QDRANT_URL': 'http://qdrant:6333', + 'QDRANT_API_KEY': 'secret', + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + defaults = get_qdrant_defaults() + assert defaults['url'] == 'http://qdrant:6333' + assert defaults['api_key'] == 'secret' + assert defaults['replication_factor'] == 3 + assert defaults['shard_number'] == 5 + + +class TestResolveQdrantConfig: + + def test_defaults(self): + with patch.dict(os.environ, {}, clear=True): + url, api_key, rf, sn = resolve_qdrant_config() + assert url == 'http://localhost:6333' + assert api_key is None + assert rf == 1 + assert sn == 1 + + def test_explicit_kwargs(self): + with patch.dict(os.environ, {}, clear=True): + url, api_key, rf, sn = resolve_qdrant_config( + url='http://custom:6333', + api_key='key', + replication_factor=3, + shard_number=5, + ) + assert url == 'http://custom:6333' + assert api_key == 'key' + assert rf == 3 + assert sn == 5 + + def test_kwargs_override_env(self): + env = { + 'QDRANT_URL': 'http://env:6333', + 'QDRANT_REPLICATION_FACTOR': '10', + 'QDRANT_SHARD_NUMBER': '10', + } + with patch.dict(os.environ, env, clear=True): + url, _, rf, sn = resolve_qdrant_config( + url='http://explicit:6333', + replication_factor=3, + shard_number=5, + ) + assert url == 'http://explicit:6333' + assert rf == 3 + assert sn == 5 + + def test_env_fallback_when_kwargs_none(self): + env = { + 'QDRANT_URL': 'http://env:6333', + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + url, _, rf, sn = resolve_qdrant_config() + assert url == 'http://env:6333' + assert rf == 3 + assert sn == 5 + + def test_params_dict_path(self): + with patch.dict(os.environ, {}, clear=True): + params = { + 'store_uri': 'http://params:6333', + 'api_key': 'pkey', + 'qdrant_replication_factor': 3, + 'qdrant_shard_number': 5, + } + url, api_key, rf, sn = resolve_qdrant_config( + url=params.get('store_uri'), + api_key=params.get('api_key'), + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert url == 'http://params:6333' + assert api_key == 'pkey' + assert rf == 3 + assert sn == 5 + + def test_params_dict_overrides_env(self): + env = { + 'QDRANT_REPLICATION_FACTOR': '10', + 'QDRANT_SHARD_NUMBER': '10', + } + with patch.dict(os.environ, env, clear=True): + params = { + 'qdrant_replication_factor': 3, + 'qdrant_shard_number': 5, + } + _, _, rf, sn = resolve_qdrant_config( + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert rf == 3 + assert sn == 5 + + def test_params_dict_missing_falls_to_env(self): + env = { + 'QDRANT_REPLICATION_FACTOR': '3', + 'QDRANT_SHARD_NUMBER': '5', + } + with patch.dict(os.environ, env, clear=True): + params = {} + _, _, rf, sn = resolve_qdrant_config( + replication_factor=params.get('qdrant_replication_factor'), + shard_number=params.get('qdrant_shard_number'), + ) + assert rf == 3 + assert sn == 5 diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index 8f73dcc6..7797c9be 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -11,7 +11,12 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock from unittest.mock import call from trustgraph.cores.knowledge import KnowledgeManager -from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL +from trustgraph.schema import ( + KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, + EntityEmbeddings, IRI, LITERAL, + LibraryMetadata, LibraryBlob, + LibrarianResponse, DocumentMetadata, +) @pytest.fixture @@ -373,11 +378,252 @@ class TestKnowledgeManagerOtherMethods: mock_respond = AsyncMock() await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user") - + # Verify table store was called correctly knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id") - + # Verify response mock_respond.assert_called_once() response = mock_respond.call_args[0][0] - assert response.error is None \ No newline at end of file + assert response.error is None + + +class TestKnowledgeManagerLibraryDownload: + """Test get_kg_core streaming of library documents.""" + + @pytest.fixture + def manager_with_librarian(self, mock_flow_config): + with patch('trustgraph.cores.knowledge.KnowledgeTableStore'): + mock_librarian = AsyncMock() + manager = KnowledgeManager( + cassandra_host=["localhost"], + cassandra_username="test_user", + cassandra_password="test_pass", + keyspace="test_keyspace", + flow_config=mock_flow_config, + librarian=mock_librarian, + ) + manager.table_store = AsyncMock() + return manager + + @pytest.mark.asyncio + async def test_get_kg_core_streams_library_docs(self, manager_with_librarian): + mock_request = Mock() + mock_request.id = "root-doc" + mock_respond = AsyncMock() + + manager_with_librarian.table_store.get_triples = AsyncMock() + manager_with_librarian.table_store.get_graph_embeddings = AsyncMock() + + root_meta = DocumentMetadata( + id="root-doc", kind="application/pdf", title="Test PDF", + document_type="source", + ) + child_meta = DocumentMetadata( + id="chunk-1", kind="text/plain", title="Chunk 1", + parent_id="root-doc", document_type="chunk", + ) + + manager_with_librarian.librarian.fetch_document_metadata.return_value = root_meta + manager_with_librarian.librarian.request.return_value = LibrarianResponse( + document_metadatas=[child_meta], + ) + manager_with_librarian.librarian.fetch_document_content.side_effect = [ + b"cm9vdCBjb250ZW50", + b"Y2h1bmsgY29udGVudA==", + ] + + await manager_with_librarian.get_kg_core( + mock_request, mock_respond, "test-user" + ) + + responses = [c[0][0] for c in mock_respond.call_args_list] + + lm_responses = [r for r in responses if r.library_metadata is not None] + lb_responses = [r for r in responses if r.library_blob is not None] + eos_responses = [r for r in responses if r.eos is True] + + assert len(lm_responses) == 2 + assert lm_responses[0].library_metadata.id == "root-doc" + assert lm_responses[0].library_metadata.document_type == "source" + assert lm_responses[1].library_metadata.id == "chunk-1" + assert lm_responses[1].library_metadata.parent_id == "root-doc" + + assert len(lb_responses) == 2 + assert lb_responses[0].library_blob.id == "root-doc" + assert lb_responses[0].library_blob.data == b"cm9vdCBjb250ZW50" + assert lb_responses[1].library_blob.id == "chunk-1" + + assert len(eos_responses) == 1 + + @pytest.mark.asyncio + async def test_get_kg_core_no_librarian_skips_library(self, mock_flow_config): + with patch('trustgraph.cores.knowledge.KnowledgeTableStore'): + manager = KnowledgeManager( + cassandra_host=["localhost"], + cassandra_username="u", cassandra_password="p", + keyspace="ks", flow_config=mock_flow_config, + ) + manager.table_store = AsyncMock() + manager.table_store.get_triples = AsyncMock() + manager.table_store.get_graph_embeddings = AsyncMock() + + mock_request = Mock() + mock_request.id = "doc-1" + mock_respond = AsyncMock() + + await manager.get_kg_core(mock_request, mock_respond, "w") + + responses = [c[0][0] for c in mock_respond.call_args_list] + assert all(r.library_metadata is None for r in responses) + assert all(r.library_blob is None for r in responses) + + @pytest.mark.asyncio + async def test_get_kg_core_librarian_metadata_failure_is_graceful( + self, manager_with_librarian, + ): + mock_request = Mock() + mock_request.id = "missing-doc" + mock_respond = AsyncMock() + + manager_with_librarian.table_store.get_triples = AsyncMock() + manager_with_librarian.table_store.get_graph_embeddings = AsyncMock() + manager_with_librarian.librarian.fetch_document_metadata.side_effect = ( + RuntimeError("not found") + ) + + await manager_with_librarian.get_kg_core( + mock_request, mock_respond, "test-user" + ) + + responses = [c[0][0] for c in mock_respond.call_args_list] + assert all(r.library_metadata is None for r in responses) + assert any(r.eos for r in responses) + + +class TestKnowledgeManagerLibraryUpload: + """Test put_kg_core handling of library metadata and blob records.""" + + @pytest.fixture + def manager_with_librarian(self, mock_flow_config): + with patch('trustgraph.cores.knowledge.KnowledgeTableStore'): + mock_librarian = AsyncMock() + manager = KnowledgeManager( + cassandra_host=["localhost"], + cassandra_username="u", cassandra_password="p", + keyspace="ks", flow_config=mock_flow_config, + librarian=mock_librarian, + ) + manager.table_store = AsyncMock() + return manager + + @pytest.mark.asyncio + async def test_put_metadata_then_blob_calls_librarian( + self, manager_with_librarian, + ): + mock_respond = AsyncMock() + manager_with_librarian.librarian.request.return_value = LibrarianResponse() + + # First call: metadata + req_meta = Mock() + req_meta.triples = None + req_meta.graph_embeddings = None + req_meta.library_metadata = LibraryMetadata( + id="doc-1", kind="application/pdf", title="Test", + document_type="source", + ) + req_meta.library_blob = None + await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws") + + # Metadata is buffered, librarian not called yet + manager_with_librarian.librarian.request.assert_not_called() + + # Second call: blob + req_blob = Mock() + req_blob.triples = None + req_blob.graph_embeddings = None + req_blob.library_metadata = None + req_blob.library_blob = LibraryBlob( + id="doc-1", data=b"dGVzdA==", + ) + await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws") + + # Now librarian should have been called with add-document + manager_with_librarian.librarian.request.assert_called_once() + call_args = manager_with_librarian.librarian.request.call_args[0][0] + assert call_args.operation == "add-document" + assert call_args.document_metadata.id == "doc-1" + assert call_args.document_metadata.kind == "application/pdf" + assert call_args.content == b"dGVzdA==" + + @pytest.mark.asyncio + async def test_put_child_document_uses_add_child_operation( + self, manager_with_librarian, + ): + mock_respond = AsyncMock() + manager_with_librarian.librarian.request.return_value = LibrarianResponse() + + req_meta = Mock() + req_meta.triples = None + req_meta.graph_embeddings = None + req_meta.library_metadata = LibraryMetadata( + id="chunk-1", kind="text/plain", title="Chunk", + parent_id="doc-1", document_type="chunk", + ) + req_meta.library_blob = None + await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws") + + req_blob = Mock() + req_blob.triples = None + req_blob.graph_embeddings = None + req_blob.library_metadata = None + req_blob.library_blob = LibraryBlob(id="chunk-1", data=b"Y2h1bms=") + await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws") + + call_args = manager_with_librarian.librarian.request.call_args[0][0] + assert call_args.operation == "add-child-document" + assert call_args.document_metadata.parent_id == "doc-1" + + @pytest.mark.asyncio + async def test_put_blob_without_metadata_logs_warning( + self, manager_with_librarian, + ): + mock_respond = AsyncMock() + + req_blob = Mock() + req_blob.triples = None + req_blob.graph_embeddings = None + req_blob.library_metadata = None + req_blob.library_blob = LibraryBlob(id="orphan", data=b"data") + await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws") + + # Librarian should not be called for orphan blob + manager_with_librarian.librarian.request.assert_not_called() + + @pytest.mark.asyncio + async def test_put_existing_document_is_graceful( + self, manager_with_librarian, + ): + mock_respond = AsyncMock() + manager_with_librarian.librarian.request.side_effect = RuntimeError( + "Document already exists" + ) + + req_meta = Mock() + req_meta.triples = None + req_meta.graph_embeddings = None + req_meta.library_metadata = LibraryMetadata( + id="doc-1", kind="application/pdf", title="Test", + document_type="source", + ) + req_meta.library_blob = None + await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws") + + req_blob = Mock() + req_blob.triples = None + req_blob.graph_embeddings = None + req_blob.library_metadata = None + req_blob.library_blob = LibraryBlob(id="doc-1", data=b"data") + await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws") + + # Should not raise — "already exists" is handled gracefully \ No newline at end of file diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py index 04807b20..641a9d78 100644 --- a/tests/unit/test_decoding/test_pdf_decoder.py +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test successful PDF processing""" # Mock PDF content - pdf_content = b"fake pdf content" + pdf_content = b"%PDF-1.7\nfake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') # Mock PyPDFLoader @@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): # Verify triples were sent for each page (provenance) assert mock_triples_flow.send.call_count == 2 + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') + @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer): + """Test rejecting non-PDF content before invoking the PDF loader""" + html_content = b"Not found" + html_base64 = base64.b64encode(html_content) + + mock_metadata = Metadata(id="test-doc") + mock_document = Document(metadata=mock_metadata, document_id="doc-123") + mock_msg = MagicMock() + mock_msg.value.return_value = mock_document + + mock_output_flow = AsyncMock() + mock_triples_flow = AsyncMock() + mock_flow = MagicMock(side_effect=lambda name: { + "output": mock_output_flow, + "triples": mock_triples_flow, + }.get(name)) + mock_flow.librarian.fetch_document_metadata = AsyncMock( + return_value=MagicMock(kind="application/pdf") + ) + mock_flow.librarian.fetch_document_content = AsyncMock( + return_value=html_base64 + ) + mock_flow.librarian.save_child_document = AsyncMock() + + config = { + 'id': 'test-pdf-decoder', + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + await processor.on_message(mock_msg, None, mock_flow) + + mock_pdf_loader_class.assert_not_called() + mock_output_flow.send.assert_not_called() + mock_triples_flow.send.assert_not_called() + mock_flow.librarian.save_child_document.assert_not_called() + @patch('trustgraph.base.librarian_client.Consumer') @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of empty PDF""" - pdf_content = b"fake pdf content" + pdf_content = b"%PDF-1.7\nfake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') mock_loader = MagicMock() @@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of unicode content in PDF""" - pdf_content = b"fake pdf content" + pdf_content = b"%PDF-1.7\nfake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') mock_loader = MagicMock() diff --git a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py index aef6fc92..65837323 100644 --- a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py +++ b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py @@ -18,7 +18,7 @@ from trustgraph.embeddings.hf.hf import Processor class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): """Test HuggingFace dynamic model loading and caching""" - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -39,7 +39,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): assert processor.cached_model_name == "test-model" assert processor.embeddings is not None - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -63,7 +63,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): mock_hf_class.assert_not_called() assert processor.cached_model_name == "test-model" - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -84,7 +84,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): mock_hf_class.assert_called_once_with(model_name="different-model") assert processor.cached_model_name == "different-model" - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -107,7 +107,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): assert processor.cached_model_name == "test-model" # Still using default assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]] - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -130,7 +130,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): assert processor.cached_model_name == "custom-model" mock_hf_instance.embed_documents.assert_called_once_with(["test text"]) - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -164,7 +164,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): assert call_count_after_b == initial_call_count + 2 # Reload for model-b assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): @@ -187,7 +187,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase): assert mock_hf_class.call_count == initial_count assert processor.cached_model_name == "test-model" - @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings') + @patch('langchain_huggingface.HuggingFaceEmbeddings') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__') async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class): diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index b61500a4..fb385f43 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -333,8 +333,8 @@ class TestUnifiedTableQueries: """Test queries against the unified rows table""" @pytest.mark.asyncio - @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock) - async def test_query_with_index_match(self, mock_async_execute): + @patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock) + async def test_query_with_index_match(self, mock_async_execute_paged): """Test query execution with matching index""" processor = MagicMock() processor.session = MagicMock() @@ -344,10 +344,10 @@ class TestUnifiedTableQueries: processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) - # Mock async_execute to return test data + # Mock async_execute_paged to return test data (list of pages) mock_row = MagicMock() mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"} - mock_async_execute.return_value = [mock_row] + mock_async_execute_paged.return_value = [[mock_row]] schema = RowSchema( name="products", @@ -370,10 +370,10 @@ class TestUnifiedTableQueries: # Verify Cassandra was connected and queried processor.connect_cassandra.assert_called_once() - mock_async_execute.assert_called_once() + mock_async_execute_paged.assert_called_once() # Verify query structure - should query unified rows table - call_args = mock_async_execute.call_args + call_args = mock_async_execute_paged.call_args query = call_args[0][1] params = call_args[0][2] @@ -394,8 +394,8 @@ class TestUnifiedTableQueries: assert results[0]["category"] == "electronics" @pytest.mark.asyncio - @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock) - async def test_query_without_index_match(self, mock_async_execute): + @patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock) + async def test_query_without_index_match(self, mock_async_scan): """Test query execution without matching index (scan mode)""" processor = MagicMock() processor.session = MagicMock() @@ -406,12 +406,10 @@ class TestUnifiedTableQueries: processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) - # Mock async_execute to return test data + # Mock async_scan to return filtered test data mock_row1 = MagicMock() mock_row1.data = {"id": "1", "name": "Product A", "price": "100"} - mock_row2 = MagicMock() - mock_row2.data = {"id": "2", "name": "Product B", "price": "200"} - mock_async_execute.return_value = [mock_row1, mock_row2] + mock_async_scan.return_value = [mock_row1] schema = RowSchema( name="products", @@ -432,13 +430,16 @@ class TestUnifiedTableQueries: limit=10 ) - # Query should use ALLOW FILTERING for scan - call_args = mock_async_execute.call_args + # Verify async_scan was called + mock_async_scan.assert_called_once() + + # Verify query structure + call_args = mock_async_scan.call_args query = call_args[0][1] assert "ALLOW FILTERING" in query - # Should post-filter results + # Should return filtered results assert len(results) == 1 assert results[0]["name"] == "Product A" diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index dbe06b40..41d0f88b 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection: proc.collection_exists = MagicMock(return_value=True) proc._cache_lock = asyncio.Lock() proc._known_collections = set() + proc.replication_factor = 1 + proc.shard_number = 1 msg = MagicMock() msg.metadata.collection = "graphs" diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 59d15b45..2d058733 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -35,9 +35,9 @@ def _make_store(): class TestGetGraphEmbeddings: @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) async def test_row_converts_to_entity_embeddings_with_singular_vector( - self, mock_async_execute + self, mock_async_execute_paged ): """ Cassandra rows return entities as a list of [entity_tuple, vector] @@ -57,7 +57,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -66,7 +66,7 @@ class TestGetGraphEmbeddings: await store.get_graph_embeddings("alice", "doc-1", receiver) - mock_async_execute.assert_called_once_with( + mock_async_execute_paged.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, ("alice", "doc-1"), @@ -96,8 +96,8 @@ class TestGetGraphEmbeddings: assert ge.entities[2].entity.value == "a literal entity" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged): """row[3] being None / empty must produce a GraphEmbeddings with no entities, not raise.""" fake_row = (None, None, None, None) @@ -105,7 +105,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -118,8 +118,8 @@ class TestGetGraphEmbeddings: assert received[0].entities == [] @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_multiple_rows_each_emit_one_message(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged): fake_rows = [ (None, None, None, [ (("http://example.org/a", True), [1.0]), @@ -132,7 +132,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = fake_rows + mock_async_execute_paged.return_value = [fake_rows] received = [] @@ -153,9 +153,9 @@ class TestGetTriples: the same Metadata construction. Cover it for parity.""" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_row_converts_to_triples(self, mock_async_execute): - # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_row_converts_to_triples(self, mock_async_execute_paged): + # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri, graph) fake_row = ( None, None, None, [ @@ -163,6 +163,7 @@ class TestGetTriples: "http://example.org/alice", True, "http://example.org/knows", True, "http://example.org/bob", True, + "urn:graph:source", ), ], ) @@ -170,7 +171,7 @@ class TestGetTriples: store = _make_store() store.cassandra = Mock() store.get_triples_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -191,3 +192,33 @@ class TestGetTriples: assert t.s.iri == "http://example.org/alice" assert t.p.iri == "http://example.org/knows" assert t.o.iri == "http://example.org/bob" + assert t.g == "urn:graph:source" + + @pytest.mark.asyncio + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_empty_graph_name_becomes_none(self, mock_async_execute_paged): + fake_row = ( + None, None, None, + [ + ( + "http://example.org/alice", True, + "http://example.org/knows", True, + "http://example.org/bob", True, + "", + ), + ], + ) + + store = _make_store() + store.cassandra = Mock() + store.get_triples_stmt = Mock() + mock_async_execute_paged.return_value = [[fake_row]] + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_triples("w", "d", receiver) + + assert received[0].triples[0].g is None diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py index 437b83c8..af128f23 100644 --- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -1,5 +1,6 @@ """ -Round-trip unit tests for KnowledgeRequestTranslator. +Round-trip unit tests for KnowledgeRequestTranslator and +KnowledgeResponseTranslator. Regression coverage: a previous version of the decode side constructed EntityEmbeddings(vectors=...) — the schema field is `vector` (singular), @@ -15,9 +16,13 @@ Triples breaks the test. import pytest -from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator +from trustgraph.messaging.translators.knowledge import ( + KnowledgeRequestTranslator, + KnowledgeResponseTranslator, +) from trustgraph.schema import ( KnowledgeRequest, + KnowledgeResponse, GraphEmbeddings, EntityEmbeddings, Triples, @@ -25,6 +30,8 @@ from trustgraph.schema import ( Metadata, Term, IRI, + LibraryMetadata, + LibraryBlob, ) @@ -145,3 +152,161 @@ class TestKnowledgeRequestTranslatorTriples: assert t.s.iri == "http://example.org/alice" assert t.p.iri == "http://example.org/knows" assert t.o.iri == "http://example.org/bob" + + +class TestKnowledgeRequestTranslatorLibrary: + + def test_roundtrip_preserves_library_metadata(self, translator): + request = KnowledgeRequest( + operation="put-kg-core", + id="doc-1", + library_metadata=LibraryMetadata( + id="doc-1", + kind="application/pdf", + title="Test Document", + parent_id="", + document_type="source", + comments="test comments", + tags=["tag1", "tag2"], + ), + ) + + encoded = translator.encode(request) + assert "library-metadata" in encoded + lm = encoded["library-metadata"] + assert lm["id"] == "doc-1" + assert lm["kind"] == "application/pdf" + assert lm["title"] == "Test Document" + assert lm["parent-id"] == "" + assert lm["document-type"] == "source" + assert lm["comments"] == "test comments" + assert lm["tags"] == ["tag1", "tag2"] + + decoded = translator.decode(encoded) + assert decoded.library_metadata is not None + assert decoded.library_metadata.id == "doc-1" + assert decoded.library_metadata.kind == "application/pdf" + assert decoded.library_metadata.title == "Test Document" + assert decoded.library_metadata.parent_id == "" + assert decoded.library_metadata.document_type == "source" + assert decoded.library_metadata.comments == "test comments" + assert decoded.library_metadata.tags == ["tag1", "tag2"] + + def test_roundtrip_preserves_child_document_metadata(self, translator): + request = KnowledgeRequest( + operation="put-kg-core", + id="doc-1", + library_metadata=LibraryMetadata( + id="chunk-1", + kind="text/plain", + title="Chunk 1", + parent_id="doc-1", + document_type="chunk", + ), + ) + + encoded = translator.encode(request) + decoded = translator.decode(encoded) + + assert decoded.library_metadata.parent_id == "doc-1" + assert decoded.library_metadata.document_type == "chunk" + + def test_roundtrip_preserves_library_blob(self, translator): + request = KnowledgeRequest( + operation="put-kg-core", + id="doc-1", + library_blob=LibraryBlob( + id="doc-1", + data=b"SGVsbG8gV29ybGQ=", + ), + ) + + encoded = translator.encode(request) + assert "library-blob" in encoded + assert encoded["library-blob"]["id"] == "doc-1" + assert encoded["library-blob"]["data"] == "SGVsbG8gV29ybGQ=" + + decoded = translator.decode(encoded) + assert decoded.library_blob is not None + assert decoded.library_blob.id == "doc-1" + assert decoded.library_blob.data == "SGVsbG8gV29ybGQ=" + + def test_absent_library_fields_decode_as_none(self, translator): + decoded = translator.decode({ + "operation": "get-kg-core", + "id": "doc-1", + }) + assert decoded.library_metadata is None + assert decoded.library_blob is None + + +class TestKnowledgeResponseTranslatorLibrary: + + @pytest.fixture + def response_translator(self): + return KnowledgeResponseTranslator() + + def test_encode_library_metadata(self, response_translator): + response = KnowledgeResponse( + ids=None, + library_metadata=LibraryMetadata( + id="doc-1", + kind="application/pdf", + title="Test", + parent_id="", + document_type="source", + comments="", + tags=[], + ), + ) + encoded = response_translator.encode(response) + assert "library-metadata" in encoded + assert encoded["library-metadata"]["id"] == "doc-1" + assert encoded["library-metadata"]["kind"] == "application/pdf" + assert encoded["library-metadata"]["document-type"] == "source" + + def test_encode_library_blob_bytes_to_string(self, response_translator): + response = KnowledgeResponse( + ids=None, + library_blob=LibraryBlob( + id="doc-1", + data=b"dGVzdCBkYXRh", + ), + ) + encoded = response_translator.encode(response) + assert "library-blob" in encoded + assert encoded["library-blob"]["id"] == "doc-1" + assert encoded["library-blob"]["data"] == "dGVzdCBkYXRh" + assert isinstance(encoded["library-blob"]["data"], str) + + def test_encode_library_blob_string_passthrough(self, response_translator): + response = KnowledgeResponse( + ids=None, + library_blob=LibraryBlob( + id="doc-1", + data="already-a-string", + ), + ) + encoded = response_translator.encode(response) + assert encoded["library-blob"]["data"] == "already-a-string" + + def test_library_metadata_is_not_final(self, response_translator): + response = KnowledgeResponse( + ids=None, + library_metadata=LibraryMetadata(id="doc-1"), + ) + _, is_final = response_translator.encode_with_completion(response) + assert is_final is False + + def test_library_blob_is_not_final(self, response_translator): + response = KnowledgeResponse( + ids=None, + library_blob=LibraryBlob(id="doc-1", data=b"data"), + ) + _, is_final = response_translator.encode_with_completion(response) + assert is_final is False + + def test_eos_is_final(self, response_translator): + response = KnowledgeResponse(eos=True) + _, is_final = response_translator.encode_with_completion(response) + assert is_final is True diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index 9074bac1..0190d3f5 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -337,7 +337,7 @@ class Api: from . bulk_client import BulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._bulk_client = BulkClient(base_url, self.timeout, self.token) + self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._bulk_client def metrics(self): @@ -462,7 +462,7 @@ class Api: from . async_bulk_client import AsyncBulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token) + self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._async_bulk_client def async_metrics(self): diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 9a6a49c3..f93ab667 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -9,10 +9,11 @@ from . types import Triple class AsyncBulkClient: """Asynchronous bulk operations client""" - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,11 +26,21 @@ class AsyncBulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None: """Bulk import triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for triple in triples: @@ -42,9 +53,7 @@ class AsyncBulkClient: async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]: """Bulk export triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -57,9 +66,7 @@ class AsyncBulkClient: async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -67,9 +74,7 @@ class AsyncBulkClient: async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -77,9 +82,7 @@ class AsyncBulkClient: async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -87,9 +90,7 @@ class AsyncBulkClient: async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -97,9 +98,7 @@ class AsyncBulkClient: async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for context in contexts: @@ -107,9 +106,7 @@ class AsyncBulkClient: async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -117,9 +114,7 @@ class AsyncBulkClient: async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import rows via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for row in rows: diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index d18bee34..7b38a4b1 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -30,6 +30,7 @@ class AsyncSocketClient: self.timeout = timeout self.token = token self.workspace = workspace + self._workspace_explicit = workspace != "default" self._request_counter = 0 self._socket = None self._connect_cm = None @@ -92,7 +93,8 @@ class AsyncSocketClient: ) if resp.get("type") == "auth-ok": - self.workspace = resp.get("workspace", self.workspace) + if not self._workspace_explicit: + self.workspace = resp.get("workspace", self.workspace) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 0e49fc4e..ae185240 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -34,7 +34,7 @@ class BulkClient: Note: For true async support, use AsyncBulkClient instead. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: """ Initialize synchronous bulk client. @@ -42,10 +42,12 @@ class BulkClient: url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) timeout: WebSocket timeout in seconds token: Optional bearer token for authentication + workspace: Workspace for data isolation """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -58,6 +60,18 @@ class BulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: """Run async coroutine synchronously""" try: @@ -116,9 +130,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of triple import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -194,9 +206,7 @@ class BulkClient: async def _export_triples_async(self, flow: str) -> Iterator[Triple]: """Async implementation of triple export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -238,9 +248,7 @@ class BulkClient: async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of graph embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -296,9 +304,7 @@ class BulkClient: async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of graph embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -336,9 +342,7 @@ class BulkClient: async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of document embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -394,9 +398,7 @@ class BulkClient: async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of document embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -446,9 +448,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of entity contexts import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -522,9 +522,7 @@ class BulkClient: async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of entity contexts export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -562,9 +560,7 @@ class BulkClient: async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: """Async implementation of rows import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for row in rows: diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 6eeb95ff..91bc67a1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -167,7 +167,8 @@ class SocketClient: ) if resp.get("type") == "auth-ok": - self.workspace = resp.get("workspace", self.workspace) + if self.workspace == "default": + self.workspace = resp.get("workspace", self.workspace) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( @@ -501,6 +502,7 @@ class SocketClient: def put_kg_core( self, id: str, triples=None, graph_embeddings=None, + library_metadata=None, library_blob=None, ) -> Dict[str, Any]: request = { "operation": "put-kg-core", @@ -511,6 +513,10 @@ class SocketClient: request["triples"] = triples if graph_embeddings is not None: request["graph-embeddings"] = graph_embeddings + if library_metadata is not None: + request["library-metadata"] = library_metadata + if library_blob is not None: + request["library-blob"] = library_blob return self._send_request_sync("knowledge", None, request) def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]: diff --git a/trustgraph-base/trustgraph/base/cassandra_config.py b/trustgraph-base/trustgraph/base/cassandra_config.py index 78505c68..b2e36fbd 100644 --- a/trustgraph-base/trustgraph/base/cassandra_config.py +++ b/trustgraph-base/trustgraph/base/cassandra_config.py @@ -103,35 +103,19 @@ def resolve_cassandra_config( host: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, - default_keyspace: Optional[str] = None + default_keyspace: Optional[str] = None, + replication_factor: Optional[int] = None, ) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]: - """ - Resolve Cassandra configuration from various sources. - - Can accept either argparse args object or explicit parameters. - Converts host string to list format for Cassandra driver. - - Args: - args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor - host: Optional explicit host parameter (overrides args) - username: Optional explicit username parameter (overrides args) - password: Optional explicit password parameter (overrides args) - default_keyspace: Optional default keyspace if not specified elsewhere - - Returns: - tuple: (hosts_list, username, password, keyspace, replication_factor) - """ - # If args provided, extract values keyspace = None - replication_factor = 1 if args is not None: host = host or getattr(args, 'cassandra_host', None) username = username or getattr(args, 'cassandra_username', None) password = password or getattr(args, 'cassandra_password', None) keyspace = getattr(args, 'cassandra_keyspace', None) - replication_factor = getattr(args, 'cassandra_replication_factor', 1) + replication_factor = replication_factor or getattr( + args, 'cassandra_replication_factor', None + ) - # Apply defaults if still None defaults = get_cassandra_defaults() host = host or defaults['host'] username = username or defaults['username'] diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py index 9bf599b1..ff10c140 100644 --- a/trustgraph-base/trustgraph/base/logging.py +++ b/trustgraph-base/trustgraph/base/logging.py @@ -11,6 +11,7 @@ Supports dual output to console and Loki for centralized log aggregation. import contextvars import logging import logging.handlers +import uuid from argparse import ArgumentParser from queue import Queue from typing import Any @@ -132,14 +133,12 @@ def setup_logging(args: dict[str, Any]) -> None: try: from logging_loki import LokiHandler - # Create Loki handler with optional authentication. The - # processor label is NOT baked in here — it's stamped onto - # each record by _ProcessorIdFilter reading the task-local - # contextvar, and logging_loki's emitter reads record.tags - # to build per-record Loki labels. + instance_id = str(uuid.uuid4())[:8] + loki_handler_kwargs = { 'url': loki_url, 'version': "1", + 'tags': {'instance': instance_id}, } if loki_username and loki_password: diff --git a/trustgraph-base/trustgraph/base/qdrant_config.py b/trustgraph-base/trustgraph/base/qdrant_config.py new file mode 100644 index 00000000..f3e015ca --- /dev/null +++ b/trustgraph-base/trustgraph/base/qdrant_config.py @@ -0,0 +1,87 @@ + +import os +import argparse +from typing import Optional, Any, Tuple + + +def get_qdrant_defaults() -> dict: + return { + 'url': os.getenv('QDRANT_URL', 'http://localhost:6333'), + 'api_key': os.getenv('QDRANT_API_KEY'), + 'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')), + 'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')), + } + + +def add_qdrant_args(parser: argparse.ArgumentParser) -> None: + defaults = get_qdrant_defaults() + + url_help = f"Qdrant URL (default: {defaults['url']})" + if 'QDRANT_URL' in os.environ: + url_help += " [from QDRANT_URL]" + + api_key_help = "Qdrant API key" + if defaults['api_key']: + api_key_help += " (default: )" + if 'QDRANT_API_KEY' in os.environ: + api_key_help += " [from QDRANT_API_KEY]" + + replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})" + if 'QDRANT_REPLICATION_FACTOR' in os.environ: + replication_help += " [from QDRANT_REPLICATION_FACTOR]" + + shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})" + if 'QDRANT_SHARD_NUMBER' in os.environ: + shard_help += " [from QDRANT_SHARD_NUMBER]" + + parser.add_argument( + '--store-uri', + default=defaults['url'], + help=url_help, + ) + + parser.add_argument( + '--api-key', + default=defaults['api_key'], + help=api_key_help, + ) + + parser.add_argument( + '--qdrant-replication-factor', + type=int, + default=defaults['replication_factor'], + help=replication_help, + ) + + parser.add_argument( + '--qdrant-shard-number', + type=int, + default=defaults['shard_number'], + help=shard_help, + ) + + +def resolve_qdrant_config( + args: Optional[Any] = None, + url: Optional[str] = None, + api_key: Optional[str] = None, + replication_factor: Optional[int] = None, + shard_number: Optional[int] = None, +) -> Tuple[str, Optional[str], int, int]: + if args is not None: + url = url or getattr(args, 'store_uri', None) + api_key = api_key or getattr(args, 'api_key', None) + replication_factor = replication_factor or getattr( + args, 'qdrant_replication_factor', None + ) + shard_number = shard_number or getattr( + args, 'qdrant_shard_number', None + ) + + defaults = get_qdrant_defaults() + url = url or defaults['url'] + api_key = api_key or defaults['api_key'] + replication_factor = replication_factor or defaults['replication_factor'] + shard_number = shard_number or defaults['shard_number'] + + return url, api_key, replication_factor, shard_number diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 3830bf59..3f09b41b 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -2,7 +2,8 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, DocumentEmbeddings, ChunkEmbeddings, - Metadata, EntityEmbeddings + Metadata, EntityEmbeddings, + LibraryMetadata, LibraryBlob, ) from .base import MessageTranslator from .primitives import ValueTranslator, SubgraphTranslator @@ -61,6 +62,27 @@ class KnowledgeRequestTranslator(MessageTranslator): ] ) + library_metadata = None + if "library-metadata" in data: + lm = data["library-metadata"] + library_metadata = LibraryMetadata( + id=lm.get("id", ""), + kind=lm.get("kind", ""), + title=lm.get("title", ""), + parent_id=lm.get("parent-id", ""), + document_type=lm.get("document-type", ""), + comments=lm.get("comments", ""), + tags=lm.get("tags", []), + ) + + library_blob = None + if "library-blob" in data: + lb = data["library-blob"] + library_blob = LibraryBlob( + id=lb.get("id", ""), + data=lb.get("data", b""), + ) + return KnowledgeRequest( operation=data.get("operation"), id=data.get("id"), @@ -69,6 +91,8 @@ class KnowledgeRequestTranslator(MessageTranslator): triples=triples, graph_embeddings=graph_embeddings, document_embeddings=document_embeddings, + library_metadata=library_metadata, + library_blob=library_blob, ) def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: @@ -125,6 +149,26 @@ class KnowledgeRequestTranslator(MessageTranslator): ], } + if obj.library_metadata: + result["library-metadata"] = { + "id": obj.library_metadata.id, + "kind": obj.library_metadata.kind, + "title": obj.library_metadata.title, + "parent-id": obj.library_metadata.parent_id, + "document-type": obj.library_metadata.document_type, + "comments": obj.library_metadata.comments, + "tags": obj.library_metadata.tags, + } + + if obj.library_blob: + data = obj.library_blob.data + if isinstance(data, bytes): + data = data.decode("utf-8") + result["library-blob"] = { + "id": obj.library_blob.id, + "data": data, + } + return result @@ -194,6 +238,32 @@ class KnowledgeResponseTranslator(MessageTranslator): } } + # Streaming library metadata response + if obj.library_metadata: + return { + "library-metadata": { + "id": obj.library_metadata.id, + "kind": obj.library_metadata.kind, + "title": obj.library_metadata.title, + "parent-id": obj.library_metadata.parent_id, + "document-type": obj.library_metadata.document_type, + "comments": obj.library_metadata.comments, + "tags": obj.library_metadata.tags, + } + } + + # Streaming library blob response + if obj.library_blob: + data = obj.library_blob.data + if isinstance(data, bytes): + data = data.decode("utf-8") + return { + "library-blob": { + "id": obj.library_blob.id, + "data": data, + } + } + # End of stream marker if obj.eos is True: return {"eos": True} @@ -209,7 +279,9 @@ class KnowledgeResponseTranslator(MessageTranslator): is_final = ( obj.ids is not None or # List response obj.eos is True or # End of stream - (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response + (not obj.triples and not obj.graph_embeddings + and not obj.document_embeddings + and not obj.library_metadata and not obj.library_blob) # Empty response ) return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index a3879103..4353065b 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -21,6 +21,21 @@ from .embeddings import GraphEmbeddings, DocumentEmbeddings # <- () # <- (error) +@dataclass +class LibraryMetadata: + id: str = "" + kind: str = "" + title: str = "" + parent_id: str = "" + document_type: str = "" + comments: str = "" + tags: list[str] = field(default_factory=list) + +@dataclass +class LibraryBlob: + id: str = "" + data: bytes = b"" + @dataclass class KnowledgeRequest: # get-kg-core, delete-kg-core, list-kg-cores, put-kg-core @@ -44,6 +59,10 @@ class KnowledgeRequest: # put-de-core document_embeddings: DocumentEmbeddings | None = None + # put-kg-core (source material) + library_metadata: LibraryMetadata | None = None + library_blob: LibraryBlob | None = None + @dataclass class KnowledgeResponse: error: Error | None = None @@ -52,6 +71,8 @@ class KnowledgeResponse: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None document_embeddings: DocumentEmbeddings | None = None + library_metadata: LibraryMetadata | None = None + library_blob: LibraryBlob | None = None knowledge_request_queue = queue('knowledge', cls='request') knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py index 62fa7ca2..f4d44cca 100644 --- a/trustgraph-cli/trustgraph/cli/get_document_content.py +++ b/trustgraph-cli/trustgraph/cli/get_document_content.py @@ -5,7 +5,7 @@ Gets document content from the library by document ID. import argparse import os import sys -from trustgraph.api import Api +import requests default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -13,15 +13,29 @@ default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") def get_content(url, document_id, output_file, token=None, workspace="default"): - api = Api(url, token=token, workspace=workspace).library() + stream_url = url.rstrip("/") + "/api/v1/document-stream" - content = api.get_document_content(id=document_id) + params = { + "document-id": document_id, + "workspace": workspace, + } + + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(stream_url, params=params, headers=headers, stream=True) + resp.raise_for_status() if output_file: + total = 0 with open(output_file, 'wb') as f: - f.write(content) - print(f"Written {len(content)} bytes to {output_file}") + for chunk in resp.iter_content(chunk_size=65536): + f.write(chunk) + total += len(chunk) + print(f"Written {total} bytes to {output_file}") else: + content = resp.content try: text = content.decode('utf-8') print(text) diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index b4f37b81..2ff1a3cc 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -47,6 +47,31 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) +def write_library_metadata(f, data): + msg = ( + "lm", + { + "i": data["id"], + "k": data.get("kind", ""), + "t": data.get("title", ""), + "p": data.get("parent-id", ""), + "d": data.get("document-type", ""), + "c": data.get("comments", ""), + "g": data.get("tags", []), + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + +def write_library_blob(f, data): + msg = ( + "lb", + { + "i": data["id"], + "d": data.get("data", b""), + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + def fetch(url, workspace, id, output, token=None): api = Api(url=url, token=token, workspace=workspace) @@ -55,6 +80,8 @@ def fetch(url, workspace, id, output, token=None): try: ge = 0 t = 0 + lm = 0 + lb = 0 with open(output, "wb") as f: @@ -68,7 +95,15 @@ def fetch(url, workspace, id, output, token=None): ge += 1 write_ge(f, response["graph-embeddings"]) - print(f"Got: {t} triple, {ge} GE messages.") + if "library-metadata" in response: + lm += 1 + write_library_metadata(f, response["library-metadata"]) + + if "library-blob" in response: + lb += 1 + write_library_blob(f, response["library-blob"]) + + print(f"Got: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.") finally: socket.close() diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index dccf548e..5649a5ae 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -78,7 +78,7 @@ def load_structured_data( logger.info("Step 1: Analyzing data to discover best matching schema...") # Step 1: Auto-discover schema (reuse discover_schema logic) - discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) + discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace) if not discovered_schema: logger.error("Failed to discover suitable schema automatically") print("❌ Could not automatically determine the best schema for your data.") @@ -90,7 +90,7 @@ def load_structured_data( # Step 2: Auto-generate descriptor logger.info("Step 2: Generating descriptor configuration...") - auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace) + auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace) if not auto_descriptor: logger.error("Failed to generate descriptor automatically") print("❌ Could not automatically generate descriptor configuration.") @@ -172,7 +172,7 @@ def load_structured_data( logger.info(f"Sample chars: {sample_chars} characters") # Use the helper function to discover schema (get raw response for display) - response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace) + response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace) if response: # Debug: print response type and content @@ -203,7 +203,7 @@ def load_structured_data( # If no schema specified, discover it first if not schema_name: logger.info("No schema specified, auto-discovering...") - schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace) + schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace) if not schema_name: print("Error: Could not determine schema automatically.") print("Please specify a schema using --schema-name or run --discover-schema first.") @@ -213,7 +213,7 @@ def load_structured_data( logger.info(f"Target schema: {schema_name}") # Generate descriptor using helper function - descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace) + descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace) if descriptor: # Output the generated descriptor @@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp # Helper functions for auto mode -def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"): +def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"): """Auto-discover the best matching schema for the input data Args: @@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url, workspace=workspace) + api = Api(api_url, token=token, workspace=workspace) config_api = api.config() # Get available schemas @@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur return None -def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"): +def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"): """Auto-generate descriptor configuration for the discovered schema""" try: # Read sample data @@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl # Import API modules from trustgraph.api import Api from trustgraph.api.types import ConfigKey - api = Api(api_url, workspace=workspace) + api = Api(api_url, token=token, workspace=workspace) config_api = api.config() # Get schema definition diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index fe0981a5..f4e0b3dd 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -40,6 +40,23 @@ def read_message(unpacked, id): }, "triples": msg["t"], } + elif unpacked[0] == "lm": + msg = unpacked[1] + return "lm", { + "id": msg["i"], + "kind": msg.get("k", ""), + "title": msg.get("t", ""), + "parent-id": msg.get("p", ""), + "document-type": msg.get("d", ""), + "comments": msg.get("c", ""), + "tags": msg.get("g", []), + } + elif unpacked[0] == "lb": + msg = unpacked[1] + return "lb", { + "id": msg["i"], + "data": msg.get("d", b""), + } else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) @@ -51,6 +68,8 @@ def put(url, workspace, id, input, token=None): try: ge = 0 t = 0 + lm = 0 + lb = 0 with open(input, "rb") as f: @@ -73,10 +92,18 @@ def put(url, workspace, id, input, token=None): t += 1 socket.put_kg_core(id, triples=msg) + elif kind == "lm": + lm += 1 + socket.put_kg_core(id, library_metadata=msg) + + elif kind == "lb": + lb += 1 + socket.put_kg_core(id, library_blob=msg) + else: raise RuntimeError("Unexpected message kind", kind) - print(f"Put: {t} triple, {ge} GE messages.") + print(f"Put: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.") finally: socket.close() diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index c5fac198..725f1106 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -83,7 +83,8 @@ class Processor(AsyncProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="config" + default_keyspace="config", + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index f1fa53f5..6f017c43 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -1,6 +1,7 @@ from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings -from .. schema import DocumentEmbeddings +from .. schema import DocumentEmbeddings, LibraryMetadata, LibraryBlob +from .. schema import LibrarianRequest, DocumentMetadata from .. knowledge import hash from .. exceptions import RequestError from .. tables.knowledge import KnowledgeTableStore @@ -18,7 +19,7 @@ class KnowledgeManager: def __init__( self, cassandra_host, cassandra_username, cassandra_password, - keyspace, flow_config, replication_factor=1, + keyspace, flow_config, librarian=None, replication_factor=1, ): self.table_store = KnowledgeTableStore( @@ -26,6 +27,9 @@ class KnowledgeManager: replication_factor ) + self.librarian = librarian + self._pending_library_metadata = {} + self.loader_queue = asyncio.Queue(maxsize=20) self.background_task = None self.flow_config = flow_config @@ -86,6 +90,9 @@ class KnowledgeManager: publish_ge, ) + if self.librarian: + await self._stream_library_docs(request.id, respond) + logger.debug("Knowledge core retrieval complete") await respond( @@ -122,6 +129,12 @@ class KnowledgeManager: workspace, request.graph_embeddings ) + if request.library_metadata and self.librarian: + await self._put_library_metadata(request.library_metadata, workspace) + + if request.library_blob and self.librarian: + await self._put_library_blob(request.library_blob, workspace) + await respond( KnowledgeResponse( error = None, @@ -250,6 +263,112 @@ class KnowledgeManager: await self.loader_queue.put((request, respond, workspace)) + async def _stream_library_docs(self, document_id, respond): + + try: + root_meta = await self.librarian.fetch_document_metadata( + document_id + ) + except Exception as e: + logger.warning(f"Could not fetch library metadata for {document_id}: {e}") + return + + if root_meta is None: + return + + await self._stream_one_doc(root_meta, respond) + + try: + resp = await self.librarian.request( + LibrarianRequest( + operation="list-children", + document_id=document_id, + ) + ) + except Exception as e: + logger.warning(f"Could not list children for {document_id}: {e}") + return + + for child_meta in resp.document_metadatas: + await self._stream_one_doc(child_meta, respond) + + async def _stream_one_doc(self, doc_meta, respond): + + lm = LibraryMetadata( + id=doc_meta.id, + kind=doc_meta.kind, + title=doc_meta.title, + parent_id=doc_meta.parent_id, + document_type=doc_meta.document_type, + comments=doc_meta.comments, + tags=doc_meta.tags or [], + ) + + await respond( + KnowledgeResponse(library_metadata=lm) + ) + + try: + content = await self.librarian.fetch_document_content( + doc_meta.id + ) + except Exception as e: + logger.warning(f"Could not fetch content for {doc_meta.id}: {e}") + return + + await respond( + KnowledgeResponse( + library_blob=LibraryBlob( + id=doc_meta.id, + data=content, + ) + ) + ) + + async def _put_library_metadata(self, lm, workspace): + self._pending_library_metadata[lm.id] = lm + + async def _put_library_blob(self, lb, workspace): + + lm = self._pending_library_metadata.pop(lb.id, None) + if lm is None: + logger.warning( + f"Received library blob for {lb.id} with no preceding metadata" + ) + return + + doc_meta = DocumentMetadata( + id=lm.id, + kind=lm.kind, + title=lm.title, + parent_id=lm.parent_id, + document_type=lm.document_type, + comments=lm.comments, + tags=lm.tags or [], + ) + + if lm.parent_id: + operation = "add-child-document" + else: + operation = "add-document" + + try: + await self.librarian.request( + LibrarianRequest( + operation=operation, + document_id=lm.id, + document_metadata=doc_meta, + content=lb.data, + ) + ) + except RuntimeError as e: + if "already exists" in str(e): + logger.debug(f"Library document {lm.id} already exists, skipping") + else: + logger.warning(f"Could not save library document {lm.id}: {e}") + except Exception as e: + logger.warning(f"Could not save library document {lm.id}: {e}") + async def core_loader(self): logger.info("Knowledge background processor running...") diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index a04e42ca..5c50c207 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -12,6 +12,7 @@ import logging from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber from .. base import ConsumerMetrics, ProducerMetrics from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .. base import LibrarianClient from .. schema import KnowledgeRequest, KnowledgeResponse, Error from .. schema import knowledge_request_queue, knowledge_response_queue @@ -60,7 +61,8 @@ class Processor(WorkspaceProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="knowledge" + default_keyspace="knowledge", + replication_factor=params.get("cassandra_replication_factor"), ) self.cassandra_host = hosts @@ -77,12 +79,17 @@ class Processor(WorkspaceProcessor): } ) + self.librarian_client = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, + ) + self.knowledge = KnowledgeManager( cassandra_host = self.cassandra_host, cassandra_username = self.cassandra_username, cassandra_password = self.cassandra_password, keyspace = keyspace, flow_config = self, + librarian = self.librarian_client, replication_factor = replication_factor, ) @@ -156,6 +163,7 @@ class Processor(WorkspaceProcessor): async def start(self): await super(Processor, self).start() + await self.librarian_client.start() async def on_knowledge_config(self, workspace, config, version): diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index f214111d..40ecac8a 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -219,7 +219,14 @@ class Processor(FlowProcessor): source_doc_id = v.document_id or v.metadata.id # Run OCR, get per-page markdown - pages = self.ocr(blob) + try: + pages = self.ocr(blob) + except Exception as e: + logger.error( + f"Failed to decode PDF {source_doc_id}: " + f"{type(e).__name__}: {e}" + ) + return for markdown, page_num in pages: diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 209153f6..ae393028 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -32,6 +32,10 @@ logger = logging.getLogger(__name__) default_ident = "document-decoder" +def _looks_like_pdf(content): + return content.lstrip().startswith(b"%PDF-") + + class Processor(FlowProcessor): def __init__(self, **params): @@ -94,33 +98,37 @@ class Processor(FlowProcessor): ) return - with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp: + # Check if we should fetch from librarian or use inline data + if v.document_id: + # Fetch from librarian via Pulsar + logger.info(f"Fetching document {v.document_id} from librarian...") + + content = await flow.librarian.fetch_document_content( + document_id=v.document_id, + + ) + + # Content is base64 encoded + if isinstance(content, str): + content = content.encode('utf-8') + decoded_content = base64.b64decode(content) + + logger.info(f"Fetched {len(decoded_content)} bytes from librarian") + else: + # Use inline data (backward compatibility) + decoded_content = base64.b64decode(v.data) + + if not _looks_like_pdf(decoded_content): + logger.error( + f"Document {v.metadata.id} is not valid PDF content. " + f"Ignoring document." + ) + return + + with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp: temp_path = fp.name - - # Check if we should fetch from librarian or use inline data - if v.document_id: - # Fetch from librarian via Pulsar - logger.info(f"Fetching document {v.document_id} from librarian...") - fp.close() - - content = await flow.librarian.fetch_document_content( - document_id=v.document_id, - - ) - - # Content is base64 encoded - if isinstance(content, str): - content = content.encode('utf-8') - decoded_content = base64.b64decode(content) - - with open(temp_path, 'wb') as f: - f.write(decoded_content) - - logger.info(f"Fetched {len(decoded_content)} bytes from librarian") - else: - # Use inline data (backward compatibility) - fp.write(base64.b64decode(v.data)) - fp.close() + fp.write(decoded_content) + fp.close() global PyPDFLoader if PyPDFLoader is None: @@ -129,7 +137,15 @@ class Processor(FlowProcessor): ) PyPDFLoader = _cls loader = PyPDFLoader(temp_path) - pages = loader.load() + try: + pages = loader.load() + except Exception as e: + source_doc_id = v.document_id or v.metadata.id + logger.error( + f"Failed to decode PDF {source_doc_id}: " + f"{type(e).__name__}: {e}" + ) + return # Get the source document ID source_doc_id = v.document_id or v.metadata.id diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index d7abd1a9..f1e4a577 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -6,7 +6,7 @@ import logging from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement, SimpleStatement -from ssl import SSLContext, PROTOCOL_TLSv1_2 +import ssl from ..tables.cassandra_async import async_execute @@ -41,13 +41,15 @@ class KnowledgeGraph: def __init__( self, hosts=None, - keyspace="trustgraph", username=None, password=None + keyspace="trustgraph", username=None, password=None, + replication_factor=1, ): if hosts is None: hosts = ["localhost"] self.keyspace = keyspace + self.replication_factor = replication_factor self.username = username # 7-table schema for quads with full query pattern support @@ -68,7 +70,7 @@ class KnowledgeGraph: self.collection_metadata_table = "collection_metadata" if username and password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider(username=username, password=password) self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) else: @@ -92,7 +94,7 @@ class KnowledgeGraph: create keyspace if not exists {self.keyspace} with replication = {{ 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + 'replication_factor' : {self.replication_factor} }}; """) @@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph: def __init__( self, hosts=None, - keyspace="trustgraph", username=None, password=None + keyspace="trustgraph", username=None, password=None, + replication_factor=1, ): if hosts is None: hosts = ["localhost"] self.keyspace = keyspace + self.replication_factor = replication_factor self.username = username # 2-table entity-centric schema @@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph: self.collection_metadata_table = "collection_metadata" if username and password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider(username=username, password=password) self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) else: @@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph: create keyspace if not exists {self.keyspace} with replication = {{ 'class' : 'SimpleStrategy', - 'replication_factor' : 1 + 'replication_factor' : {self.replication_factor} }}; """) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 6696afbe..90080cc4 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -73,6 +73,39 @@ class CoreExport: enc = msgpack.packb(msg) await response.write(enc) + if "library-metadata" in resp: + + data = resp["library-metadata"] + msg = ( + "lm", + { + "i": data["id"], + "k": data.get("kind", ""), + "t": data.get("title", ""), + "p": data.get("parent-id", ""), + "d": data.get("document-type", ""), + "c": data.get("comments", ""), + "g": data.get("tags", []), + } + ) + + enc = msgpack.packb(msg) + await response.write(enc) + + if "library-blob" in resp: + + data = resp["library-blob"] + msg = ( + "lb", + { + "i": data["id"], + "d": data.get("data", b""), + } + ) + + enc = msgpack.packb(msg, use_bin_type=True) + await response.write(enc) + await kr.process( { "operation": "get-kg-core", diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index d03d4efd..bf660def 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -79,6 +79,39 @@ class CoreImport: await kr.process(msg) + elif unpacked[0] == "lm": + msg = unpacked[1] + msg = { + "operation": "put-kg-core", + "workspace": workspace, + "id": id, + "library-metadata": { + "id": msg["i"], + "kind": msg.get("k", ""), + "title": msg.get("t", ""), + "parent-id": msg.get("p", ""), + "document-type": msg.get("d", ""), + "comments": msg.get("c", ""), + "tags": msg.get("g", []), + } + } + + await kr.process(msg) + + elif unpacked[0] == "lb": + msg = unpacked[1] + msg = { + "operation": "put-kg-core", + "workspace": workspace, + "id": id, + "library-blob": { + "id": msg["i"], + "data": msg.get("d", b""), + } + } + + await kr.process(msg) + except Exception as e: logger.error(f"Core import exception: {e}", exc_info=True) await error(str(e)) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py index 2992d99f..74b4d7df 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py @@ -3,6 +3,7 @@ import asyncio import uuid import logging from . librarian import LibrarianRequestor +from ... schema import librarian_request_queue, librarian_response_queue # Module logger logger = logging.getLogger(__name__) @@ -23,10 +24,13 @@ class DocumentStreamExport: response = await ok() + uid = str(uuid.uuid4()) lr = LibrarianRequestor( backend=self.backend, - consumer="api-gateway-doc-stream-" + str(uuid.uuid4()), - subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()), + consumer="api-gateway-doc-stream-" + uid, + subscriber="api-gateway-doc-stream-" + uid, + request_queue=f"{librarian_request_queue}:{workspace}", + response_queue=f"{librarian_response_queue}:{workspace}", ) try: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index bdbd18d8..9b119f8e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -4,6 +4,8 @@ import queue import uuid import logging +from ..capabilities import PUBLIC, AUTHENTICATED + # Module logger logger = logging.getLogger(__name__) @@ -156,37 +158,41 @@ class Mux: }) return - # Resolve workspace first (default-fill from the caller's - # bound workspace), then ask the regime to authorise the - # service-level capability against the matched - # operation's resource shape. + # Resolve workspace (default-fill from the caller's + # bound workspace). Workspace resolution applies to all + # operations regardless of capability level. try: await enforce_workspace(data, self.identity, self.auth) if isinstance(inner, dict): await enforce_workspace(inner, self.identity, self.auth) - if data.get("flow"): - resource = { - "workspace": data.get("workspace", ""), - "flow": data.get("flow", ""), - } - parameters = {} - else: - # Build a minimal RequestContext so the matched - # operation's own extractors decide resource and - # parameters — same path the HTTP endpoints take. - from ..registry import RequestContext - ctx = RequestContext( - body=inner if isinstance(inner, dict) else {}, - match_info={}, - identity=self.identity, - ) - resource = op.extract_resource(ctx) - parameters = op.extract_parameters(ctx) + # Authorisation: capability sentinels short-circuit + # the regime call; capability strings go through + # authorise(). + if op.capability not in (PUBLIC, AUTHENTICATED): + if data.get("flow"): + resource = { + "workspace": data.get("workspace", ""), + "flow": data.get("flow", ""), + } + parameters = {} + else: + # Build a minimal RequestContext so the matched + # operation's own extractors decide resource + # and parameters — same path the HTTP + # endpoints take. + from ..registry import RequestContext + ctx = RequestContext( + body=inner if isinstance(inner, dict) else {}, + match_info={}, + identity=self.identity, + ) + resource = op.extract_resource(ctx) + parameters = op.extract_parameters(ctx) - await self.auth.authorise( - self.identity, op.capability, resource, parameters, - ) + await self.auth.authorise( + self.identity, op.capability, resource, parameters, + ) except _web.HTTPNotFound: await self.ws.send_json({ "id": request_id, @@ -288,6 +294,8 @@ class Mux: await self.maybe_tidy_workers(workers) async def responder(resp, fin): + if self.ws is None: + return await self.ws.send_json({ "id": id, "response": resp, @@ -321,6 +329,8 @@ class Mux: ) except Exception as e: + if self.ws is None: + return await self.ws.send_json({ "id": id, "error": {"message": str(e), "type": "error"}, diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index f53ad73b..af6183db 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -117,8 +117,10 @@ class SocketEndpoint: running = Running() + params = dict(request.query) + params.update(request.match_info) dispatcher = await self.dispatcher( - ws, running, request.match_info + ws, running, params ) worker_task = tg.create_task( diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py index 8ce22757..b2f3976d 100644 --- a/trustgraph-flow/trustgraph/iam/service/service.py +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -101,6 +101,7 @@ class Processor(AsyncProcessor): username=cassandra_username, password=cassandra_password, default_keyspace="iam", + replication_factor=params.get("cassandra_replication_factor"), ) self.cassandra_host = hosts diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 1c4d010e..cc5f0bdf 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -162,6 +162,9 @@ class Librarian: request.document_id ) + if object_id is None: + raise RequestError(f"Document not found: {request.document_id}") + content = await self.blob_store.get( object_id ) diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index cc5efdae..4d3efbfb 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -8,6 +8,7 @@ import asyncio import base64 import json import logging +import os from datetime import datetime from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber @@ -54,6 +55,16 @@ default_object_store_access_key = "object-user" default_object_store_secret_key = "object-password" default_object_store_use_ssl = False default_object_store_region = None + +# Environment variables consulted as a fallback when the +# corresponding params field is not set in the processor-group YAML +# or via CLI. Intended for K8s Secret / env-var injection so +# credentials never have to live in the YAML (and thus in git). +ENV_OBJECT_STORE_ENDPOINT = "OBJECT_STORE_ENDPOINT" +ENV_OBJECT_STORE_ACCESS_KEY = "OBJECT_STORE_ACCESS_KEY" +ENV_OBJECT_STORE_SECRET_KEY = "OBJECT_STORE_SECRET_KEY" +ENV_OBJECT_STORE_USE_SSL = "OBJECT_STORE_USE_SSL" +ENV_OBJECT_STORE_REGION = "OBJECT_STORE_REGION" default_cassandra_host = "cassandra" default_min_chunk_size = 1 # No minimum by default (for Garage) @@ -89,22 +100,36 @@ class Processor(WorkspaceProcessor): "config_response_queue", default_config_response_queue ) - object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint) - object_store_access_key = params.get( - "object_store_access_key", - default_object_store_access_key + # Resolve object-store config. Precedence: explicit params + # (CLI / processor-group YAML) → environment variable → + # hardcoded default. The env-var path lets K8s Secrets feed + # credentials without them appearing in the YAML. + object_store_endpoint = ( + params.get("object_store_endpoint") + or os.environ.get(ENV_OBJECT_STORE_ENDPOINT) + or default_object_store_endpoint ) - object_store_secret_key = params.get( - "object_store_secret_key", - default_object_store_secret_key + object_store_access_key = ( + params.get("object_store_access_key") + or os.environ.get(ENV_OBJECT_STORE_ACCESS_KEY) + or default_object_store_access_key ) - object_store_use_ssl = params.get( - "object_store_use_ssl", - default_object_store_use_ssl + object_store_secret_key = ( + params.get("object_store_secret_key") + or os.environ.get(ENV_OBJECT_STORE_SECRET_KEY) + or default_object_store_secret_key ) - object_store_region = params.get( - "object_store_region", - default_object_store_region + object_store_use_ssl = params.get("object_store_use_ssl") + if object_store_use_ssl is None: + env_ssl = os.environ.get(ENV_OBJECT_STORE_USE_SSL) + if env_ssl is not None: + object_store_use_ssl = env_ssl.lower() in ("true", "1", "yes") + else: + object_store_use_ssl = default_object_store_use_ssl + object_store_region = ( + params.get("object_store_region") + or os.environ.get(ENV_OBJECT_STORE_REGION) + or default_object_store_region ) min_chunk_size = params.get( @@ -121,7 +146,8 @@ class Processor(WorkspaceProcessor): host=cassandra_host, username=cassandra_username, password=cassandra_password, - default_keyspace="librarian" + default_keyspace="librarian", + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index f6770744..de25a139 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -12,31 +12,33 @@ from qdrant_client import QdrantClient from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error from .... base import DocumentEmbeddingsQueryService +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config # Module logger logger = logging.getLogger(__name__) default_ident = "doc-embeddings-query" -default_store_uri = 'http://localhost:6333' - class Processor(DocumentEmbeddingsQueryService): def __init__(self, **params): - store_uri = params.get("store_uri", default_store_uri) + store_uri = params.get("store_uri") + api_key = params.get("api_key") - #optional api key - api_key = params.get("api_key", None) + url, api_key, _, _ = resolve_qdrant_config( + url=store_uri, + api_key=api_key, + ) super(Processor, self).__init__( **params | { - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) async def query_document_embeddings(self, workspace, msg): @@ -85,18 +87,7 @@ class Processor(DocumentEmbeddingsQueryService): def add_args(parser): DocumentEmbeddingsQueryService.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant store URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help=f'API key for qdrant (default: None)' - ) + add_qdrant_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 167130c9..aa93925d 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -12,31 +12,32 @@ from qdrant_client import QdrantClient from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL from .... base import GraphEmbeddingsQueryService +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config # Module logger logger = logging.getLogger(__name__) default_ident = "graph-embeddings-query" -default_store_uri = 'http://localhost:6333' - class Processor(GraphEmbeddingsQueryService): def __init__(self, **params): - store_uri = params.get("store_uri", default_store_uri) + store_uri = params.get("store_uri") + api_key = params.get("api_key") - #optional api key - api_key = params.get("api_key", None) + url, api_key, _, _ = resolve_qdrant_config( + url=store_uri, api_key=api_key, + ) super(Processor, self).__init__( **params | { - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService): def add_args(parser): GraphEmbeddingsQueryService.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant store URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help=f'API key for qdrant (default: None)' - ) + add_qdrant_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py index b7f0f423..a9005ee4 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object): # Create keyspace self.session.execute(f""" CREATE KEYSPACE IF NOT EXISTS {self.keyspace} - WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} + WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}} """) # Create triples table optimized for SPARQL queries diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 1534c044..7e1a5851 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -19,12 +19,12 @@ from .... schema import ( RowIndexMatch, Error ) from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config # Module logger logger = logging.getLogger(__name__) default_ident = "row-embeddings-query" -default_store_uri = 'http://localhost:6333' default_concurrency = 10 @@ -35,13 +35,17 @@ class Processor(FlowProcessor): id = params.get("id", default_ident) concurrency = params.get("concurrency", default_concurrency) - store_uri = params.get("store_uri", default_store_uri) - api_key = params.get("api_key", None) + store_uri = params.get("store_uri") + api_key = params.get("api_key") + + url, api_key, _, _ = resolve_qdrant_config( + url=store_uri, api_key=api_key, + ) super(Processor, self).__init__( **params | { "id": id, - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) @@ -62,7 +66,7 @@ class Processor(FlowProcessor): ) ) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) def sanitize_name(self, name: str) -> str: """Sanitize names for Qdrant collection naming""" @@ -192,21 +196,9 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): - """Add command-line arguments""" FlowProcessor.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant store URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help='API key for Qdrant (default: None)' - ) + add_qdrant_args(parser) parser.add_argument( '-c', '--concurrency', diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 7157daae..f9868d67 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from .... schema import Error, RowSchema, Field as SchemaField from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config -from .... tables.cassandra_async import async_execute +from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan from ... graphql import GraphQLSchemaBuilder, SortDirection @@ -180,7 +180,7 @@ class Processor(FlowProcessor): description=field_def.get("description", ""), required=field_def.get("required", False), enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) + indexed=field_def.get("indexed", False), ) fields.append(field) @@ -232,6 +232,8 @@ class Processor(FlowProcessor): for index_name in index_names: if index_name in filters: value = filters[index_name] + if value == "" or value is None: + continue # Single field index -> single element list index_value = [str(value)] return (index_name, index_value) @@ -282,11 +284,13 @@ class Processor(FlowProcessor): query += f" LIMIT {limit}" try: - rows = await async_execute(self.session, query, params) - for row in rows: - # Convert data map to dict with proper field names - row_dict = dict(row.data) if row.data else {} - results.append(row_dict) + pages = await async_execute_paged( + self.session, query, params + ) + for page in pages: + for row in page: + row_dict = dict(row.data) if row.data else {} + results.append(row_dict) except Exception as e: logger.error(f"Failed to query rows: {e}", exc_info=True) raise @@ -308,8 +312,6 @@ class Processor(FlowProcessor): # Query using the first index (arbitrary choice for scan) primary_index = index_names[0] - # We need to scan all values for this index - # This requires ALLOW FILTERING or a different approach query = f""" SELECT data, source FROM {safe_keyspace}.rows WHERE collection = %s @@ -320,17 +322,18 @@ class Processor(FlowProcessor): params = [collection, schema_name, primary_index] try: - rows = await async_execute(self.session, query, params) - - for row in rows: + def row_filter(row): row_dict = dict(row.data) if row.data else {} + return self._matches_filters(row_dict, filters, row_schema) - # Apply post-filters - if self._matches_filters(row_dict, filters, row_schema): - results.append(row_dict) - - if limit and len(results) >= limit: - break + matched_rows = await async_scan( + self.session, query, params, + row_filter=row_filter, + limit=limit, + ) + for row in matched_rows: + row_dict = dict(row.data) if row.data else {} + results.append(row_dict) except Exception as e: logger.error(f"Failed to scan rows: {e}", exc_info=True) @@ -363,7 +366,7 @@ class Processor(FlowProcessor): # Parse filter key for operator if '_' in filter_key: parts = filter_key.rsplit('_', 1) - if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']: + if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']: field_name = parts[0] operator = parts[1] else: @@ -400,6 +403,18 @@ class Processor(FlowProcessor): elif operator == 'in': if str(row_value) not in [str(v) for v in filter_value]: return False + elif operator == 'not': + if str(row_value) == str(filter_value): + return False + elif operator == 'startsWith': + if not str(row_value).startswith(str(filter_value)): + return False + elif operator == 'endsWith': + if not str(row_value).endswith(str(filter_value)): + return False + elif operator == 'not_in': + if str(row_value) in [str(v) for v in filter_value]: + return False except (ValueError, TypeError): return False diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 2bfef99c..08d88849 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -14,29 +14,36 @@ from qdrant_client.models import Distance, VectorParams from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config # Module logger logger = logging.getLogger(__name__) default_ident = "doc-embeddings-write" -default_store_uri = 'http://localhost:6333' - class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): def __init__(self, **params): - store_uri = params.get("store_uri", default_store_uri) - api_key = params.get("api_key", None) + store_uri = params.get("store_uri") + api_key = params.get("api_key") + + url, api_key, replication_factor, shard_number = resolve_qdrant_config( + url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), + ) super(Processor, self).__init__( **params | { - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) + self.replication_factor = replication_factor + self.shard_number = shard_number self._cache_lock = asyncio.Lock() self._known_collections: set[str] = set() @@ -61,6 +68,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): vectors_config=VectorParams( size=dim, distance=Distance.COSINE ), + replication_factor=self.replication_factor, + shard_number=self.shard_number, ) self._known_collections.add(collection_name) @@ -109,18 +118,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): def add_args(parser): DocumentEmbeddingsStoreService.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help=f'Qdrant API key (default: None)' - ) + add_qdrant_args(parser) async def create_collection(self, workspace: str, collection: str, metadata: dict): """ diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 13dcdba8..b6072bdc 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config from .... schema import IRI, LITERAL # Module logger @@ -29,29 +30,34 @@ def get_term_value(term): elif term.type == LITERAL: return term.value else: - # For blank nodes or other types, use id or value return term.id or term.value default_ident = "graph-embeddings-write" -default_store_uri = 'http://localhost:6333' - class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): def __init__(self, **params): - store_uri = params.get("store_uri", default_store_uri) - api_key = params.get("api_key", None) + store_uri = params.get("store_uri") + api_key = params.get("api_key") + + url, api_key, replication_factor, shard_number = resolve_qdrant_config( + url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), + ) super(Processor, self).__init__( **params | { - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) + self.replication_factor = replication_factor + self.shard_number = shard_number self._cache_lock = asyncio.Lock() self._known_collections: set[str] = set() @@ -76,6 +82,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): vectors_config=VectorParams( size=dim, distance=Distance.COSINE ), + replication_factor=self.replication_factor, + shard_number=self.shard_number, ) self._known_collections.add(collection_name) @@ -128,18 +136,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): def add_args(parser): GraphEmbeddingsStoreService.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant store URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help=f'Qdrant API key' - ) + add_qdrant_args(parser) async def create_collection(self, workspace: str, collection: str, metadata: dict): """ diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index 162a4057..f6e12a85 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -27,7 +27,8 @@ class Processor(FlowProcessor): host=params.get("cassandra_host"), username=params.get("cassandra_username"), password=params.get("cassandra_password"), - default_keyspace='knowledge' + default_keyspace='knowledge', + replication_factor=params.get("cassandra_replication_factor"), ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index a01629c5..4c65edb1 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams from .... schema import RowEmbeddings from .... base import FlowProcessor, ConsumerSpec from .... base import CollectionConfigHandler +from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config # Module logger logger = logging.getLogger(__name__) default_ident = "row-embeddings-write" -default_store_uri = 'http://localhost:6333' class Processor(CollectionConfigHandler, FlowProcessor): @@ -41,13 +41,19 @@ class Processor(CollectionConfigHandler, FlowProcessor): id = params.get("id", default_ident) - store_uri = params.get("store_uri", default_store_uri) - api_key = params.get("api_key", None) + store_uri = params.get("store_uri") + api_key = params.get("api_key") + + url, api_key, replication_factor, shard_number = resolve_qdrant_config( + url=store_uri, api_key=api_key, + replication_factor=params.get("qdrant_replication_factor"), + shard_number=params.get("qdrant_shard_number"), + ) super(Processor, self).__init__( **params | { "id": id, - "store_uri": store_uri, + "store_uri": url, "api_key": api_key, } ) @@ -63,7 +69,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register config handler for collection management self.register_config_handler(self.on_collection_config, types=["collection"]) - self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self.qdrant = QdrantClient(url=url, api_key=api_key) + self.replication_factor = replication_factor + self.shard_number = shard_number self._cache_lock = asyncio.Lock() self._known_collections: set[str] = set() @@ -103,6 +111,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): size=dimension, distance=Distance.COSINE ), + replication_factor=self.replication_factor, + shard_number=self.shard_number, ) self._known_collections.add(collection_name) @@ -249,21 +259,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): @staticmethod def add_args(parser): - """Add command-line arguments""" FlowProcessor.add_args(parser) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Qdrant URI (default: {default_store_uri})' - ) - - parser.add_argument( - '-k', '--api-key', - default=None, - help='Qdrant API key (default: None)' - ) + add_qdrant_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index 65eeee06..31fc41a7 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -47,16 +47,18 @@ class Processor(CollectionConfigHandler, FlowProcessor): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password, keyspace, _ = resolve_cassandra_config( + hosts, username, password, keyspace, replication_factor = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, - password=cassandra_password + password=cassandra_password, + replication_factor=params.get("cassandra_replication_factor"), ) # Store resolved configuration with proper names self.cassandra_host = hosts # Store as list self.cassandra_username = username self.cassandra_password = password + self.replication_factor = replication_factor # Config key for schemas self.config_key = params.get("config_type", "schema") @@ -170,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): description=field_def.get("description", ""), required=field_def.get("required", False), enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) + indexed=field_def.get("indexed", False), ) fields.append(field) @@ -232,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): CREATE KEYSPACE IF NOT EXISTS {safe_keyspace} WITH REPLICATION = {{ 'class': 'SimpleStrategy', - 'replication_factor': 1 + 'replication_factor': {self.replication_factor} }} """ diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py index 2f497748..fe410a26 100644 --- a/trustgraph-flow/trustgraph/tables/cassandra_async.py +++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py @@ -27,6 +27,8 @@ Notes: import asyncio +from cassandra.query import SimpleStatement + async def async_execute(session, query, parameters=None): """Execute a CQL statement asynchronously. @@ -76,3 +78,83 @@ def _set_result_if_pending(fut, result): def _set_exception_if_pending(fut, exc): if not fut.done(): fut.set_exception(exc) + + +async def async_execute_paged(session, query, parameters=None, fetch_size=5000): + """Execute a CQL query with page-by-page iteration. + + Uses synchronous session.execute() inside run_in_executor so that + the driver's ResultSet paging works correctly without materialising + the entire result set in memory. + + Returns all pages as a list of lists. + """ + loop = asyncio.get_running_loop() + + if isinstance(query, str): + stmt = SimpleStatement(query, fetch_size=fetch_size) + else: + stmt = query + stmt.fetch_size = fetch_size + + def _fetch_all_pages(): + pages = [] + result_set = session.execute(stmt, parameters) + while True: + pages.append(list(result_set.current_rows)) + if result_set.has_more_pages: + result_set.fetch_next_page() + else: + break + return pages + + return await loop.run_in_executor( + None, _fetch_all_pages + ) + + +async def async_scan( + session, query, parameters=None, row_filter=None, + limit=None, fetch_size=5000, +): + """Scan a CQL query page-by-page, applying a filter and limit. + + Only matching rows accumulate in memory. Each page is discarded + after processing, so peak memory is bounded by fetch_size plus + the number of matching rows (capped by limit). + + Args: + session: cassandra.cluster.Session + query: CQL statement string + parameters: bind params + row_filter: callable(row) -> bool, or None to accept all + limit: max results to return, or None for unlimited + fetch_size: rows per Cassandra page fetch + + Returns: + List of matching rows. + """ + loop = asyncio.get_running_loop() + + if isinstance(query, str): + stmt = SimpleStatement(query, fetch_size=fetch_size) + else: + stmt = query + stmt.fetch_size = fetch_size + + def _scan(): + results = [] + result_set = session.execute(stmt, parameters) + while True: + for row in result_set.current_rows: + if row_filter is None or row_filter(row): + results.append(row) + if limit and len(results) >= limit: + return results + if result_set.has_more_pages: + result_set.fetch_next_page() + else: + break + return results + + return await loop.run_in_executor(None, _scan) diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index 74ceb6f4..c87cb3b5 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider -from ssl import SSLContext, PROTOCOL_TLSv1_2 +import ssl import uuid import time @@ -33,7 +33,7 @@ class ConfigTableStore: cassandra_host = [h.strip() for h in cassandra_host.split(',')] if cassandra_username and cassandra_password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider( username=cassandra_username, password=cassandra_password ) diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py index d7bf5e3d..b60e9cff 100644 --- a/trustgraph-flow/trustgraph/tables/iam.py +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -15,7 +15,7 @@ import logging from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider -from ssl import SSLContext, PROTOCOL_TLSv1_2 +import ssl from . cassandra_async import async_execute @@ -39,7 +39,7 @@ class IamTableStore: cassandra_host = [h.strip() for h in cassandra_host.split(",")] if cassandra_username and cassandra_password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider( username=cassandra_username, password=cassandra_password, ) diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index cf085fdd..53a12b35 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -5,7 +5,7 @@ from .. schema import DocumentEmbeddings, ChunkEmbeddings from cassandra.cluster import Cluster -from . cassandra_async import async_execute +from . cassandra_async import async_execute, async_execute_paged def term_to_tuple(term): @@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri): else: return Term(type=LITERAL, value=value) from cassandra.auth import PlainTextAuthProvider -from ssl import SSLContext, PROTOCOL_TLSv1_2 +import ssl import uuid import time @@ -50,7 +50,7 @@ class KnowledgeTableStore: cassandra_host = [h.strip() for h in cassandra_host.split(',')] if cassandra_username and cassandra_password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider( username=cassandra_username, password=cassandra_password ) @@ -98,7 +98,8 @@ class KnowledgeTableStore: text, boolean, text, boolean, text, boolean >>, triples list>, PRIMARY KEY ((workspace, document_id), id) ); @@ -234,7 +235,8 @@ class KnowledgeTableStore: triples = [ ( - *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o) + *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o), + v.g or "" ) for v in m.triples ] @@ -398,7 +400,7 @@ class KnowledgeTableStore: logger.debug("Get triples...") try: - rows = await async_execute( + pages = await async_execute_paged( self.cassandra, self.get_triples_stmt, (workspace, document_id), @@ -407,29 +409,31 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - for row in rows: + for page in pages: + for row in page: - if row[3]: - triples = [ - Triple( - s = tuple_to_term(elt[0], elt[1]), - p = tuple_to_term(elt[2], elt[3]), - o = tuple_to_term(elt[4], elt[5]), + if row[3]: + triples = [ + Triple( + s = tuple_to_term(elt[0], elt[1]), + p = tuple_to_term(elt[2], elt[3]), + o = tuple_to_term(elt[4], elt[5]), + g = elt[6] if elt[6] else None, + ) + for elt in row[3] + ] + else: + triples = [] + + await receiver( + Triples( + metadata = Metadata( + id = document_id, + collection = "default", + ), + triples = triples ) - for elt in row[3] - ] - else: - triples = [] - - await receiver( - Triples( - metadata = Metadata( - id = document_id, - collection = "default", # FIXME: What to put here? - ), - triples = triples ) - ) logger.debug("Done") @@ -438,7 +442,7 @@ class KnowledgeTableStore: logger.debug("Get GE...") try: - rows = await async_execute( + pages = await async_execute_paged( self.cassandra, self.get_graph_embeddings_stmt, (workspace, document_id), @@ -447,28 +451,29 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - for row in rows: + for page in pages: + for row in page: - if row[3]: - entities = [ - EntityEmbeddings( - entity = tuple_to_term(ent[0][0], ent[0][1]), - vector = ent[1] + if row[3]: + entities = [ + EntityEmbeddings( + entity = tuple_to_term(ent[0][0], ent[0][1]), + vector = ent[1] + ) + for ent in row[3] + ] + else: + entities = [] + + await receiver( + GraphEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + entities = entities ) - for ent in row[3] - ] - else: - entities = [] - - await receiver( - GraphEmbeddings( - metadata = Metadata( - id = document_id, - collection = "default", # FIXME: What to put here? - ), - entities = entities ) - ) logger.debug("Done") @@ -477,7 +482,7 @@ class KnowledgeTableStore: logger.debug("Get DE...") try: - rows = await async_execute( + pages = await async_execute_paged( self.cassandra, self.get_document_embeddings_stmt, (workspace, document_id), @@ -486,28 +491,29 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise - for row in rows: + for page in pages: + for row in page: - if row[3]: - chunks = [ - ChunkEmbeddings( - chunk_id=ch[0], - vector=ch[1], + if row[3]: + chunks = [ + ChunkEmbeddings( + chunk_id=ch[0], + vector=ch[1], + ) + for ch in row[3] + ] + else: + chunks = [] + + await receiver( + DocumentEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + chunks = chunks ) - for ch in row[3] - ] - else: - chunks = [] - - await receiver( - DocumentEmbeddings( - metadata = Metadata( - id = document_id, - collection = "default", - ), - chunks = chunks ) - ) logger.debug("Done") diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 58486f0e..5094e103 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -24,7 +24,7 @@ from .. exceptions import RequestError from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement -from ssl import SSLContext, PROTOCOL_TLSv1_2 +import ssl import uuid import time @@ -53,7 +53,7 @@ class LibraryTableStore: cassandra_host = [h.strip() for h in cassandra_host.split(',')] if cassandra_username and cassandra_password: - ssl_context = SSLContext(PROTOCOL_TLSv1_2) + ssl_context = ssl.create_default_context() auth_provider = PlainTextAuthProvider( username=cassandra_username, password=cassandra_password ) diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index 7378db64..11b975b2 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -8,71 +8,180 @@ import logging import json import uuid import argparse -from dataclasses import dataclass +from dataclasses import dataclass, field from collections.abc import AsyncIterator from functools import partial from mcp.server.fastmcp import FastMCP, Context -from mcp.types import TextContent -from websockets.asyncio.client import connect +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.server.auth.middleware.auth_context import get_access_token from trustgraph.base.logging import add_logging_args, setup_logging -from . tg_socket import WebSocketManager +from . tg_socket import WebSocketManager, _token_key + +logger = logging.getLogger(__name__) + + +# Wire-format Term type codes (match TermTranslator compact keys) +_TERM_TYPES = { + "iri": "i", + "literal": "l", + "blank": "b", +} + + +def _make_term(value: str, term_type: str) -> dict: + """Build a compact-key Term dict for the gateway wire format. + + Args: + value: The term value (IRI string, literal text, or blank node id). + term_type: One of "iri", "literal", "blank". + """ + t = _TERM_TYPES.get(term_type) + if t is None: + raise ValueError( + f"Unknown term type '{term_type}' — " + f"expected one of: {', '.join(_TERM_TYPES)}" + ) + + if t == "i": + return {"t": t, "i": value} + elif t == "l": + return {"t": t, "v": value} + elif t == "b": + return {"t": t, "d": value} + return {"t": t} + +# ── Security boundary: MCP client → MCP server ── +# The MCP client authenticates to this server via a Bearer token in the +# HTTP Authorization header. The SDK's auth middleware extracts and +# verifies the token before any tool handler runs. +# +# We implement a pass-through TokenVerifier: the gateway is the real +# authority, so we accept any non-empty Bearer token here and forward +# it to the gateway for validation. The gateway's in-band auth +# protocol and IAM regime decide whether the token is valid. +# +# This means an invalid token will connect to the MCP server but will +# fail when the first WebSocket auth frame is sent to the gateway. +# That is intentional — the gateway is the single source of truth. + + +class PassthroughTokenVerifier(TokenVerifier): + """Accept any non-empty Bearer token and forward it downstream. + + The TrustGraph gateway is the authority for token validation, not + this MCP server. We store the raw token in the AccessToken so that + tool handlers can retrieve it via ``get_access_token().token`` and + forward it to the gateway. + """ + + async def verify_token(self, token: str) -> AccessToken | None: + if not token: + return None + return AccessToken( + token=token, + client_id="mcp-caller", + scopes=[], + ) + @dataclass class AppContext: - sockets: dict[str, WebSocketManager] - websocket_url: str - gateway_token: str + sockets: dict[str, WebSocketManager] = field(default_factory=dict) + websocket_url: str = "" + @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: +async def app_lifespan( + server: FastMCP, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", +) -> AsyncIterator[AppContext]: + """Manage per-server state: the pool of per-caller WebSocket + connections to the gateway.""" - """ - Manage application lifecycle with type-safe context - """ - - # Initialize on startup - sockets = {} + sockets: dict[str, WebSocketManager] = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) + yield AppContext(sockets=sockets, websocket_url=websocket_url) finally: - # Cleanup on shutdown - logging.info("Shutting down context") + logger.info("Shutting down — closing %d WebSocket(s)", len(sockets)) - for k, manager in sockets.items(): - logging.info(f"Closing socket for {k}") - await manager.stop() + for key, manager in sockets.items(): + try: + await manager.stop() + except Exception as e: + logger.warning("Error closing socket %s: %s", key, e) - logging.info("Shutdown complete") + logger.info("Shutdown complete") -async def get_socket_manager(ctx): + +def _require_token() -> str: + """Extract the caller's Bearer token from the MCP auth context. + + Raises RuntimeError if no token is present (the caller did not + authenticate). + """ + # ── Security boundary: token extraction ── + # get_access_token() reads the contextvar set by the SDK's + # AuthContextMiddleware. The token was placed there by + # PassthroughTokenVerifier.verify_token() and is the raw Bearer + # value from the MCP client's Authorization header. + access = get_access_token() + if access is None or not access.token: + raise RuntimeError( + "Authentication required — send a Bearer token in the " + "Authorization header" + ) + return access.token + + +async def get_socket_manager(ctx, token): + """Return (or create) an authenticated WebSocket for this token. + + Each unique token gets its own WebSocket connection so that + gateway-side identity, workspace binding, and capability scoping + are preserved per caller. + """ lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url - gateway_token = lifespan_context.gateway_token - if "default" in sockets: - logging.info("Return existing socket manager") - return sockets["default"] + key = _token_key(token) - logging.info(f"Opening socket to {websocket_url}...") + if key in sockets: + manager = sockets[key] + if manager.socket is not None: + return manager + # Socket was closed (e.g. server-side timeout) — reconnect. + del sockets[key] - # Create manager with empty pending requests - manager = WebSocketManager(websocket_url, token=gateway_token) + logger.info("Opening authenticated WebSocket to %s …", websocket_url) - # Start reader task with the proper manager + manager = WebSocketManager(websocket_url, token=token) await manager.start() - sockets["default"] = manager + # Verify the token is valid by calling whoami. This confirms the + # gateway accepted the token and gives us the caller's identity. + try: + identity = await manager.whoami() + logger.info( + "WebSocket ready — caller: %s", + identity.get("handle", "unknown"), + ) + except Exception as e: + await manager.stop() + raise RuntimeError( + f"Token rejected by gateway (whoami failed): {e}" + ) from e - logging.info("Return new socket manager") + sockets[key] = manager return manager + @dataclass class EmbeddingsResponse: vectors: List[List[float]] @@ -182,10 +291,23 @@ class PutConfigResponse: class DeleteConfigResponse: pass +@dataclass +class SparqlQueryResponse: + query_type: str + variables: List[str] + bindings: List[Dict[str, Any]] + ask_result: bool + triples: List[Dict[str, Any]] + +@dataclass +class GraphQLQueryResponse: + data: Any + errors: List[Dict[str, Any]] + @dataclass class GetPromptsResponse: prompts: List[str] - + @dataclass class GetPromptResponse: prompt: Dict[str, Any] @@ -194,31 +316,61 @@ class GetPromptResponse: class GetSystemPromptResponse: prompt: str + class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8000, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", + auth_issuer: str = "", + auth_resource_url: str = "", + ): self.host = host self.port = port self.websocket_url = websocket_url - self.gateway_token = gateway_token - # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) - + lifespan_with_url = partial( + app_lifespan, websocket_url=websocket_url, + ) + + # ── Security: MCP-level auth configuration ── + # The SDK requires AuthSettings whenever a token_verifier is + # present. The issuer_url tells MCP clients where to obtain + # tokens; resource_server_url identifies this server in OAuth + # protected-resource metadata. + # + # The PassthroughTokenVerifier accepts any non-empty Bearer + # token — real validation happens at the gateway. This is + # intentional: the gateway is the single source of truth for + # identity and capability checks. + from mcp.server.auth.settings import AuthSettings + + auth_settings = AuthSettings( + issuer_url=auth_issuer or f"http://{host}:{port}", + resource_server_url=auth_resource_url or f"http://{host}:{port}", + ) + self.mcp = FastMCP( - "TrustGraph", dependencies=["trustgraph-base"], - host=self.host, port=self.port, + "TrustGraph", + dependencies=["trustgraph-base"], + host=self.host, + port=self.port, lifespan=lifespan_with_url, + token_verifier=PassthroughTokenVerifier(), + auth=auth_settings, ) self._register_tools() - + def _register_tools(self): """Register all MCP tools""" - # Register all the tools that were previously registered globally self.mcp.tool()(self.embeddings) self.mcp.tool()(self.text_completion) self.mcp.tool()(self.graph_rag) self.mcp.tool()(self.agent) self.mcp.tool()(self.triples_query) + self.mcp.tool()(self.sparql_query) + self.mcp.tool()(self.graphql_query) self.mcp.tool()(self.graph_embeddings_query) self.mcp.tool()(self.get_config_all) self.mcp.tool()(self.get_config) @@ -243,67 +395,69 @@ class McpServer: self.mcp.tool()(self.load_document) self.mcp.tool()(self.remove_document) self.mcp.tool()(self.add_processing) - + def run(self): """Run the MCP server""" self.mcp.run(transport="streamable-http") + async def _get_manager(self, ctx): + """Get an authenticated WebSocket manager for the current caller. + + Extracts the Bearer token from the MCP auth context and returns + a per-token WebSocket connection to the gateway. + """ + token = _require_token() + return await get_socket_manager(ctx, token) + async def embeddings( self, - text: str, + texts: List[str], flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> EmbeddingsResponse: """ - Generate vector embeddings for the given text using TrustGraph's embedding models. - + Generate vector embeddings for the given texts using TrustGraph's embedding models. + This tool converts text into high-dimensional vectors that capture semantic meaning, enabling similarity searches, clustering, and other vector-based operations. - + Args: - text: The input text to convert into embeddings. Can be a sentence, paragraph, - or document. The text will be processed by the configured embedding model. + texts: List of input texts to convert into embeddings. Each text can be a + sentence, paragraph, or document. flow_id: Optional flow identifier to use for processing (default: "default"). Different flows may use different embedding models or configurations. - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - EmbeddingsResponse containing a list of vectors. Each vector is a list of floats - representing the text's semantic embedding in the model's vector space. - - Example usage: - - Convert a query into embeddings for similarity search - - Generate embeddings for documents before storing them - - Create embeddings for comparison with existing knowledge + EmbeddingsResponse containing a list of vectors, one per input text. """ - logging.info("Embeddings request made") + logger.info("Embeddings request") if flow_id is None: flow_id = "default" - manager = await get_socket_manager(ctx, "trustgraph") + manager = await self._get_manager(ctx) - if ctx is None: - raise RuntimeError("No context provided") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Computing embeddings via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - await ctx.session.send_log_message( - level="info", - data=f"Computing embeddings via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, + request_data = {"texts": texts} + + gen = manager.request( + "embeddings", request_data, flow_id, workspace=workspace, ) - # Send websocket request - request_data = {"text": text} - logging.info("making request") - - gen = manager.request("embeddings", request_data, flow_id) - async for response in gen: - - # Extract vectors from response vectors = response.get("vectors", [[]]) break - + return EmbeddingsResponse(vectors=vectors) async def text_completion( @@ -311,62 +465,47 @@ class McpServer: prompt: str, system: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TextCompletionResponse: """ Generate text completions using TrustGraph's language models. - - This tool sends prompts to configured language models and returns generated text. - It supports both user prompts and system instructions for controlling generation. - + Args: prompt: The main prompt or question to send to the language model. - This is the primary input that guides the model's response. system: Optional system prompt that sets the context, role, or behavior - for the AI assistant (e.g., "You are a helpful coding assistant"). - System prompts influence how the model interprets and responds. - flow_id: Optional flow identifier (default: "default"). Different flows - may use different models, parameters, or processing pipelines. - + for the AI assistant. + flow_id: Optional flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TextCompletionResponse containing the generated text response from the model. - - Example usage: - - Ask questions and get AI-generated answers - - Generate code, documentation, or creative content - - Perform text analysis, summarization, or transformation tasks - - Use system prompts to control tone, style, or domain expertise """ if system is None: system = "" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - # Use websocket if context is available - logging.info("Text completion request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Generating text completion via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Generating text completion via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Send websocket request request_data = {"system": system, "prompt": prompt} - gen = manager.request("text-completion", request_data, flow_id) + gen = manager.request( + "text-completion", request_data, flow_id, workspace=workspace, + ) async for response in gen: - - # Extract vectors from response text = response.get("response", "") break - + return TextCompletionResponse(response=text) async def graph_rag( @@ -378,58 +517,43 @@ class McpServer: max_subgraph_size: int | None = None, max_path_length: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphRagResponse: """ Perform Graph-based Retrieval Augmented Generation (GraphRAG) queries. - + GraphRAG combines knowledge graph traversal with language model generation to provide - contextually rich answers. It explores relationships between entities to build relevant - context before generating responses. - + contextually rich answers. + Args: question: The question or query to answer using the knowledge graph. - The system will find relevant entities and relationships to inform the response. collection: Knowledge collection to query (default: "default"). - Different collections may contain domain-specific knowledge. entity_limit: Maximum number of entities to retrieve during graph traversal. - Higher limits provide more context but increase processing time. triple_limit: Maximum number of relationship triples to consider. - Controls the depth of relationship exploration. max_subgraph_size: Maximum size of the subgraph to extract for context. - Larger subgraphs provide richer context but use more resources. max_path_length: Maximum path length to traverse in the knowledge graph. - Longer paths can discover distant but relevant relationships. flow_id: Processing flow to use (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: GraphRagResponse containing the generated answer informed by knowledge graph context. - - Example usage: - - Answer complex questions requiring multi-hop reasoning - - Explore relationships between entities in your knowledge base - - Generate responses grounded in structured knowledge - - Perform research queries across connected information """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("GraphRAG request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphRAG query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing GraphRAG query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "query": question } @@ -440,20 +564,19 @@ class McpServer: if max_subgraph_size: request_data["max_subgraph_size"] = max_subgraph_size if max_path_length: request_data["max_path_length"] = max_path_length - gen = manager.request("graph-rag", request_data, flow_id) + gen = manager.request( + "graph-rag", request_data, flow_id, workspace=workspace, + ) text_chunks = [] async for response in gen: - # Handle new message format with message_type message_type = response.get("message_type", "chunk") - # Only collect text from chunk messages if message_type == "chunk": chunk_text = response.get("response", "") if chunk_text: text_chunks.append(chunk_text) - # Check if session is complete if response.get("end_of_session"): break @@ -464,404 +587,447 @@ class McpServer: question: str, collection: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AgentResponse: """ Execute intelligent agent queries with reasoning and tool usage capabilities. - - The agent can perform complex multi-step reasoning, use tools, and provide - detailed thought processes. It's designed for tasks requiring planning, - analysis, and iterative problem-solving. - + Args: - question: The question or task for the agent to solve. Can be complex - queries requiring multiple steps, analysis, or tool usage. + question: The question or task for the agent to solve. collection: Knowledge collection the agent can access (default: "default"). - Determines what information and tools are available. - flow_id: Agent workflow to use (default: "default"). Different flows - may have different capabilities, tools, or reasoning strategies. - + flow_id: Agent workflow to use (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: AgentResponse containing the final answer after the agent's reasoning process. - During execution, you'll see intermediate thoughts and observations. - - Example usage: - - Solve complex analytical problems requiring multiple steps - - Perform research tasks across multiple information sources - - Handle queries that need tool usage and decision-making - - Get detailed explanations of reasoning processes - - Note: This tool provides real-time updates on the agent's thinking process - through log messages, so you can follow its reasoning steps. """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Agent request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing agent query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing agent query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "question": question } if collection: request_data["collection"] = collection - gen = manager.request("agent", request_data, flow_id) + gen = manager.request( + "agent", request_data, flow_id, workspace=workspace, + ) async for response in gen: - logging.debug(f"Agent response: {response}") + logger.debug("Agent response: %s", response) - if "thought" in response: - await ctx.session.send_log_message( - level="info", - data=f"Thinking: {response['thought']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + if "thought" in response: + await ctx.session.send_log_message( + level="info", + data=f"Thinking: {response['thought']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - if "observation" in response: - await ctx.session.send_log_message( - level="info", - data=f"Observation: {response['observation']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if "observation" in response: + await ctx.session.send_log_message( + level="info", + data=f"Observation: {response['observation']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - # Extract vectors from response if "answer" in response: answer = response.get("answer", "") return AgentResponse(answer=answer) async def triples_query( self, - s_v: str | None = None, - s_e: bool | None = None, - p_v: str | None = None, - p_e: bool | None = None, - o_v: str | None = None, - o_e: bool | None = None, + s: str | None = None, + s_type: str | None = None, + p: str | None = None, + p_type: str | None = None, + o: str | None = None, + o_type: str | None = None, + collection: str | None = None, + graph: str | None = None, limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TriplesQueryResponse: """ Query knowledge graph triples using subject-predicate-object patterns. - - Knowledge graphs store information as triples (subject, predicate, object). - This tool allows flexible querying by specifying any combination of these - components, with wildcards for unspecified parts. - + + Each of s, p, o is an RDF term value. Use the corresponding _type + parameter to specify the term kind: + - "iri" (default for s and p): an IRI / entity reference + - "literal" (default for o): a plain literal value + - "blank": a blank node identifier + Args: - s_v: Subject value to match (e.g., "John", "Apple Inc."). Leave None for wildcard. - s_e: Whether subject should be treated as an entity (True) or literal (False). - p_v: Predicate/relationship value (e.g., "works_for", "type_of"). Leave None for wildcard. - p_e: Whether predicate should be treated as an entity (True) or literal (False). - o_v: Object value to match (e.g., "Engineer", "Company"). Leave None for wildcard. - o_e: Whether object should be treated as an entity (True) or literal (False). + s: Subject value to match. Leave None for wildcard. + s_type: Subject term type: "iri" (default), "literal", or "blank". + p: Predicate value to match. Leave None for wildcard. + p_type: Predicate term type: "iri" (default), "literal", or "blank". + o: Object value to match. Leave None for wildcard. + o_type: Object term type: "iri", "literal" (default), or "blank". + collection: Knowledge collection to query (default: "default"). + graph: Named graph IRI to restrict the query. None = default graph, + "*" = all graphs. limit: Maximum number of triples to return (default: 20). flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TriplesQueryResponse containing matching triples from the knowledge graph. - - Example queries: - - Find all relationships for an entity: s_v="John", others None - - Find all instances of a relationship: p_v="works_for", others None - - Find specific facts: s_v="John", p_v="works_for", o_v=None - - Explore entity types: p_v="type_of", others None - - Use this for: - - Exploring knowledge graph structure - - Finding specific facts or relationships - - Discovering connections between entities - - Validating or debugging knowledge content """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 + if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Triples query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing triples query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing triples query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with Value objects request_data = { - "limit": limit + "limit": limit, + "collection": collection, } - # Add subject if provided - if s_v is not None: - request_data["s"] = {"v": s_v, "e": s_e } + if s is not None: + request_data["s"] = _make_term(s, s_type or "iri") - # Add predicate if provided - if p_v is not None: - request_data["p"] = {"v": p_v, "e": p_e } + if p is not None: + request_data["p"] = _make_term(p, p_type or "iri") - # Add object if provided - if o_v is not None: - request_data["o"] = {"v": o_v, "e": o_e } + if o is not None: + request_data["o"] = _make_term(o, o_type or "literal") - gen = manager.request("triples", request_data, flow_id) + if graph is not None: + request_data["g"] = graph + + gen = manager.request( + "triples", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract response data triples = response.get("response", []) break - + return TriplesQueryResponse(triples=triples) + async def sparql_query( + self, + query: str, + collection: str | None = None, + limit: int | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> SparqlQueryResponse: + """ + Execute a SPARQL query against the knowledge graph. + + Supports SELECT, ASK, CONSTRUCT, and DESCRIBE query forms. + + Args: + query: SPARQL query string (e.g. "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"). + collection: Knowledge collection to query (default: "default"). + limit: Safety limit on number of results (default: 10000). + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + SparqlQueryResponse containing the query results. The structure depends + on query type: + - SELECT: variables (column names) and bindings (rows of Term values) + - ASK: ask_result (boolean) + - CONSTRUCT/DESCRIBE: triples + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + if limit is None: limit = 10000 + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing SPARQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "limit": limit, + } + + gen = manager.request( + "sparql", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + query_type = response.get("query-type", "") + return SparqlQueryResponse( + query_type=query_type, + variables=response.get("variables", []), + bindings=response.get("bindings", []), + ask_result=response.get("ask-result", False), + triples=response.get("triples", []), + ) + + async def graphql_query( + self, + query: str, + collection: str | None = None, + variables: Dict[str, Any] | None = None, + operation_name: str | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> GraphQLQueryResponse: + """ + Execute a GraphQL query against structured data (rows). + + Queries structured data schemas that have been loaded into TrustGraph. + The available types and fields depend on the schemas configured in the + target workspace. + + Args: + query: GraphQL query string (e.g. '{ customers(where: {status: {eq: "active"}}) { id name } }'). + collection: Data collection to query (default: "default"). + variables: Optional GraphQL variables as a dict. + operation_name: Optional operation name for multi-operation documents. + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + GraphQLQueryResponse containing data (the query result) and errors + (any GraphQL field-level errors). + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "variables": variables or {}, + } + + if operation_name is not None: + request_data["operation_name"] = operation_name + + gen = manager.request( + "rows", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + return GraphQLQueryResponse( + data=response.get("data"), + errors=response.get("errors", []), + ) + async def graph_embeddings_query( self, vectors: List[List[float]], limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphEmbeddingsQueryResponse: """ Find entities in the knowledge graph using vector similarity search. - - This tool performs semantic search by comparing embedding vectors to find - the most similar entities in the knowledge graph. It's useful for finding - conceptually related information even when exact text matches don't exist. - + Args: - vectors: List of embedding vectors to search with. Each vector should be - a list of floats representing semantic embeddings (typically from - the embeddings tool). Multiple vectors can be provided for batch queries. + vectors: List of embedding vectors to search with. limit: Maximum number of similar entities to return (default: 20). - Higher limits provide more results but may include less relevant matches. flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - GraphEmbeddingsQueryResponse containing entities ranked by similarity to the - input vectors, along with similarity scores and entity metadata. - - Example workflow: - 1. Use the 'embeddings' tool to convert text to vectors - 2. Use this tool to find similar entities in the knowledge graph - 3. Explore the returned entities for relevant information - - Use this for: - - Semantic search across knowledge entities - - Finding conceptually similar content - - Discovering related entities without exact keyword matches - - Building recommendation systems based on entity similarity + GraphEmbeddingsQueryResponse containing entities ranked by similarity. """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Graph embeddings query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing graph embeddings query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing graph embeddings query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data request_data = { "vectors": vectors, "limit": limit } - gen = manager.request("graph-embeddings", request_data, flow_id) + gen = manager.request( + "graph-embeddings", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract entities from response entities = response.get("entities", []) break - + return GraphEmbeddingsQueryResponse(entities=entities) async def get_config_all( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigResponse: """ Retrieve the complete TrustGraph system configuration. - - This tool returns all configuration settings for the TrustGraph system, - including model configurations, API keys, flow definitions, and system parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigResponse containing the full configuration as a nested dictionary - with all system settings, organized by category (e.g., models, flows, storage). - - Use this for: - - Inspecting current system configuration - - Debugging configuration issues - - Understanding available models and settings - - Auditing system setup and parameters + ConfigResponse containing the full configuration as a nested dictionary. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config all request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving all configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving all configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) break - + return ConfigResponse(config=config) async def get_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> ConfigGetResponse: """ Retrieve specific configuration values by key. - - This tool allows you to fetch specific configuration settings without - retrieving the entire configuration. Useful for checking particular - settings or API keys. - + Args: - keys: List of configuration keys to retrieve. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings', 'storage') - - 'key': Specific setting name within that category - + keys: List of configuration keys to retrieve. Each key should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: ConfigGetResponse containing the requested configuration values. - - Example keys: - - {'type': 'llm', 'key': 'openai.model'} - - {'type': 'embeddings', 'key': 'default.model'} - - {'type': 'storage', 'key': 'database.url'} - - Use this for: - - Checking specific model configurations - - Validating API key settings - - Inspecting individual system parameters """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving specific configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving specific configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) break - + return ConfigGetResponse(values=values) async def put_config( self, values: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> PutConfigResponse: """ Update system configuration values. - - This tool allows you to modify TrustGraph system settings, such as - model parameters, API keys, and system behavior configurations. - + Args: - values: List of configuration updates. Each update should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to update - - 'value': New value for the setting - + values: List of configuration updates. Each should be a dict with + 'type', 'key', and 'value' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: PutConfigResponse confirming the configuration update. - - Example updates: - - {'type': 'llm', 'key': 'openai.model', 'value': 'gpt-4'} - - {'type': 'embeddings', 'key': 'batch_size', 'value': '100'} - - Use this for: - - Switching between different models - - Updating API credentials - - Modifying system behavior parameters - - Configuring processing settings - - Note: Configuration changes may require system restart to take effect. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Put config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Updating configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Updating configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "put", "values": values } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return PutConfigResponse() @@ -869,97 +1035,73 @@ class McpServer: async def delete_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> DeleteConfigResponse: """ Delete specific configuration entries from the system. - - This tool removes configuration settings, reverting them to system defaults - or disabling specific features. - + Args: - keys: List of configuration keys to delete. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to remove - + keys: List of configuration keys to delete. Each should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: DeleteConfigResponse confirming the deletion. - - Use this for: - - Removing custom model configurations - - Clearing API credentials - - Resetting settings to defaults - - Cleaning up obsolete configurations - - Warning: Deleting essential configuration may cause system functionality - to be disabled until properly reconfigured. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Deleting configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Deleting configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return DeleteConfigResponse() async def get_prompts( self, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptsResponse: """ List all available prompt templates in the system. - - Prompt templates are reusable prompts that can be used with language models - for consistent behavior across different queries and use cases. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: GetPromptsResponse containing a list of available prompt template IDs. - Each ID can be used with get_prompt to retrieve the full template. - - Use this for: - - Discovering available prompt templates - - Exploring pre-configured prompts for different tasks - - Finding templates for specific use cases - - Understanding what prompt options are available """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompts request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving prompt templates via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt templates via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -971,49 +1113,36 @@ class McpServer: async def get_prompt( self, prompt_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptResponse: """ Retrieve a specific prompt template by ID. - - Prompt templates contain structured prompts with placeholders, instructions, - and metadata for specific tasks or domains. - + Args: prompt_id: The unique identifier of the prompt template to retrieve. - Use get_prompts to see available template IDs. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetPromptResponse containing the complete prompt template with its - structure, placeholders, and usage instructions. - - Use this for: - - Examining prompt template structure - - Understanding how to use specific templates - - Copying or modifying existing prompts - - Learning prompt engineering patterns + GetPromptResponse containing the complete prompt template. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving prompt template '{prompt_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt template '{prompt_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1025,44 +1154,35 @@ class McpServer: async def get_system_prompt( self, + workspace: str | None = None, ctx: Context = None, ) -> GetSystemPromptResponse: """ Retrieve the current system prompt configuration. - - The system prompt defines the default behavior, personality, and instructions - for language models across the TrustGraph system. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetSystemPromptResponse containing the system prompt text and configuration. - - Use this for: - - Understanding default AI behavior settings - - Checking current system-wide prompt configuration - - Auditing AI personality and instruction settings - - Debugging unexpected AI responses + GetSystemPromptResponse containing the system prompt text. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get system prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving system prompt via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving system prompt via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1073,51 +1193,39 @@ class McpServer: async def get_token_costs( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigTokenCostsResponse: """ Retrieve token pricing information for all configured AI models. - - This tool provides cost information for input and output tokens across - different language models, helping with budget planning and cost optimization. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigTokenCostsResponse containing pricing data for each model including: - - Model name/identifier - - Input token cost (per token) - - Output token cost (per token) - - Use this for: - - Estimating costs for different models - - Choosing cost-effective models for tasks - - Budget planning and cost analysis - - Monitoring and optimizing AI spending + ConfigTokenCostsResponse containing pricing data for each model. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get token costs request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving token costs via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving token costs via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "getvalues", "type": "token-costs" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) - # Transform to match TypeScript API format costs = [] for item in values: try: @@ -1130,106 +1238,89 @@ class McpServer: except (json.JSONDecodeError, AttributeError): continue break - + return ConfigTokenCostsResponse(costs=costs) async def get_knowledge_cores( self, + workspace: str | None = None, ctx: Context = None, ) -> KnowledgeCoresResponse: """ List all available knowledge graph cores in the current workspace. - Knowledge cores are packaged collections of structured knowledge that can - be loaded into the system for querying and reasoning. They contain entities, - relationships, and facts organized as knowledge graphs. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: KnowledgeCoresResponse containing a list of available knowledge core IDs. - - Use this for: - - Discovering available knowledge collections - - Understanding what knowledge domains are accessible - - Planning which cores to load for specific tasks - - Managing knowledge resources """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get knowledge cores request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph cores via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving knowledge graph cores via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-kg-cores", } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: ids = response.get("ids", []) break - + return KnowledgeCoresResponse(ids=ids) async def delete_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> DeleteKgCoreResponse: """ Permanently delete a knowledge graph core. - This operation removes a knowledge core from storage. Use with caution - as this action cannot be undone. - Args: core_id: Unique identifier of the knowledge core to delete. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: DeleteKgCoreResponse confirming the deletion. - - Use this for: - - Cleaning up obsolete knowledge cores - - Removing test or experimental data - - Managing storage space - - Maintaining organized knowledge collections - - Warning: This permanently deletes the knowledge core and all its data. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Deleting knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Deleting knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete-kg-core", "id": core_id, } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return DeleteKgCoreResponse() async def load_kg_core( @@ -1237,46 +1328,34 @@ class McpServer: core_id: str, flow: str, collection: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadKgCoreResponse: """ Load a knowledge graph core into the active system for querying. - This operation makes a knowledge core available for GraphRAG queries, - triple searches, and other knowledge-based operations. - Args: core_id: Unique identifier of the knowledge core to load. - flow: Processing flow to use for loading the core. Different flows - may apply different processing, indexing, or optimization steps. - collection: Target collection name (default: "default"). The loaded - knowledge will be available under this collection name. + flow: Processing flow to use for loading the core. + collection: Target collection name (default: "default"). + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadKgCoreResponse confirming the core has been loaded. - - Use this for: - - Making knowledge cores available for queries - - Switching between different knowledge domains - - Loading domain-specific knowledge for tasks - - Preparing knowledge for GraphRAG operations """ if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Loading knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "load-kg-core", @@ -1285,292 +1364,241 @@ class McpServer: "collection": collection } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadKgCoreResponse() async def get_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetKgCoreResponse: """ Download and retrieve the complete content of a knowledge graph core. - This tool streams the entire content of a knowledge core, returning all - entities, relationships, and metadata. Due to potentially large data sizes, - the content is streamed in chunks. - Args: core_id: Unique identifier of the knowledge core to retrieve. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: GetKgCoreResponse containing all chunks of the knowledge core data. - Each chunk contains part of the knowledge graph structure. - - Use this for: - - Examining knowledge core content and structure - - Debugging knowledge graph data - - Exporting knowledge for backup or analysis - - Understanding the scope and quality of knowledge - - Note: Large knowledge cores may take significant time to download. - Progress updates are provided through log messages during streaming. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-kg-core", "id": core_id, } - # Collect all streaming responses chunks = [] - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: - # Check for end of stream if response.get("eos", False): - await ctx.session.send_log_message( - level="info", - data=f"Completed streaming KG core data", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Completed streaming KG core data", + logger="notification_stream", + related_request_id=ctx.request_id, + ) break else: chunks.append(response) - await ctx.session.send_log_message( - level="info", - data=f"Received KG core chunk ({len(chunks)} chunks so far)", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Received KG core chunk ({len(chunks)} chunks so far)", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + return GetKgCoreResponse(chunks=chunks) async def get_flows( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowsResponse: """ List all available processing flows in the system. - - Flows define processing pipelines for different types of operations - (e.g., document processing, knowledge extraction, query handling). - Each flow encapsulates a specific workflow with configured steps. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowsResponse containing a list of available flow identifiers. - - Use this for: - - Discovering available processing workflows - - Understanding what processing options are available - - Choosing appropriate flows for specific tasks - - Planning workflow-based operations """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flows request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving available flows via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving available flows via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-flows" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_ids = response.get("flow-ids", []) break - + return FlowsResponse(flow_ids=flow_ids) async def get_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowResponse: """ Retrieve the complete definition of a specific processing flow. - - This tool returns the detailed configuration, steps, and parameters - of a processing flow, showing how it processes data and what operations it performs. - + Args: flow_id: Unique identifier of the flow to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowResponse containing the complete flow definition including: - - Flow configuration and parameters - - Processing steps and their order - - Input/output specifications - - Dependencies and requirements - - Use this for: - - Understanding how specific flows work - - Debugging flow processing issues - - Learning flow configuration patterns - - Customizing or duplicating flows + FlowResponse containing the complete flow definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow definition for '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow definition for '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-flow", "flow-id": flow_id, } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_data = response.get("flow", "{}") - # Parse JSON flow definition as done in TypeScript flow = json.loads(flow_data) if isinstance(flow_data, str) else flow_data break - + return FlowResponse(flow=flow) async def get_flow_classes( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassesResponse: """ List all available flow class templates. - - Flow classes are templates that define types of processing workflows. - They serve as blueprints for creating specific flow instances with - customized parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowClassesResponse containing a list of available flow class names. - - Use this for: - - Discovering available flow templates - - Understanding what types of processing are supported - - Planning new flow creation - - Exploring system capabilities """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow classes request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow classes via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving flow classes via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-classes" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_names = response.get("class-names", []) break - + return FlowClassesResponse(class_names=class_names) async def get_flow_class( self, class_name: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassResponse: """ Retrieve the definition of a specific flow class template. - - Flow classes define the structure, parameters, and capabilities of - flow types. This tool returns the class specification including - configurable parameters and processing logic. - + Args: class_name: Name of the flow class to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowClassResponse containing the flow class definition with: - - Class parameters and configuration options - - Processing capabilities and requirements - - Usage instructions and examples - - Use this for: - - Understanding flow class capabilities - - Learning how to configure new flows - - Troubleshooting flow creation issues - - Exploring advanced flow features + FlowClassResponse containing the flow class definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow class request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow class definition for '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow class definition for '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-class", "class-name": class_name } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_def_data = response.get("class-definition", "{}") - # Parse JSON class definition as done in TypeScript class_definition = json.loads(class_def_data) if isinstance(class_def_data, str) else class_def_data break - + return FlowClassResponse(class_definition=class_definition) async def start_flow( @@ -1578,43 +1606,32 @@ class McpServer: flow_id: str, class_name: str, description: str, + workspace: str | None = None, ctx: Context = None, ) -> StartFlowResponse: """ Create and start a new processing flow instance. - - This tool creates a new flow based on a flow class template and starts - it running. The flow will begin processing according to its configuration. - + Args: flow_id: Unique identifier for the new flow instance. class_name: Flow class template to use for creating the flow. - Use get_flow_classes to see available classes. description: Human-readable description of the flow's purpose. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StartFlowResponse confirming the flow has been started. - - Use this for: - - Creating new processing workflows - - Starting automated processing tasks - - Launching background operations - - Initiating data processing pipelines """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Start flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "start-flow", @@ -1623,162 +1640,135 @@ class McpServer: "description": description } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StartFlowResponse() async def stop_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> StopFlowResponse: """ Stop a running flow instance. - - This tool gracefully stops a running flow, allowing it to complete - current operations before shutting down. - + Args: flow_id: Unique identifier of the flow instance to stop. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StopFlowResponse confirming the flow has been stopped. - - Use this for: - - Stopping unwanted or completed flows - - Managing system resources - - Interrupting long-running processes - - Maintaining flow lifecycle """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Stop flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Stopping flow '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Stopping flow '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "stop-flow", "flow-id": flow_id } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StopFlowResponse() async def get_documents( self, + workspace: str | None = None, ctx: Context = None, ) -> DocumentsResponse: """ List all documents stored in the TrustGraph document library. - This tool returns metadata for all documents that have been uploaded - to the system, including their processing status and properties. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - DocumentsResponse containing metadata for each document including: - - Document ID and title - - Upload timestamp - - MIME type and size information - - Tags and custom metadata - - Processing status - - Use this for: - - Browsing available documents - - Managing document collections - - Finding documents by metadata - - Auditing document storage + DocumentsResponse containing metadata for each document. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get documents request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving documents list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving documents list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-documents", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: document_metadatas = response.get("document-metadatas", []) break - + return DocumentsResponse(document_metadatas=document_metadatas) async def get_processing( self, + workspace: str | None = None, ctx: Context = None, ) -> ProcessingResponse: """ List all documents currently in the processing queue. - This tool shows documents that are being processed or waiting to be - processed, along with their processing status and configuration. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - ProcessingResponse containing processing metadata including: - - Processing job ID and document ID - - Processing flow and status - - Target collection - - Timestamp and progress information - - Use this for: - - Monitoring document processing progress - - Debugging processing issues - - Managing processing queues - - Understanding system workload + ProcessingResponse containing processing metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving processing list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving processing list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-processing", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: processing_metadatas = response.get("processing-metadatas", []) break - + return ProcessingResponse(processing_metadatas=processing_metadatas) async def load_document( @@ -1790,50 +1780,39 @@ class McpServer: title: str = "", comments: str = "", tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadDocumentResponse: """ Upload a document to the TrustGraph document library. - This tool stores documents with rich metadata for later processing, - search, and knowledge extraction. Documents can be text files, PDFs, - or other supported formats. - Args: document: The document content as a string. For binary files, this should be base64-encoded content. document_id: Optional unique identifier. If not provided, one will be generated. metadata: Optional list of custom metadata key-value pairs. - mime_type: MIME type of the document (e.g., 'text/plain', 'application/pdf'). + mime_type: MIME type of the document. title: Human-readable title for the document. comments: Optional description or notes about the document. - tags: List of tags for categorizing and finding the document. + tags: List of tags for categorizing the document. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadDocumentResponse confirming the document has been stored. - - Use this for: - - Adding new documents to the knowledge base - - Storing reference materials and data sources - - Building document collections for processing - - Importing external content for analysis """ if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading document to library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Loading document to library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1852,63 +1831,55 @@ class McpServer: "content": document } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadDocumentResponse() async def remove_document( self, document_id: str, + workspace: str | None = None, ctx: Context = None, ) -> RemoveDocumentResponse: """ Permanently remove a document from the library. - This operation deletes a document and all its associated metadata. - Use with caution as this action cannot be undone. - Args: document_id: Unique identifier of the document to remove. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: RemoveDocumentResponse confirming the document has been deleted. - - Use this for: - - Cleaning up obsolete or incorrect documents - - Managing storage space - - Removing sensitive or inappropriate content - - Maintaining organized document collections - - Warning: This permanently deletes the document and all its metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Remove document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Removing document '{document_id}' from library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Removing document '{document_id}' from library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "remove-document", "document-id": document_id, } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return RemoveDocumentResponse() async def add_processing( @@ -1918,53 +1889,37 @@ class McpServer: flow: str, collection: str | None = None, tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AddProcessingResponse: """ Queue a document for processing through a specific workflow. - This tool adds a document to the processing queue where it will be - processed by the specified flow to extract knowledge, create embeddings, - or perform other analysis operations. - Args: processing_id: Unique identifier for this processing job. document_id: ID of the document to process (must exist in library). - flow: Processing flow to use. Different flows perform different - types of analysis (e.g., knowledge extraction, summarization). + flow: Processing flow to use. collection: Target collection for processed knowledge (default: "default"). - Results will be stored under this collection name. tags: Optional tags for categorizing this processing job. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: AddProcessingResponse confirming the document has been queued. - - Use this for: - - Processing uploaded documents into knowledge - - Extracting entities and relationships from text - - Creating searchable embeddings - - Converting documents into structured knowledge - - Note: Processing may take time depending on document size and flow complexity. - Use get_processing to monitor progress. """ if collection is None: collection = "default" if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Add processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Adding document '{document_id}' to processing queue via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Adding document '{document_id}' to processing queue via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1981,38 +1936,61 @@ class McpServer: } } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return AddProcessingResponse() + def main(): parser = argparse.ArgumentParser(description='TrustGraph MCP Server') - parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)') - parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)') - parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)') + parser.add_argument( + '--host', default='0.0.0.0', + help='Host to bind to (default: 0.0.0.0)', + ) + parser.add_argument( + '--port', type=int, default=8000, + help='Port to bind to (default: 8000)', + ) + parser.add_argument( + '--websocket-url', + default='ws://api-gateway:8088/api/v1/socket', + help='WebSocket URL for the TrustGraph gateway', + ) + parser.add_argument( + '--auth-issuer', + default=os.environ.get("AUTH_ISSUER", ""), + help='OAuth issuer URL for MCP auth metadata discovery', + ) + parser.add_argument( + '--auth-resource-url', + default=os.environ.get("AUTH_RESOURCE_URL", ""), + help='Resource server URL for OAuth protected resource metadata', + ) - # Add logging arguments add_logging_args(parser) args = parser.parse_args() - # Setup logging before creating server setup_logging(vars(args)) - # Read gateway auth token from environment - gateway_token = os.environ.get("GATEWAY_SECRET", "") - - # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) + server = McpServer( + host=args.host, + port=args.port, + websocket_url=args.websocket_url, + auth_issuer=args.auth_issuer, + auth_resource_url=args.auth_resource_url, + ) server.run() + def run(): - """Legacy function for backward compatibility""" main() + if __name__ == "__main__": main() - diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index bff8ae75..9fbf7459 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,49 +1,110 @@ -from dataclasses import dataclass from websockets.asyncio.client import connect -from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json import uuid -import time +import hashlib + +logger = logging.getLogger(__name__) + + +def _token_key(token): + """Derive a dict key from a token without storing the raw secret.""" + return hashlib.sha256(token.encode()).hexdigest()[:16] + class WebSocketManager: + """Manages an authenticated WebSocket connection to the TrustGraph + gateway on behalf of a single caller. - def __init__(self, url, token=None): + Each caller token gets its own WebSocketManager so that gateway-side + identity, workspace, and capability scoping are preserved end-to-end. + """ + + def __init__(self, url, token): self.url = url + # ── Security boundary: token storage ── + # This is the MCP caller's Bearer token, forwarded verbatim to + # the gateway. It MUST NOT be logged, persisted, or shared + # across callers. It is held only for the lifetime of this + # connection so that re-auth (e.g. after a reconnect) is + # possible. self.token = token self.socket = None - - # FIXME: authentication is broken. The /api/v1/socket endpoint uses - # in-band auth (first-frame protocol via the Mux dispatcher), not - # query-parameter tokens. This query-string token is silently ignored. - # Fix: after connect(), send an auth frame with the bearer token as - # the first message, matching the gateway's in-band auth protocol. - def _build_url(self): - if not self.token: - return self.url - parsed = urlparse(self.url) - params = parse_qs(parsed.query) - params["token"] = [self.token] - new_query = urlencode(params, doseq=True) - return urlunparse(parsed._replace(query=new_query)) + self.identity = None + self.last_used = None async def start(self): - self.socket = await connect(self._build_url()) + """Connect and authenticate via the gateway's in-band auth + protocol. Raises on auth failure.""" + + # ── Security boundary: MCP server → gateway ── + # The WebSocket connects to the gateway and authenticates using + # the caller's Bearer token via the in-band first-frame auth + # protocol. The token belongs to the MCP client — we forward + # it as-is and never interpret its contents. + self.socket = await connect(self.url) self.pending_requests = {} self.running = True + + await self._authenticate() + self.reader_task = asyncio.create_task(self.reader()) + async def _authenticate(self): + """Send in-band auth frame and wait for auth-ok / auth-failed. + + The gateway expects ``{"type": "auth", "token": "..."}`` as the + first frame on a new WebSocket. Any service frame sent before + auth-ok is rejected. + """ + await self.socket.send(json.dumps({ + "type": "auth", + "token": self.token, + })) + + response_text = await asyncio.wait_for(self.socket.recv(), 10) + response = json.loads(response_text) + + if response.get("type") == "auth-ok": + logger.info( + "WebSocket authenticated, default workspace: %s", + response.get("workspace"), + ) + return + + # Auth failed — close immediately, do not leave an + # unauthenticated socket open. + await self.socket.close() + self.socket = None + + if response.get("type") == "auth-failed": + raise RuntimeError( + "Gateway rejected the authentication token" + ) + + raise RuntimeError( + f"Unexpected auth response type: {response.get('type')}" + ) + + async def whoami(self): + """Verify the token by calling the gateway's whoami endpoint. + Returns the identity dict and caches it on ``self.identity``. + """ + gen = self.request("iam", {"operation": "whoami"}, flow_id=None) + async for response in gen: + self.identity = response + return response + async def stop(self): self.running = False - await self.reader_task + if hasattr(self, "reader_task"): + await self.reader_task async def reader(self): - """ - Background task to read websocket responses and route to correct - request - """ + """Background task: read WebSocket frames and route them to the + correct pending-request queue by ``id``.""" while self.running: try: @@ -59,23 +120,21 @@ class WebSocketManager: request_id = response.get("id") if request_id and request_id in self.pending_requests: - # Put the response in the queue queue = self.pending_requests[request_id] await queue.put(response) else: - logging.warning( - f"Response for unknown request ID: {request_id}" + logger.warning( + "Response for unknown request ID: %s", request_id ) except Exception as e: - logging.error(f"Error in websocket reader: {e}") + logger.error("Error in websocket reader: %s", e) - # Put error in all pending queues for queue in self.pending_requests.values(): try: await queue.put({"error": str(e)}) - except: + except Exception: pass self.pending_requests.clear() @@ -86,25 +145,29 @@ class WebSocketManager: async def request( self, service, request_data, flow_id="default", + workspace=None, ): - """ - Send a request via websocket and handle single or streaming responses + """Send a request via WebSocket and yield responses. + + Args: + service: Gateway service name (e.g. "graph-rag", "config"). + request_data: Inner request payload. + flow_id: Optional flow identifier. ``None`` omits the field + (workspace-level services don't use flows). + workspace: Optional workspace override. When ``None`` the + gateway uses the caller's default workspace. """ - # Generate unique request ID + import time + self.last_used = time.monotonic() + request_id = f"{uuid.uuid4()}" - # Determine if this service streams responses - streaming_services = {"agent"} - is_streaming = service in streaming_services - - # Create a queue for all responses (streaming and single) response_queue = asyncio.Queue() self.pending_requests[request_id] = response_queue try: - # Build request message message = { "id": request_id, "service": service, @@ -114,7 +177,16 @@ class WebSocketManager: if flow_id is not None: message["flow"] = flow_id - # Send request + # ── Security boundary: workspace scoping ── + # When the caller supplies a workspace, we set it on the + # message envelope. The gateway's enforce_workspace() + # validates that the authenticated identity is permitted + # to access the target workspace — we MUST NOT skip or + # override that check. When workspace is None, the + # gateway default-fills from the identity's bound workspace. + if workspace is not None: + message["workspace"] = workspace + await self.socket.send(json.dumps(message)) while self.running: @@ -127,19 +199,17 @@ class WebSocketManager: continue if "error" in response: - if "message" in response["error"]: - raise RuntimeError(response["error"]["text"]) + if isinstance(response["error"], dict): + raise RuntimeError( + response["error"].get("message", str(response["error"])) + ) else: raise RuntimeError(str(response["error"])) yield response["response"] - if "complete" in response: - if response["complete"]: - break + if response.get("complete"): + break - except Exception as e: - # Clean up on error + finally: self.pending_requests.pop(request_id, None) - raise e - diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index 1b4815c6..0d5101df 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -107,7 +107,14 @@ class Processor(FlowProcessor): # Get the source document ID source_doc_id = v.document_id or v.metadata.id - pages = convert_from_bytes(blob) + try: + pages = convert_from_bytes(blob) + except Exception as e: + logger.error( + f"Failed to decode PDF {source_doc_id}: " + f"{type(e).__name__}: {e}" + ) + return for ix, page in enumerate(pages): diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index b4936786..deedb7b4 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -418,7 +418,14 @@ class Processor(FlowProcessor): doc_uri_str = document_uri(source_doc_id) # Extract elements using unstructured - elements = self.extract_elements(blob, mime_type) + try: + elements = self.extract_elements(blob, mime_type) + except Exception as e: + logger.error( + f"Failed to extract elements from {source_doc_id}: " + f"{type(e).__name__}: {e}" + ) + return if not elements: logger.warning("No elements extracted from document")