mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-12 00:05:13 +02:00
Compare commits
20 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
627c669097 | ||
|
|
81d57826c8 | ||
|
|
79d7ef6a90 | ||
|
|
28a51c244f | ||
|
|
fa5ebe2393 | ||
|
|
e1c9351454 | ||
|
|
dbc21c0bb9 | ||
|
|
08bfec1539 | ||
|
|
4913f8c2eb | ||
|
|
acf182c265 | ||
|
|
6df7471a55 | ||
|
|
aa158e1ba3 | ||
|
|
60f861bac4 | ||
|
|
00bb964e93 | ||
|
|
6b1dd16f9f | ||
|
|
97453d9b83 | ||
|
|
7e1fb76bc9 | ||
|
|
6dfa47aac8 | ||
|
|
dcee842455 | ||
|
|
36eadbda3a |
58 changed files with 3301 additions and 1455 deletions
32
README.md
32
README.md
|
|
@ -11,11 +11,11 @@
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
# The agent runtime platform
|
# The semantic deployment platform
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
|
TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads.
|
||||||
|
|
||||||
The platform:
|
The platform:
|
||||||
- [x] Multi-model and multimodal database system
|
- [x] Multi-model and multimodal database system
|
||||||
|
|
@ -99,23 +99,21 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
|
||||||
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
|
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
|
||||||
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
|
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
|
||||||
|
|
||||||
## Workbench
|
## Context Graph UI
|
||||||
|
|
||||||
The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
|
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" />
|
||||||
|
|
||||||
- **Vector Search**: Search the installed knowledge bases
|
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
|
||||||
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
|
|
||||||
- **Relationships**: Analyze deep relationships in the installed knowledge bases
|
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
|
||||||
- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
|
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
|
||||||
- **Library**: Staging area for installing knowledge bases
|
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
|
||||||
- **Flow Classes**: Workflow preset configurations
|
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
|
||||||
- **Flows**: Create custom workflows and adjust LLM parameters during runtime
|
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
|
||||||
- **Knowledge Cores**: Manage resuable knowledge bases
|
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
|
||||||
- **Prompts**: Manage and adjust prompts during runtime
|
- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout
|
||||||
- **Schemas**: Define custom schemas for structured data knowledge bases
|
- **Workspace UX** — Workspace selection and management surfaced directly in the interface
|
||||||
- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
|
- **Prompt Editor** — A dedicated prompt editing workflow
|
||||||
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
|
|
||||||
- **MCP Tools**: Connect to MCP servers
|
|
||||||
|
|
||||||
## TypeScript Library for UIs
|
## TypeScript Library for UIs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base
|
||||||
|
|
||||||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
|
|
||||||
RUN dnf install -y python3.13 libxcb mesa-libGL && \
|
RUN dnf install -y python3.13 libxcb mesa-libGL poppler-utils && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
|
|
|
||||||
535
docs/tech-specs/knowledge-core-completeness.md
Normal file
535
docs/tech-specs/knowledge-core-completeness.md
Normal file
|
|
@ -0,0 +1,535 @@
|
||||||
|
---
|
||||||
|
layout: default
|
||||||
|
title: "Knowledge Core Completeness"
|
||||||
|
parent: "Tech Specs"
|
||||||
|
---
|
||||||
|
|
||||||
|
# Knowledge Core Completeness
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Knowledge cores are portable snapshots of extracted knowledge: triples, graph
|
||||||
|
embeddings, and document embeddings stored in Cassandra's `knowledge` keyspace.
|
||||||
|
They can be downloaded as files, transferred between TrustGraph instances, and
|
||||||
|
loaded back into vector and graph stores.
|
||||||
|
|
||||||
|
Recent additions to TrustGraph — explainability/provenance and named graphs —
|
||||||
|
were not carried through to the knowledge core system. This means that
|
||||||
|
exporting and re-importing a core loses provenance links, graph assignments,
|
||||||
|
and source material, breaking the explainability chain.
|
||||||
|
|
||||||
|
This specification addresses three gaps:
|
||||||
|
|
||||||
|
1. **Named graphs not stored** — The `g` (graph name) field on triples is
|
||||||
|
silently dropped when writing to the core store and comes back as `None`
|
||||||
|
on read.
|
||||||
|
2. **Provenance triples not captured** — Provenance triples (PROV-O) are
|
||||||
|
generated during extraction and flow to graph stores, but never enter
|
||||||
|
the knowledge core store. It is unclear whether they arrive at the store
|
||||||
|
in the correct form.
|
||||||
|
3. **Source material not included** — Documents, text pages, and chunks in
|
||||||
|
the librarian's bucket store are not part of the core. After loading a
|
||||||
|
core on a different instance, provenance links to source material point
|
||||||
|
at nothing.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- **Self-contained cores**: A downloaded knowledge core file contains
|
||||||
|
everything needed to reconstruct the full knowledge graph including
|
||||||
|
provenance and source attribution on a fresh instance.
|
||||||
|
- **Named graph preservation**: Round-tripping a core preserves graph
|
||||||
|
assignments on all triples.
|
||||||
|
- **Backward compatibility**: Existing core files (without graph names or
|
||||||
|
source material) can still be uploaded and loaded. New fields are optional
|
||||||
|
on import.
|
||||||
|
- **No change to core identity**: A core is still identified by its document
|
||||||
|
ID. The additional data is associated with the same core ID.
|
||||||
|
- **Minimal file format changes**: Extend the existing msgpack record format
|
||||||
|
with new record types rather than restructuring existing ones.
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
### Current Lifecycle
|
||||||
|
|
||||||
|
```
|
||||||
|
Extraction pipeline
|
||||||
|
│
|
||||||
|
├─ triples ──────────────────► knowledge core store (Cassandra)
|
||||||
|
├─ graph embeddings ─────────► knowledge core store (Cassandra)
|
||||||
|
├─ document embeddings ──────► knowledge core store (Cassandra)
|
||||||
|
├─ provenance triples ───────► graph store (only)
|
||||||
|
└─ source documents ─────────► librarian bucket store (only)
|
||||||
|
|
||||||
|
Download: Cassandra ──► knowledge manager ──► API gateway ──► client file
|
||||||
|
Upload: client file ──► API gateway ──► knowledge manager ──► Cassandra
|
||||||
|
Load: Cassandra ──► knowledge manager ──► Pulsar topics ──► graph/vector stores
|
||||||
|
```
|
||||||
|
|
||||||
|
### Current Core File Format (msgpack)
|
||||||
|
|
||||||
|
A core file is a sequence of concatenated msgpack records. Each record is a
|
||||||
|
2-element tuple: `(type_tag, payload)`.
|
||||||
|
|
||||||
|
| Type tag | Payload | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `"t"` | `{"m": {id, root, collection}, "t": [triple_dicts]}` | Triple batch |
|
||||||
|
| `"ge"` | `{"m": {id, root, collection}, "e": [{entity, vector}]}` | Graph embedding batch |
|
||||||
|
|
||||||
|
### What's Missing
|
||||||
|
|
||||||
|
#### Named Graphs
|
||||||
|
|
||||||
|
The `Triple` dataclass has a `g: str | None` field (graph name IRI), used to
|
||||||
|
separate provenance graphs (`urn:graph:source`, `urn:graph:retrieval`) from
|
||||||
|
the default graph. However:
|
||||||
|
|
||||||
|
- **Cassandra schema** (`knowledge.triples` table): stores a 6-tuple per
|
||||||
|
triple `(s_val, s_is_uri, p_val, p_is_uri, o_val, o_is_uri)` — no graph
|
||||||
|
field.
|
||||||
|
- **`add_triples()`** (`tables/knowledge.py:231`): destructures only `s`,
|
||||||
|
`p`, `o` — `g` is discarded.
|
||||||
|
- **`get_triples()`** (`tables/knowledge.py:396`): reconstructs `Triple`
|
||||||
|
with `g` defaulting to `None`.
|
||||||
|
- **Core file format**: triple dicts do not include a graph field.
|
||||||
|
|
||||||
|
#### Provenance Triples
|
||||||
|
|
||||||
|
Provenance triples are generated in the extraction pipeline
|
||||||
|
(`trustgraph-base/trustgraph/provenance/triples.py`) and published to graph
|
||||||
|
store topics. They use named graphs (`urn:graph:source`,
|
||||||
|
`urn:graph:retrieval`) and PROV-O vocabulary.
|
||||||
|
|
||||||
|
The knowledge core store processor (`storage/knowledge/store.py`) listens on
|
||||||
|
`triples-input` and `graph-embeddings-input`. Whether provenance triples
|
||||||
|
arrive on the same `triples-input` topic or a separate one needs
|
||||||
|
verification. Even if they do arrive, the graph name would be lost (per
|
||||||
|
above).
|
||||||
|
|
||||||
|
#### Source Material
|
||||||
|
|
||||||
|
The librarian stores the full document hierarchy in a separate system:
|
||||||
|
|
||||||
|
- **Blob store** (S3/MinIO): original documents, text pages, chunks —
|
||||||
|
keyed by object UUID under `doc/{object_id}`.
|
||||||
|
- **Cassandra `library` keyspace**: document metadata including `id`,
|
||||||
|
`kind` (MIME type), `title`, `parent_id`, `document_type`
|
||||||
|
(`source`/`extracted`), `object_id` (blob reference).
|
||||||
|
|
||||||
|
Provenance triples link extracted facts back to chunk/page/document IDs.
|
||||||
|
Those IDs resolve through the librarian. When a core is loaded on a
|
||||||
|
different instance, the librarian has no matching documents, so the entire
|
||||||
|
provenance chain is broken.
|
||||||
|
|
||||||
|
### Key Source Files
|
||||||
|
|
||||||
|
| Component | File | Purpose |
|
||||||
|
|-----------|------|---------|
|
||||||
|
| Core Cassandra schema | `trustgraph-flow/trustgraph/tables/knowledge.py` | Table definitions, read/write |
|
||||||
|
| Core manager | `trustgraph-flow/trustgraph/cores/knowledge.py` | API operations, load-to-store |
|
||||||
|
| Core store processor | `trustgraph-flow/trustgraph/storage/knowledge/store.py` | Extraction → Cassandra |
|
||||||
|
| CLI download | `trustgraph-cli/trustgraph/cli/get_kg_core.py` | Core → msgpack file |
|
||||||
|
| CLI upload | `trustgraph-cli/trustgraph/cli/put_kg_core.py` | Msgpack file → core |
|
||||||
|
| CLI load | `trustgraph-cli/trustgraph/cli/load_kg_core.py` | Core → graph/vector stores |
|
||||||
|
| API client | `trustgraph-base/trustgraph/api/knowledge.py` | Client-side knowledge API |
|
||||||
|
| Triple schema | `trustgraph-base/trustgraph/schema/core/primitives.py` | Triple dataclass with `g` field |
|
||||||
|
| Provenance generation | `trustgraph-base/trustgraph/provenance/triples.py` | PROV-O triple creation |
|
||||||
|
| Librarian | `trustgraph-flow/trustgraph/librarian/librarian.py` | Document storage service |
|
||||||
|
| Library tables | `trustgraph-flow/trustgraph/tables/library.py` | Document metadata in Cassandra |
|
||||||
|
| Blob store | `trustgraph-flow/trustgraph/librarian/blob_store.py` | S3/MinIO object storage |
|
||||||
|
|
||||||
|
## Technical Design
|
||||||
|
|
||||||
|
### Change 1: Named Graph Field in Core Storage
|
||||||
|
|
||||||
|
#### Cassandra Schema
|
||||||
|
|
||||||
|
Extend the `triples` tuple from 6 to 7 elements, adding the graph name:
|
||||||
|
|
||||||
|
```
|
||||||
|
triples list<tuple<
|
||||||
|
text, boolean, -- s_val, s_is_uri
|
||||||
|
text, boolean, -- p_val, p_is_uri
|
||||||
|
text, boolean, -- o_val, o_is_uri
|
||||||
|
text -- graph name (empty string = default graph)
|
||||||
|
>>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Migration**: The schema change uses `ALTER TABLE` or is handled by
|
||||||
|
creating a new table version. Existing rows with 6-element tuples must be
|
||||||
|
handled gracefully on read — if the tuple has 6 elements, treat graph as
|
||||||
|
default.
|
||||||
|
|
||||||
|
#### Write Path (`add_triples`)
|
||||||
|
|
||||||
|
Change `tables/knowledge.py:add_triples()` to include `triple.g`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
triples = [
|
||||||
|
(
|
||||||
|
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
|
||||||
|
v.g or ""
|
||||||
|
)
|
||||||
|
for v in m.triples
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Read Path (`get_triples`)
|
||||||
|
|
||||||
|
Change `tables/knowledge.py:get_triples()` to restore the graph name:
|
||||||
|
|
||||||
|
```python
|
||||||
|
Triple(
|
||||||
|
s = tuple_to_term(elt[0], elt[1]),
|
||||||
|
p = tuple_to_term(elt[2], elt[3]),
|
||||||
|
o = tuple_to_term(elt[4], elt[5]),
|
||||||
|
g = elt[6] if len(elt) > 6 and elt[6] else None,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `len(elt) > 6` guard provides backward compatibility with existing
|
||||||
|
6-element rows.
|
||||||
|
|
||||||
|
#### Core File Format
|
||||||
|
|
||||||
|
Extend triple dicts in the `"t"` record to include the graph name:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In get_kg_core.py write_triple — each triple dict gains "g" key
|
||||||
|
{"s": ..., "p": ..., "o": ..., "g": "urn:graph:source"}
|
||||||
|
```
|
||||||
|
|
||||||
|
On read (`put_kg_core.py`), treat missing `"g"` key as default graph for
|
||||||
|
backward compatibility with old core files.
|
||||||
|
|
||||||
|
### Change 2: Provenance Triples in Cores
|
||||||
|
|
||||||
|
#### Investigation Required
|
||||||
|
|
||||||
|
Before implementation, verify:
|
||||||
|
|
||||||
|
1. Whether provenance triples arrive on the `triples-input` topic that the
|
||||||
|
knowledge core store processor already listens on.
|
||||||
|
2. If not, which topic they use, and whether the store processor should
|
||||||
|
subscribe to it.
|
||||||
|
|
||||||
|
#### If provenance triples already arrive at the store
|
||||||
|
|
||||||
|
The only change needed is Change 1 (named graphs) — the provenance triples
|
||||||
|
are already being stored, just without their graph name. Once graph names
|
||||||
|
are preserved, provenance triples will round-trip correctly.
|
||||||
|
|
||||||
|
#### If provenance triples do NOT arrive at the store
|
||||||
|
|
||||||
|
Two options:
|
||||||
|
|
||||||
|
**Option A — Route provenance to the existing store topic**: Configure the
|
||||||
|
flow so provenance triples are published to the same `triples-input` topic.
|
||||||
|
This is the simpler approach and keeps the store processor unchanged.
|
||||||
|
|
||||||
|
**Option B — Add a subscription**: Add a new `ConsumerSpec` in the store
|
||||||
|
processor for the provenance topic. This keeps provenance routing
|
||||||
|
independent but adds complexity.
|
||||||
|
|
||||||
|
Recommendation: Option A, unless there is a reason provenance triples are
|
||||||
|
intentionally kept off the core store topic.
|
||||||
|
|
||||||
|
### Change 3: Source Material in Cores
|
||||||
|
|
||||||
|
This is the largest change. The goal is that when a core is loaded on a
|
||||||
|
fresh instance, provenance links to source material resolve.
|
||||||
|
|
||||||
|
#### Architecture
|
||||||
|
|
||||||
|
Source material is **not stored in the knowledge core tables**. It lives in
|
||||||
|
the librarian (Cassandra `library` keyspace + S3/MinIO blob store) and is
|
||||||
|
fetched on demand via the librarian's existing service API.
|
||||||
|
|
||||||
|
The knowledge manager acts as a **client of the librarian service** — it
|
||||||
|
calls the librarian's request/response API over pub/sub to retrieve document
|
||||||
|
metadata and content. It does not access the library's Cassandra tables or
|
||||||
|
blob store directly.
|
||||||
|
|
||||||
|
#### Transport
|
||||||
|
|
||||||
|
The librarian's pub/sub API already handles chunking of large documents.
|
||||||
|
This chunking is designed to be websocket-friendly, so library content
|
||||||
|
flowing through the API gateway to external clients does not require
|
||||||
|
re-chunking. The API gateway remains a transport layer.
|
||||||
|
|
||||||
|
```
|
||||||
|
Download:
|
||||||
|
Knowledge manager ──pub/sub──► Librarian (fetch metadata + content)
|
||||||
|
Knowledge manager ──pub/sub──► API gateway ──websocket──► Client
|
||||||
|
|
||||||
|
Upload:
|
||||||
|
Client ──websocket──► API gateway ──pub/sub──► Knowledge manager
|
||||||
|
Knowledge manager ──pub/sub──► Librarian (store metadata + content)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### What to Include
|
||||||
|
|
||||||
|
The provenance chain links facts → chunks → pages → documents. For the
|
||||||
|
chain to resolve, the core must include:
|
||||||
|
|
||||||
|
1. **Document metadata** — the library record for each document in the
|
||||||
|
hierarchy (id, kind, title, parent_id, document_type, etc.)
|
||||||
|
2. **Document content** — the blob data for each document (original file,
|
||||||
|
extracted text pages, text chunks)
|
||||||
|
|
||||||
|
Including the full hierarchy is necessary because:
|
||||||
|
- A user viewing provenance needs to traverse fact → chunk → page → document
|
||||||
|
- The chunk text is needed to show what text a fact was extracted from
|
||||||
|
- The page text provides broader context
|
||||||
|
- The original document is needed for full source attribution
|
||||||
|
|
||||||
|
#### Size Implications
|
||||||
|
|
||||||
|
Source material will significantly increase core file sizes. A rough model:
|
||||||
|
|
||||||
|
| Component | Typical size per document |
|
||||||
|
|-----------|-------------------------|
|
||||||
|
| Triples + embeddings (current) | 1-10 MB |
|
||||||
|
| Chunk text (all chunks) | ~same as original document |
|
||||||
|
| Page text (all pages) | ~same as original document |
|
||||||
|
| Original document (PDF, etc.) | Varies widely (KB to hundreds of MB) |
|
||||||
|
|
||||||
|
For a 10 MB PDF, the core could grow from ~5 MB to ~25 MB (original +
|
||||||
|
derived text + existing data). For large document sets, cores could become
|
||||||
|
very large.
|
||||||
|
|
||||||
|
**Decision needed**: Whether to include original documents or just derived
|
||||||
|
text (pages + chunks). Including only derived text still allows provenance
|
||||||
|
display but loses the ability to serve the original file.
|
||||||
|
|
||||||
|
#### New Core File Record Types
|
||||||
|
|
||||||
|
Add new msgpack record types for library content:
|
||||||
|
|
||||||
|
| Type tag | Payload | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `"lm"` | `{"id", "kind", "title", "parent_id", "document_type", "comments", "tags", "metadata"}` | Library document metadata |
|
||||||
|
| `"lb"` | `{"id", "data"}` | Library document blob content (chunked by pub/sub layer) |
|
||||||
|
|
||||||
|
These are emitted after the existing `"t"` and `"ge"` records during
|
||||||
|
download and processed during upload.
|
||||||
|
|
||||||
|
#### Download Path
|
||||||
|
|
||||||
|
Extend `KnowledgeManager.get_kg_core()` to:
|
||||||
|
|
||||||
|
1. Stream triples and graph embeddings from the core store (existing
|
||||||
|
behavior).
|
||||||
|
2. Use the librarian service API to retrieve documents associated with
|
||||||
|
this core ID:
|
||||||
|
a. Fetch the root document metadata and content.
|
||||||
|
b. Use `list-children` to discover child documents (pages, chunks).
|
||||||
|
c. Recursively fetch metadata and content for each child.
|
||||||
|
3. Stream each document as `"lm"` (metadata) and `"lb"` (content) records.
|
||||||
|
|
||||||
|
The knowledge manager gains the librarian service as a pub/sub dependency.
|
||||||
|
Large document content is chunked by the librarian's existing pub/sub
|
||||||
|
transport — the knowledge manager receives and forwards these chunks without
|
||||||
|
buffering the full blob in memory.
|
||||||
|
|
||||||
|
#### Upload Path
|
||||||
|
|
||||||
|
Extend `KnowledgeManager.put_kg_core()` to handle the new record types:
|
||||||
|
|
||||||
|
1. For `"lm"` records: call the librarian service API to create/update
|
||||||
|
the document metadata.
|
||||||
|
2. For `"lb"` records: call the librarian service API to store the
|
||||||
|
document content.
|
||||||
|
|
||||||
|
Parent-child relationships are preserved because `parent_id` is stored in
|
||||||
|
the metadata. Documents should be processed in hierarchy order (parent
|
||||||
|
before child) to satisfy any ordering constraints.
|
||||||
|
|
||||||
|
#### Load Path
|
||||||
|
|
||||||
|
The load path (`_load_kg_core`) publishes triples and embeddings to Pulsar
|
||||||
|
topics for ingestion into graph/vector stores. Source material does not need
|
||||||
|
to flow through the load path — it is already in the librarian after the
|
||||||
|
upload step and can be accessed directly by services that need it.
|
||||||
|
|
||||||
|
No changes to the load path for source material.
|
||||||
|
|
||||||
|
#### CLI Changes
|
||||||
|
|
||||||
|
**`tg-get-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
|
||||||
|
the file writer.
|
||||||
|
|
||||||
|
**`tg-put-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
|
||||||
|
the file reader. Send library records to the knowledge manager alongside
|
||||||
|
triple/embedding records.
|
||||||
|
|
||||||
|
#### Associating Documents with Cores
|
||||||
|
|
||||||
|
The core ID is `metadata.root`, which is the root document ID from the
|
||||||
|
librarian. This provides a natural join: the core's root document and all
|
||||||
|
its children (pages, chunks) are the source material for that core.
|
||||||
|
|
||||||
|
The librarian's `list-children` API provides the child documents. A
|
||||||
|
recursive traversal from the root document collects the full hierarchy.
|
||||||
|
|
||||||
|
### API Changes
|
||||||
|
|
||||||
|
#### KnowledgeResponse Schema
|
||||||
|
|
||||||
|
Add optional fields to `KnowledgeResponse` for library data:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeResponse:
|
||||||
|
error: Error | None = None
|
||||||
|
ids: list | None = None
|
||||||
|
eos: bool = False
|
||||||
|
triples: Triples | None = None
|
||||||
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
|
document_embeddings: DocumentEmbeddings | None = None
|
||||||
|
library_metadata: LibraryMetadata | None = None # new
|
||||||
|
library_blob: LibraryBlob | None = None # new
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Schema Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class LibraryMetadata:
|
||||||
|
id: str
|
||||||
|
kind: str | None = None
|
||||||
|
title: str | None = None
|
||||||
|
parent_id: str | None = None
|
||||||
|
document_type: str | None = None
|
||||||
|
comments: str | None = None
|
||||||
|
tags: list[str] | None = None
|
||||||
|
metadata: list[Triple] | None = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LibraryBlob:
|
||||||
|
id: str
|
||||||
|
data: bytes
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Socket API
|
||||||
|
|
||||||
|
The existing streaming protocol for `get-kg-core` / `put-kg-core` carries
|
||||||
|
these new fields naturally — responses already stream multiple record types.
|
||||||
|
|
||||||
|
### Dependencies Between Changes
|
||||||
|
|
||||||
|
```
|
||||||
|
Change 1 (named graphs) ◄── Change 2 depends on this
|
||||||
|
│
|
||||||
|
└── Change 2 (provenance triples)
|
||||||
|
│
|
||||||
|
└── Change 3 (source material) is independent
|
||||||
|
```
|
||||||
|
|
||||||
|
Change 1 is a prerequisite for Change 2 (provenance triples use named
|
||||||
|
graphs). Change 3 is independent and can be implemented in parallel.
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
- **Workspace isolation**: Core download/upload must respect workspace
|
||||||
|
boundaries. Source material from the librarian must only be included if
|
||||||
|
it belongs to the same workspace as the core. This is already enforced
|
||||||
|
by the existing workspace-scoped queries.
|
||||||
|
- **Large blob transfer**: Streaming large documents through the API
|
||||||
|
is handled by the librarian's existing pub/sub chunking, which is
|
||||||
|
designed to be websocket-friendly. No additional chunking layer is
|
||||||
|
needed.
|
||||||
|
- **Cross-instance trust**: When uploading a core from an external source,
|
||||||
|
the library content should be treated as untrusted input. Document
|
||||||
|
metadata and blob content should be validated before insertion.
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- **Core file size**: Including source material will significantly increase
|
||||||
|
core file sizes. Consider adding a flag to download/upload commands to
|
||||||
|
optionally exclude source material for use cases where only the knowledge
|
||||||
|
graph is needed.
|
||||||
|
- **Streaming**: All paths already use streaming (paged Cassandra queries,
|
||||||
|
msgpack record-at-a-time). Library content should follow the same pattern.
|
||||||
|
- **Cassandra schema migration**: Changing the tuple width in the `triples`
|
||||||
|
table requires careful handling. Cassandra frozen tuples cannot be altered
|
||||||
|
in place — a migration strategy is needed (see Migration Plan).
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
- **Unit tests**: Triple round-trip with graph name (write → read →
|
||||||
|
verify `g` field preserved). Backward compatibility with 6-element tuples.
|
||||||
|
- **Integration tests**: Full lifecycle — extract with provenance → download
|
||||||
|
core → upload to fresh instance → load → verify provenance chain resolves.
|
||||||
|
- **File format tests**: Read old-format core files (no graph name, no
|
||||||
|
library records) and verify they load without error.
|
||||||
|
- **Library inclusion tests**: Download core with source material → upload →
|
||||||
|
verify documents accessible through librarian.
|
||||||
|
|
||||||
|
## Migration Plan
|
||||||
|
|
||||||
|
### Cassandra Schema
|
||||||
|
|
||||||
|
The `triples` table stores tuples in a `list<tuple<...>>` column. Cassandra
|
||||||
|
does not support altering the type of an existing column. Options:
|
||||||
|
|
||||||
|
**Option A — New table**: Create a `triples_v2` table with the 7-element
|
||||||
|
tuple. Migrate data from `triples` to `triples_v2`. The read path checks
|
||||||
|
both tables during a transition period, then the old table is dropped.
|
||||||
|
|
||||||
|
**Option B — Dual read**: Keep the existing table. The read path handles
|
||||||
|
both 6-element and 7-element tuples by checking length. New writes use
|
||||||
|
7-element tuples. This works if Cassandra accepts variable-length tuples in
|
||||||
|
a list — **needs verification**.
|
||||||
|
|
||||||
|
**Option C — Separate graph column**: Instead of extending the tuple, add a
|
||||||
|
parallel `graphs list<text>` column where `graphs[i]` corresponds to
|
||||||
|
`triples[i]`. This avoids tuple migration entirely but requires keeping the
|
||||||
|
two lists in sync.
|
||||||
|
|
||||||
|
Recommendation: Verify Option B first (simplest). Fall back to Option A if
|
||||||
|
Cassandra rejects mixed tuple lengths.
|
||||||
|
|
||||||
|
### Core File Format
|
||||||
|
|
||||||
|
Backward compatible by design:
|
||||||
|
- Old files lack `"g"` in triple dicts and have no `"lm"`/`"lb"` records →
|
||||||
|
handled by defaults.
|
||||||
|
- New files read by old code → old code ignores unknown record types (the
|
||||||
|
existing `read_message` raises on unknown types, so this needs a small
|
||||||
|
fix to skip unknown types gracefully).
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
1. **Provenance topic routing**: Do provenance triples currently arrive at
|
||||||
|
the `triples-input` topic consumed by the knowledge core store? If not,
|
||||||
|
what topic are they on?
|
||||||
|
|
||||||
|
2. **Include original documents?**: Should cores include the original
|
||||||
|
uploaded document (e.g. PDF), or only derived text (pages + chunks)?
|
||||||
|
Including originals makes cores fully self-contained but potentially
|
||||||
|
very large. Excluding them preserves provenance text display but loses
|
||||||
|
the ability to serve the original file.
|
||||||
|
|
||||||
|
3. **Optional source material**: Should there be a flag on download/upload
|
||||||
|
to include or exclude source material? This would let users choose
|
||||||
|
between compact cores (knowledge only) and complete cores (knowledge +
|
||||||
|
sources).
|
||||||
|
|
||||||
|
4. **Cassandra tuple migration**: Can Cassandra handle mixed-length tuples
|
||||||
|
in a `list<tuple<...>>` column, or is a table migration required?
|
||||||
|
|
||||||
|
5. **Document embedding cores**: DE cores are managed alongside KG cores.
|
||||||
|
Do they need the same treatment (source material inclusion)? The
|
||||||
|
document embeddings reference chunk IDs — the same provenance chain
|
||||||
|
applies.
|
||||||
|
|
||||||
|
6. **Core versioning**: Should the core file include a version marker so
|
||||||
|
readers can distinguish old-format from new-format files without
|
||||||
|
trial-and-error parsing?
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
|
||||||
|
- Query-time explainability: `docs/tech-specs/query-time-explainability.md`
|
||||||
|
- Agent explainability: `docs/tech-specs/agent-explainability.md`
|
||||||
|
- Data ownership model: `docs/tech-specs/data-ownership-model.md`
|
||||||
|
|
@ -410,3 +410,56 @@ class TestEdgeCases:
|
||||||
assert hosts == ['mixed-host']
|
assert hosts == ['mixed-host']
|
||||||
assert username is None # Stays None
|
assert username is None # Stays None
|
||||||
assert password == 'mixed-pass'
|
assert password == 'mixed-pass'
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplicationFactorParamPath:
|
||||||
|
|
||||||
|
def test_explicit_kwarg(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=3,
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_kwarg_overrides_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=3,
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_env_fallback_when_kwarg_none(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=None,
|
||||||
|
)
|
||||||
|
assert rf == 5
|
||||||
|
|
||||||
|
def test_default_when_no_kwarg_no_env(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config()
|
||||||
|
assert rf == 1
|
||||||
|
|
||||||
|
def test_params_dict_path(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
params = {'cassandra_replication_factor': 3}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_params_dict_overrides_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
params = {'cassandra_replication_factor': 3}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
|
||||||
|
def test_params_dict_missing_falls_to_env(self):
|
||||||
|
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
|
||||||
|
params = {}
|
||||||
|
_, _, _, _, rf = resolve_cassandra_config(
|
||||||
|
replication_factor=params.get('cassandra_replication_factor'),
|
||||||
|
)
|
||||||
|
assert rf == 5
|
||||||
136
tests/unit/test_base/test_qdrant_config.py
Normal file
136
tests/unit/test_base/test_qdrant_config.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from trustgraph.base.qdrant_config import (
|
||||||
|
get_qdrant_defaults,
|
||||||
|
resolve_qdrant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetQdrantDefaults:
|
||||||
|
|
||||||
|
def test_defaults_with_no_env_vars(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
assert defaults['url'] == 'http://localhost:6333'
|
||||||
|
assert defaults['api_key'] is None
|
||||||
|
assert defaults['replication_factor'] == 1
|
||||||
|
assert defaults['shard_number'] == 1
|
||||||
|
|
||||||
|
def test_defaults_from_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://qdrant:6333',
|
||||||
|
'QDRANT_API_KEY': 'secret',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
assert defaults['url'] == 'http://qdrant:6333'
|
||||||
|
assert defaults['api_key'] == 'secret'
|
||||||
|
assert defaults['replication_factor'] == 3
|
||||||
|
assert defaults['shard_number'] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveQdrantConfig:
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config()
|
||||||
|
assert url == 'http://localhost:6333'
|
||||||
|
assert api_key is None
|
||||||
|
assert rf == 1
|
||||||
|
assert sn == 1
|
||||||
|
|
||||||
|
def test_explicit_kwargs(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config(
|
||||||
|
url='http://custom:6333',
|
||||||
|
api_key='key',
|
||||||
|
replication_factor=3,
|
||||||
|
shard_number=5,
|
||||||
|
)
|
||||||
|
assert url == 'http://custom:6333'
|
||||||
|
assert api_key == 'key'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_kwargs_override_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://env:6333',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '10',
|
||||||
|
'QDRANT_SHARD_NUMBER': '10',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
url, _, rf, sn = resolve_qdrant_config(
|
||||||
|
url='http://explicit:6333',
|
||||||
|
replication_factor=3,
|
||||||
|
shard_number=5,
|
||||||
|
)
|
||||||
|
assert url == 'http://explicit:6333'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_env_fallback_when_kwargs_none(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_URL': 'http://env:6333',
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
url, _, rf, sn = resolve_qdrant_config()
|
||||||
|
assert url == 'http://env:6333'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_path(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
params = {
|
||||||
|
'store_uri': 'http://params:6333',
|
||||||
|
'api_key': 'pkey',
|
||||||
|
'qdrant_replication_factor': 3,
|
||||||
|
'qdrant_shard_number': 5,
|
||||||
|
}
|
||||||
|
url, api_key, rf, sn = resolve_qdrant_config(
|
||||||
|
url=params.get('store_uri'),
|
||||||
|
api_key=params.get('api_key'),
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert url == 'http://params:6333'
|
||||||
|
assert api_key == 'pkey'
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_overrides_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '10',
|
||||||
|
'QDRANT_SHARD_NUMBER': '10',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
params = {
|
||||||
|
'qdrant_replication_factor': 3,
|
||||||
|
'qdrant_shard_number': 5,
|
||||||
|
}
|
||||||
|
_, _, rf, sn = resolve_qdrant_config(
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
||||||
|
def test_params_dict_missing_falls_to_env(self):
|
||||||
|
env = {
|
||||||
|
'QDRANT_REPLICATION_FACTOR': '3',
|
||||||
|
'QDRANT_SHARD_NUMBER': '5',
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
params = {}
|
||||||
|
_, _, rf, sn = resolve_qdrant_config(
|
||||||
|
replication_factor=params.get('qdrant_replication_factor'),
|
||||||
|
shard_number=params.get('qdrant_shard_number'),
|
||||||
|
)
|
||||||
|
assert rf == 3
|
||||||
|
assert sn == 5
|
||||||
|
|
@ -11,7 +11,12 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||||
from unittest.mock import call
|
from unittest.mock import call
|
||||||
|
|
||||||
from trustgraph.cores.knowledge import KnowledgeManager
|
from trustgraph.cores.knowledge import KnowledgeManager
|
||||||
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
|
from trustgraph.schema import (
|
||||||
|
KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term,
|
||||||
|
EntityEmbeddings, IRI, LITERAL,
|
||||||
|
LibraryMetadata, LibraryBlob,
|
||||||
|
LibrarianResponse, DocumentMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -381,3 +386,244 @@ class TestKnowledgeManagerOtherMethods:
|
||||||
mock_respond.assert_called_once()
|
mock_respond.assert_called_once()
|
||||||
response = mock_respond.call_args[0][0]
|
response = mock_respond.call_args[0][0]
|
||||||
assert response.error is None
|
assert response.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeManagerLibraryDownload:
|
||||||
|
"""Test get_kg_core streaming of library documents."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager_with_librarian(self, mock_flow_config):
|
||||||
|
with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
|
||||||
|
mock_librarian = AsyncMock()
|
||||||
|
manager = KnowledgeManager(
|
||||||
|
cassandra_host=["localhost"],
|
||||||
|
cassandra_username="test_user",
|
||||||
|
cassandra_password="test_pass",
|
||||||
|
keyspace="test_keyspace",
|
||||||
|
flow_config=mock_flow_config,
|
||||||
|
librarian=mock_librarian,
|
||||||
|
)
|
||||||
|
manager.table_store = AsyncMock()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_kg_core_streams_library_docs(self, manager_with_librarian):
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.id = "root-doc"
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
|
manager_with_librarian.table_store.get_triples = AsyncMock()
|
||||||
|
manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
|
||||||
|
|
||||||
|
root_meta = DocumentMetadata(
|
||||||
|
id="root-doc", kind="application/pdf", title="Test PDF",
|
||||||
|
document_type="source",
|
||||||
|
)
|
||||||
|
child_meta = DocumentMetadata(
|
||||||
|
id="chunk-1", kind="text/plain", title="Chunk 1",
|
||||||
|
parent_id="root-doc", document_type="chunk",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager_with_librarian.librarian.fetch_document_metadata.return_value = root_meta
|
||||||
|
manager_with_librarian.librarian.request.return_value = LibrarianResponse(
|
||||||
|
document_metadatas=[child_meta],
|
||||||
|
)
|
||||||
|
manager_with_librarian.librarian.fetch_document_content.side_effect = [
|
||||||
|
b"cm9vdCBjb250ZW50",
|
||||||
|
b"Y2h1bmsgY29udGVudA==",
|
||||||
|
]
|
||||||
|
|
||||||
|
await manager_with_librarian.get_kg_core(
|
||||||
|
mock_request, mock_respond, "test-user"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = [c[0][0] for c in mock_respond.call_args_list]
|
||||||
|
|
||||||
|
lm_responses = [r for r in responses if r.library_metadata is not None]
|
||||||
|
lb_responses = [r for r in responses if r.library_blob is not None]
|
||||||
|
eos_responses = [r for r in responses if r.eos is True]
|
||||||
|
|
||||||
|
assert len(lm_responses) == 2
|
||||||
|
assert lm_responses[0].library_metadata.id == "root-doc"
|
||||||
|
assert lm_responses[0].library_metadata.document_type == "source"
|
||||||
|
assert lm_responses[1].library_metadata.id == "chunk-1"
|
||||||
|
assert lm_responses[1].library_metadata.parent_id == "root-doc"
|
||||||
|
|
||||||
|
assert len(lb_responses) == 2
|
||||||
|
assert lb_responses[0].library_blob.id == "root-doc"
|
||||||
|
assert lb_responses[0].library_blob.data == b"cm9vdCBjb250ZW50"
|
||||||
|
assert lb_responses[1].library_blob.id == "chunk-1"
|
||||||
|
|
||||||
|
assert len(eos_responses) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_kg_core_no_librarian_skips_library(self, mock_flow_config):
|
||||||
|
with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
|
||||||
|
manager = KnowledgeManager(
|
||||||
|
cassandra_host=["localhost"],
|
||||||
|
cassandra_username="u", cassandra_password="p",
|
||||||
|
keyspace="ks", flow_config=mock_flow_config,
|
||||||
|
)
|
||||||
|
manager.table_store = AsyncMock()
|
||||||
|
manager.table_store.get_triples = AsyncMock()
|
||||||
|
manager.table_store.get_graph_embeddings = AsyncMock()
|
||||||
|
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.id = "doc-1"
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
|
await manager.get_kg_core(mock_request, mock_respond, "w")
|
||||||
|
|
||||||
|
responses = [c[0][0] for c in mock_respond.call_args_list]
|
||||||
|
assert all(r.library_metadata is None for r in responses)
|
||||||
|
assert all(r.library_blob is None for r in responses)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_kg_core_librarian_metadata_failure_is_graceful(
|
||||||
|
self, manager_with_librarian,
|
||||||
|
):
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.id = "missing-doc"
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
|
manager_with_librarian.table_store.get_triples = AsyncMock()
|
||||||
|
manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
|
||||||
|
manager_with_librarian.librarian.fetch_document_metadata.side_effect = (
|
||||||
|
RuntimeError("not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
await manager_with_librarian.get_kg_core(
|
||||||
|
mock_request, mock_respond, "test-user"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = [c[0][0] for c in mock_respond.call_args_list]
|
||||||
|
assert all(r.library_metadata is None for r in responses)
|
||||||
|
assert any(r.eos for r in responses)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeManagerLibraryUpload:
|
||||||
|
"""Test put_kg_core handling of library metadata and blob records."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager_with_librarian(self, mock_flow_config):
|
||||||
|
with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
|
||||||
|
mock_librarian = AsyncMock()
|
||||||
|
manager = KnowledgeManager(
|
||||||
|
cassandra_host=["localhost"],
|
||||||
|
cassandra_username="u", cassandra_password="p",
|
||||||
|
keyspace="ks", flow_config=mock_flow_config,
|
||||||
|
librarian=mock_librarian,
|
||||||
|
)
|
||||||
|
manager.table_store = AsyncMock()
|
||||||
|
return manager
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_put_metadata_then_blob_calls_librarian(
|
||||||
|
self, manager_with_librarian,
|
||||||
|
):
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
manager_with_librarian.librarian.request.return_value = LibrarianResponse()
|
||||||
|
|
||||||
|
# First call: metadata
|
||||||
|
req_meta = Mock()
|
||||||
|
req_meta.triples = None
|
||||||
|
req_meta.graph_embeddings = None
|
||||||
|
req_meta.library_metadata = LibraryMetadata(
|
||||||
|
id="doc-1", kind="application/pdf", title="Test",
|
||||||
|
document_type="source",
|
||||||
|
)
|
||||||
|
req_meta.library_blob = None
|
||||||
|
await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
|
||||||
|
|
||||||
|
# Metadata is buffered, librarian not called yet
|
||||||
|
manager_with_librarian.librarian.request.assert_not_called()
|
||||||
|
|
||||||
|
# Second call: blob
|
||||||
|
req_blob = Mock()
|
||||||
|
req_blob.triples = None
|
||||||
|
req_blob.graph_embeddings = None
|
||||||
|
req_blob.library_metadata = None
|
||||||
|
req_blob.library_blob = LibraryBlob(
|
||||||
|
id="doc-1", data=b"dGVzdA==",
|
||||||
|
)
|
||||||
|
await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
|
||||||
|
|
||||||
|
# Now librarian should have been called with add-document
|
||||||
|
manager_with_librarian.librarian.request.assert_called_once()
|
||||||
|
call_args = manager_with_librarian.librarian.request.call_args[0][0]
|
||||||
|
assert call_args.operation == "add-document"
|
||||||
|
assert call_args.document_metadata.id == "doc-1"
|
||||||
|
assert call_args.document_metadata.kind == "application/pdf"
|
||||||
|
assert call_args.content == b"dGVzdA=="
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_put_child_document_uses_add_child_operation(
|
||||||
|
self, manager_with_librarian,
|
||||||
|
):
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
manager_with_librarian.librarian.request.return_value = LibrarianResponse()
|
||||||
|
|
||||||
|
req_meta = Mock()
|
||||||
|
req_meta.triples = None
|
||||||
|
req_meta.graph_embeddings = None
|
||||||
|
req_meta.library_metadata = LibraryMetadata(
|
||||||
|
id="chunk-1", kind="text/plain", title="Chunk",
|
||||||
|
parent_id="doc-1", document_type="chunk",
|
||||||
|
)
|
||||||
|
req_meta.library_blob = None
|
||||||
|
await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
|
||||||
|
|
||||||
|
req_blob = Mock()
|
||||||
|
req_blob.triples = None
|
||||||
|
req_blob.graph_embeddings = None
|
||||||
|
req_blob.library_metadata = None
|
||||||
|
req_blob.library_blob = LibraryBlob(id="chunk-1", data=b"Y2h1bms=")
|
||||||
|
await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
|
||||||
|
|
||||||
|
call_args = manager_with_librarian.librarian.request.call_args[0][0]
|
||||||
|
assert call_args.operation == "add-child-document"
|
||||||
|
assert call_args.document_metadata.parent_id == "doc-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_put_blob_without_metadata_logs_warning(
|
||||||
|
self, manager_with_librarian,
|
||||||
|
):
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
|
req_blob = Mock()
|
||||||
|
req_blob.triples = None
|
||||||
|
req_blob.graph_embeddings = None
|
||||||
|
req_blob.library_metadata = None
|
||||||
|
req_blob.library_blob = LibraryBlob(id="orphan", data=b"data")
|
||||||
|
await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
|
||||||
|
|
||||||
|
# Librarian should not be called for orphan blob
|
||||||
|
manager_with_librarian.librarian.request.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_put_existing_document_is_graceful(
|
||||||
|
self, manager_with_librarian,
|
||||||
|
):
|
||||||
|
mock_respond = AsyncMock()
|
||||||
|
manager_with_librarian.librarian.request.side_effect = RuntimeError(
|
||||||
|
"Document already exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
req_meta = Mock()
|
||||||
|
req_meta.triples = None
|
||||||
|
req_meta.graph_embeddings = None
|
||||||
|
req_meta.library_metadata = LibraryMetadata(
|
||||||
|
id="doc-1", kind="application/pdf", title="Test",
|
||||||
|
document_type="source",
|
||||||
|
)
|
||||||
|
req_meta.library_blob = None
|
||||||
|
await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
|
||||||
|
|
||||||
|
req_blob = Mock()
|
||||||
|
req_blob.triples = None
|
||||||
|
req_blob.graph_embeddings = None
|
||||||
|
req_blob.library_metadata = None
|
||||||
|
req_blob.library_blob = LibraryBlob(id="doc-1", data=b"data")
|
||||||
|
await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
|
||||||
|
|
||||||
|
# Should not raise — "already exists" is handled gracefully
|
||||||
|
|
@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test successful PDF processing"""
|
"""Test successful PDF processing"""
|
||||||
# Mock PDF content
|
# Mock PDF content
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
# Mock PyPDFLoader
|
# Mock PyPDFLoader
|
||||||
|
|
@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
# Verify triples were sent for each page (provenance)
|
# Verify triples were sent for each page (provenance)
|
||||||
assert mock_triples_flow.send.call_count == 2
|
assert mock_triples_flow.send.call_count == 2
|
||||||
|
|
||||||
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
|
async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
|
"""Test rejecting non-PDF content before invoking the PDF loader"""
|
||||||
|
html_content = b"<html><body>Not found</body></html>"
|
||||||
|
html_base64 = base64.b64encode(html_content)
|
||||||
|
|
||||||
|
mock_metadata = Metadata(id="test-doc")
|
||||||
|
mock_document = Document(metadata=mock_metadata, document_id="doc-123")
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = mock_document
|
||||||
|
|
||||||
|
mock_output_flow = AsyncMock()
|
||||||
|
mock_triples_flow = AsyncMock()
|
||||||
|
mock_flow = MagicMock(side_effect=lambda name: {
|
||||||
|
"output": mock_output_flow,
|
||||||
|
"triples": mock_triples_flow,
|
||||||
|
}.get(name))
|
||||||
|
mock_flow.librarian.fetch_document_metadata = AsyncMock(
|
||||||
|
return_value=MagicMock(kind="application/pdf")
|
||||||
|
)
|
||||||
|
mock_flow.librarian.fetch_document_content = AsyncMock(
|
||||||
|
return_value=html_base64
|
||||||
|
)
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'id': 'test-pdf-decoder',
|
||||||
|
'taskgroup': AsyncMock()
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
|
mock_pdf_loader_class.assert_not_called()
|
||||||
|
mock_output_flow.send.assert_not_called()
|
||||||
|
mock_triples_flow.send.assert_not_called()
|
||||||
|
mock_flow.librarian.save_child_document.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.base.librarian_client.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.librarian_client.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of empty PDF"""
|
"""Test handling of empty PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
mock_loader = MagicMock()
|
mock_loader = MagicMock()
|
||||||
|
|
@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of unicode content in PDF"""
|
"""Test handling of unicode content in PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
mock_loader = MagicMock()
|
mock_loader = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from trustgraph.embeddings.hf.hf import Processor
|
||||||
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
"""Test HuggingFace dynamic model loading and caching"""
|
"""Test HuggingFace dynamic model loading and caching"""
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -39,7 +39,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert processor.cached_model_name == "test-model"
|
assert processor.cached_model_name == "test-model"
|
||||||
assert processor.embeddings is not None
|
assert processor.embeddings is not None
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -63,7 +63,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
mock_hf_class.assert_not_called()
|
mock_hf_class.assert_not_called()
|
||||||
assert processor.cached_model_name == "test-model"
|
assert processor.cached_model_name == "test-model"
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -84,7 +84,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
mock_hf_class.assert_called_once_with(model_name="different-model")
|
mock_hf_class.assert_called_once_with(model_name="different-model")
|
||||||
assert processor.cached_model_name == "different-model"
|
assert processor.cached_model_name == "different-model"
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -107,7 +107,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert processor.cached_model_name == "test-model" # Still using default
|
assert processor.cached_model_name == "test-model" # Still using default
|
||||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -130,7 +130,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert processor.cached_model_name == "custom-model"
|
assert processor.cached_model_name == "custom-model"
|
||||||
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
|
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -164,7 +164,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
|
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
|
||||||
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
|
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
@ -187,7 +187,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
|
||||||
assert mock_hf_class.call_count == initial_count
|
assert mock_hf_class.call_count == initial_count
|
||||||
assert processor.cached_model_name == "test-model"
|
assert processor.cached_model_name == "test-model"
|
||||||
|
|
||||||
@patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
|
@patch('langchain_huggingface.HuggingFaceEmbeddings')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
|
||||||
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
|
||||||
|
|
|
||||||
|
|
@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
|
||||||
"""Test queries against the unified rows table"""
|
"""Test queries against the unified rows table"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_query_with_index_match(self, mock_async_execute):
|
async def test_query_with_index_match(self, mock_async_execute_paged):
|
||||||
"""Test query execution with matching index"""
|
"""Test query execution with matching index"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
|
|
@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
|
||||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock async_execute to return test data
|
# Mock async_execute_paged to return test data (list of pages)
|
||||||
mock_row = MagicMock()
|
mock_row = MagicMock()
|
||||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||||
mock_async_execute.return_value = [mock_row]
|
mock_async_execute_paged.return_value = [[mock_row]]
|
||||||
|
|
||||||
schema = RowSchema(
|
schema = RowSchema(
|
||||||
name="products",
|
name="products",
|
||||||
|
|
@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
|
||||||
|
|
||||||
# Verify Cassandra was connected and queried
|
# Verify Cassandra was connected and queried
|
||||||
processor.connect_cassandra.assert_called_once()
|
processor.connect_cassandra.assert_called_once()
|
||||||
mock_async_execute.assert_called_once()
|
mock_async_execute_paged.assert_called_once()
|
||||||
|
|
||||||
# Verify query structure - should query unified rows table
|
# Verify query structure - should query unified rows table
|
||||||
call_args = mock_async_execute.call_args
|
call_args = mock_async_execute_paged.call_args
|
||||||
query = call_args[0][1]
|
query = call_args[0][1]
|
||||||
params = call_args[0][2]
|
params = call_args[0][2]
|
||||||
|
|
||||||
|
|
@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
|
||||||
assert results[0]["category"] == "electronics"
|
assert results[0]["category"] == "electronics"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
|
||||||
async def test_query_without_index_match(self, mock_async_execute):
|
async def test_query_without_index_match(self, mock_async_scan):
|
||||||
"""Test query execution without matching index (scan mode)"""
|
"""Test query execution without matching index (scan mode)"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
|
|
@ -406,12 +406,10 @@ class TestUnifiedTableQueries:
|
||||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock async_execute to return test data
|
# Mock async_scan to return filtered test data
|
||||||
mock_row1 = MagicMock()
|
mock_row1 = MagicMock()
|
||||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||||
mock_row2 = MagicMock()
|
mock_async_scan.return_value = [mock_row1]
|
||||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
|
||||||
mock_async_execute.return_value = [mock_row1, mock_row2]
|
|
||||||
|
|
||||||
schema = RowSchema(
|
schema = RowSchema(
|
||||||
name="products",
|
name="products",
|
||||||
|
|
@ -432,13 +430,16 @@ class TestUnifiedTableQueries:
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query should use ALLOW FILTERING for scan
|
# Verify async_scan was called
|
||||||
call_args = mock_async_execute.call_args
|
mock_async_scan.assert_called_once()
|
||||||
|
|
||||||
|
# Verify query structure
|
||||||
|
call_args = mock_async_scan.call_args
|
||||||
query = call_args[0][1]
|
query = call_args[0][1]
|
||||||
|
|
||||||
assert "ALLOW FILTERING" in query
|
assert "ALLOW FILTERING" in query
|
||||||
|
|
||||||
# Should post-filter results
|
# Should return filtered results
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0]["name"] == "Product A"
|
assert results[0]["name"] == "Product A"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
|
||||||
proc.collection_exists = MagicMock(return_value=True)
|
proc.collection_exists = MagicMock(return_value=True)
|
||||||
proc._cache_lock = asyncio.Lock()
|
proc._cache_lock = asyncio.Lock()
|
||||||
proc._known_collections = set()
|
proc._known_collections = set()
|
||||||
|
proc.replication_factor = 1
|
||||||
|
proc.shard_number = 1
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.metadata.collection = "graphs"
|
msg.metadata.collection = "graphs"
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,9 @@ def _make_store():
|
||||||
class TestGetGraphEmbeddings:
|
class TestGetGraphEmbeddings:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_row_converts_to_entity_embeddings_with_singular_vector(
|
async def test_row_converts_to_entity_embeddings_with_singular_vector(
|
||||||
self, mock_async_execute
|
self, mock_async_execute_paged
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Cassandra rows return entities as a list of [entity_tuple, vector]
|
Cassandra rows return entities as a list of [entity_tuple, vector]
|
||||||
|
|
@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
|
||||||
store = _make_store()
|
store = _make_store()
|
||||||
store.cassandra = Mock()
|
store.cassandra = Mock()
|
||||||
store.get_graph_embeddings_stmt = Mock()
|
store.get_graph_embeddings_stmt = Mock()
|
||||||
mock_async_execute.return_value = [fake_row]
|
mock_async_execute_paged.return_value = [[fake_row]]
|
||||||
|
|
||||||
received = []
|
received = []
|
||||||
|
|
||||||
|
|
@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
|
||||||
|
|
||||||
await store.get_graph_embeddings("alice", "doc-1", receiver)
|
await store.get_graph_embeddings("alice", "doc-1", receiver)
|
||||||
|
|
||||||
mock_async_execute.assert_called_once_with(
|
mock_async_execute_paged.assert_called_once_with(
|
||||||
store.cassandra,
|
store.cassandra,
|
||||||
store.get_graph_embeddings_stmt,
|
store.get_graph_embeddings_stmt,
|
||||||
("alice", "doc-1"),
|
("alice", "doc-1"),
|
||||||
|
|
@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
|
||||||
assert ge.entities[2].entity.value == "a literal entity"
|
assert ge.entities[2].entity.value == "a literal entity"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
|
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
|
||||||
"""row[3] being None / empty must produce a GraphEmbeddings with
|
"""row[3] being None / empty must produce a GraphEmbeddings with
|
||||||
no entities, not raise."""
|
no entities, not raise."""
|
||||||
fake_row = (None, None, None, None)
|
fake_row = (None, None, None, None)
|
||||||
|
|
@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
|
||||||
store = _make_store()
|
store = _make_store()
|
||||||
store.cassandra = Mock()
|
store.cassandra = Mock()
|
||||||
store.get_graph_embeddings_stmt = Mock()
|
store.get_graph_embeddings_stmt = Mock()
|
||||||
mock_async_execute.return_value = [fake_row]
|
mock_async_execute_paged.return_value = [[fake_row]]
|
||||||
|
|
||||||
received = []
|
received = []
|
||||||
|
|
||||||
|
|
@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
|
||||||
assert received[0].entities == []
|
assert received[0].entities == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
|
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
|
||||||
fake_rows = [
|
fake_rows = [
|
||||||
(None, None, None, [
|
(None, None, None, [
|
||||||
(("http://example.org/a", True), [1.0]),
|
(("http://example.org/a", True), [1.0]),
|
||||||
|
|
@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
|
||||||
store = _make_store()
|
store = _make_store()
|
||||||
store.cassandra = Mock()
|
store.cassandra = Mock()
|
||||||
store.get_graph_embeddings_stmt = Mock()
|
store.get_graph_embeddings_stmt = Mock()
|
||||||
mock_async_execute.return_value = fake_rows
|
mock_async_execute_paged.return_value = [fake_rows]
|
||||||
|
|
||||||
received = []
|
received = []
|
||||||
|
|
||||||
|
|
@ -153,9 +153,9 @@ class TestGetTriples:
|
||||||
the same Metadata construction. Cover it for parity."""
|
the same Metadata construction. Cover it for parity."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_row_converts_to_triples(self, mock_async_execute):
|
async def test_row_converts_to_triples(self, mock_async_execute_paged):
|
||||||
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
|
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri, graph)
|
||||||
fake_row = (
|
fake_row = (
|
||||||
None, None, None,
|
None, None, None,
|
||||||
[
|
[
|
||||||
|
|
@ -163,6 +163,7 @@ class TestGetTriples:
|
||||||
"http://example.org/alice", True,
|
"http://example.org/alice", True,
|
||||||
"http://example.org/knows", True,
|
"http://example.org/knows", True,
|
||||||
"http://example.org/bob", True,
|
"http://example.org/bob", True,
|
||||||
|
"urn:graph:source",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -170,7 +171,7 @@ class TestGetTriples:
|
||||||
store = _make_store()
|
store = _make_store()
|
||||||
store.cassandra = Mock()
|
store.cassandra = Mock()
|
||||||
store.get_triples_stmt = Mock()
|
store.get_triples_stmt = Mock()
|
||||||
mock_async_execute.return_value = [fake_row]
|
mock_async_execute_paged.return_value = [[fake_row]]
|
||||||
|
|
||||||
received = []
|
received = []
|
||||||
|
|
||||||
|
|
@ -191,3 +192,33 @@ class TestGetTriples:
|
||||||
assert t.s.iri == "http://example.org/alice"
|
assert t.s.iri == "http://example.org/alice"
|
||||||
assert t.p.iri == "http://example.org/knows"
|
assert t.p.iri == "http://example.org/knows"
|
||||||
assert t.o.iri == "http://example.org/bob"
|
assert t.o.iri == "http://example.org/bob"
|
||||||
|
assert t.g == "urn:graph:source"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||||
|
async def test_empty_graph_name_becomes_none(self, mock_async_execute_paged):
|
||||||
|
fake_row = (
|
||||||
|
None, None, None,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"http://example.org/alice", True,
|
||||||
|
"http://example.org/knows", True,
|
||||||
|
"http://example.org/bob", True,
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
store = _make_store()
|
||||||
|
store.cassandra = Mock()
|
||||||
|
store.get_triples_stmt = Mock()
|
||||||
|
mock_async_execute_paged.return_value = [[fake_row]]
|
||||||
|
|
||||||
|
received = []
|
||||||
|
|
||||||
|
async def receiver(msg):
|
||||||
|
received.append(msg)
|
||||||
|
|
||||||
|
await store.get_triples("w", "d", receiver)
|
||||||
|
|
||||||
|
assert received[0].triples[0].g is None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Round-trip unit tests for KnowledgeRequestTranslator.
|
Round-trip unit tests for KnowledgeRequestTranslator and
|
||||||
|
KnowledgeResponseTranslator.
|
||||||
|
|
||||||
Regression coverage: a previous version of the decode side constructed
|
Regression coverage: a previous version of the decode side constructed
|
||||||
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
||||||
|
|
@ -15,9 +16,13 @@ Triples breaks the test.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
|
from trustgraph.messaging.translators.knowledge import (
|
||||||
|
KnowledgeRequestTranslator,
|
||||||
|
KnowledgeResponseTranslator,
|
||||||
|
)
|
||||||
from trustgraph.schema import (
|
from trustgraph.schema import (
|
||||||
KnowledgeRequest,
|
KnowledgeRequest,
|
||||||
|
KnowledgeResponse,
|
||||||
GraphEmbeddings,
|
GraphEmbeddings,
|
||||||
EntityEmbeddings,
|
EntityEmbeddings,
|
||||||
Triples,
|
Triples,
|
||||||
|
|
@ -25,6 +30,8 @@ from trustgraph.schema import (
|
||||||
Metadata,
|
Metadata,
|
||||||
Term,
|
Term,
|
||||||
IRI,
|
IRI,
|
||||||
|
LibraryMetadata,
|
||||||
|
LibraryBlob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,3 +152,161 @@ class TestKnowledgeRequestTranslatorTriples:
|
||||||
assert t.s.iri == "http://example.org/alice"
|
assert t.s.iri == "http://example.org/alice"
|
||||||
assert t.p.iri == "http://example.org/knows"
|
assert t.p.iri == "http://example.org/knows"
|
||||||
assert t.o.iri == "http://example.org/bob"
|
assert t.o.iri == "http://example.org/bob"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeRequestTranslatorLibrary:
|
||||||
|
|
||||||
|
def test_roundtrip_preserves_library_metadata(self, translator):
|
||||||
|
request = KnowledgeRequest(
|
||||||
|
operation="put-kg-core",
|
||||||
|
id="doc-1",
|
||||||
|
library_metadata=LibraryMetadata(
|
||||||
|
id="doc-1",
|
||||||
|
kind="application/pdf",
|
||||||
|
title="Test Document",
|
||||||
|
parent_id="",
|
||||||
|
document_type="source",
|
||||||
|
comments="test comments",
|
||||||
|
tags=["tag1", "tag2"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = translator.encode(request)
|
||||||
|
assert "library-metadata" in encoded
|
||||||
|
lm = encoded["library-metadata"]
|
||||||
|
assert lm["id"] == "doc-1"
|
||||||
|
assert lm["kind"] == "application/pdf"
|
||||||
|
assert lm["title"] == "Test Document"
|
||||||
|
assert lm["parent-id"] == ""
|
||||||
|
assert lm["document-type"] == "source"
|
||||||
|
assert lm["comments"] == "test comments"
|
||||||
|
assert lm["tags"] == ["tag1", "tag2"]
|
||||||
|
|
||||||
|
decoded = translator.decode(encoded)
|
||||||
|
assert decoded.library_metadata is not None
|
||||||
|
assert decoded.library_metadata.id == "doc-1"
|
||||||
|
assert decoded.library_metadata.kind == "application/pdf"
|
||||||
|
assert decoded.library_metadata.title == "Test Document"
|
||||||
|
assert decoded.library_metadata.parent_id == ""
|
||||||
|
assert decoded.library_metadata.document_type == "source"
|
||||||
|
assert decoded.library_metadata.comments == "test comments"
|
||||||
|
assert decoded.library_metadata.tags == ["tag1", "tag2"]
|
||||||
|
|
||||||
|
def test_roundtrip_preserves_child_document_metadata(self, translator):
|
||||||
|
request = KnowledgeRequest(
|
||||||
|
operation="put-kg-core",
|
||||||
|
id="doc-1",
|
||||||
|
library_metadata=LibraryMetadata(
|
||||||
|
id="chunk-1",
|
||||||
|
kind="text/plain",
|
||||||
|
title="Chunk 1",
|
||||||
|
parent_id="doc-1",
|
||||||
|
document_type="chunk",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = translator.encode(request)
|
||||||
|
decoded = translator.decode(encoded)
|
||||||
|
|
||||||
|
assert decoded.library_metadata.parent_id == "doc-1"
|
||||||
|
assert decoded.library_metadata.document_type == "chunk"
|
||||||
|
|
||||||
|
def test_roundtrip_preserves_library_blob(self, translator):
|
||||||
|
request = KnowledgeRequest(
|
||||||
|
operation="put-kg-core",
|
||||||
|
id="doc-1",
|
||||||
|
library_blob=LibraryBlob(
|
||||||
|
id="doc-1",
|
||||||
|
data=b"SGVsbG8gV29ybGQ=",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = translator.encode(request)
|
||||||
|
assert "library-blob" in encoded
|
||||||
|
assert encoded["library-blob"]["id"] == "doc-1"
|
||||||
|
assert encoded["library-blob"]["data"] == "SGVsbG8gV29ybGQ="
|
||||||
|
|
||||||
|
decoded = translator.decode(encoded)
|
||||||
|
assert decoded.library_blob is not None
|
||||||
|
assert decoded.library_blob.id == "doc-1"
|
||||||
|
assert decoded.library_blob.data == "SGVsbG8gV29ybGQ="
|
||||||
|
|
||||||
|
def test_absent_library_fields_decode_as_none(self, translator):
|
||||||
|
decoded = translator.decode({
|
||||||
|
"operation": "get-kg-core",
|
||||||
|
"id": "doc-1",
|
||||||
|
})
|
||||||
|
assert decoded.library_metadata is None
|
||||||
|
assert decoded.library_blob is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeResponseTranslatorLibrary:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def response_translator(self):
|
||||||
|
return KnowledgeResponseTranslator()
|
||||||
|
|
||||||
|
def test_encode_library_metadata(self, response_translator):
|
||||||
|
response = KnowledgeResponse(
|
||||||
|
ids=None,
|
||||||
|
library_metadata=LibraryMetadata(
|
||||||
|
id="doc-1",
|
||||||
|
kind="application/pdf",
|
||||||
|
title="Test",
|
||||||
|
parent_id="",
|
||||||
|
document_type="source",
|
||||||
|
comments="",
|
||||||
|
tags=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
encoded = response_translator.encode(response)
|
||||||
|
assert "library-metadata" in encoded
|
||||||
|
assert encoded["library-metadata"]["id"] == "doc-1"
|
||||||
|
assert encoded["library-metadata"]["kind"] == "application/pdf"
|
||||||
|
assert encoded["library-metadata"]["document-type"] == "source"
|
||||||
|
|
||||||
|
def test_encode_library_blob_bytes_to_string(self, response_translator):
|
||||||
|
response = KnowledgeResponse(
|
||||||
|
ids=None,
|
||||||
|
library_blob=LibraryBlob(
|
||||||
|
id="doc-1",
|
||||||
|
data=b"dGVzdCBkYXRh",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
encoded = response_translator.encode(response)
|
||||||
|
assert "library-blob" in encoded
|
||||||
|
assert encoded["library-blob"]["id"] == "doc-1"
|
||||||
|
assert encoded["library-blob"]["data"] == "dGVzdCBkYXRh"
|
||||||
|
assert isinstance(encoded["library-blob"]["data"], str)
|
||||||
|
|
||||||
|
def test_encode_library_blob_string_passthrough(self, response_translator):
|
||||||
|
response = KnowledgeResponse(
|
||||||
|
ids=None,
|
||||||
|
library_blob=LibraryBlob(
|
||||||
|
id="doc-1",
|
||||||
|
data="already-a-string",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
encoded = response_translator.encode(response)
|
||||||
|
assert encoded["library-blob"]["data"] == "already-a-string"
|
||||||
|
|
||||||
|
def test_library_metadata_is_not_final(self, response_translator):
|
||||||
|
response = KnowledgeResponse(
|
||||||
|
ids=None,
|
||||||
|
library_metadata=LibraryMetadata(id="doc-1"),
|
||||||
|
)
|
||||||
|
_, is_final = response_translator.encode_with_completion(response)
|
||||||
|
assert is_final is False
|
||||||
|
|
||||||
|
def test_library_blob_is_not_final(self, response_translator):
|
||||||
|
response = KnowledgeResponse(
|
||||||
|
ids=None,
|
||||||
|
library_blob=LibraryBlob(id="doc-1", data=b"data"),
|
||||||
|
)
|
||||||
|
_, is_final = response_translator.encode_with_completion(response)
|
||||||
|
assert is_final is False
|
||||||
|
|
||||||
|
def test_eos_is_final(self, response_translator):
|
||||||
|
response = KnowledgeResponse(eos=True)
|
||||||
|
_, is_final = response_translator.encode_with_completion(response)
|
||||||
|
assert is_final is True
|
||||||
|
|
|
||||||
|
|
@ -337,7 +337,7 @@ class Api:
|
||||||
from . bulk_client import BulkClient
|
from . bulk_client import BulkClient
|
||||||
# Extract base URL (remove api/v1/ suffix)
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
|
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||||
return self._bulk_client
|
return self._bulk_client
|
||||||
|
|
||||||
def metrics(self):
|
def metrics(self):
|
||||||
|
|
@ -462,7 +462,7 @@ class Api:
|
||||||
from . async_bulk_client import AsyncBulkClient
|
from . async_bulk_client import AsyncBulkClient
|
||||||
# Extract base URL (remove api/v1/ suffix)
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
|
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||||
return self._async_bulk_client
|
return self._async_bulk_client
|
||||||
|
|
||||||
def async_metrics(self):
|
def async_metrics(self):
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,11 @@ from . types import Triple
|
||||||
class AsyncBulkClient:
|
class AsyncBulkClient:
|
||||||
"""Asynchronous bulk operations client"""
|
"""Asynchronous bulk operations client"""
|
||||||
|
|
||||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||||
self.url: str = self._convert_to_ws_url(url)
|
self.url: str = self._convert_to_ws_url(url)
|
||||||
self.timeout: int = timeout
|
self.timeout: int = timeout
|
||||||
self.token: Optional[str] = token
|
self.token: Optional[str] = token
|
||||||
|
self.workspace: str = workspace
|
||||||
|
|
||||||
def _convert_to_ws_url(self, url: str) -> str:
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
"""Convert HTTP URL to WebSocket URL"""
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
|
@ -25,11 +26,21 @@ class AsyncBulkClient:
|
||||||
else:
|
else:
|
||||||
return f"ws://{url}"
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def _build_ws_url(self, path: str) -> str:
|
||||||
|
"""Build a WebSocket URL with token and workspace query params."""
|
||||||
|
ws_url = f"{self.url}{path}"
|
||||||
|
params = []
|
||||||
|
if self.token:
|
||||||
|
params.append(f"token={self.token}")
|
||||||
|
if self.workspace:
|
||||||
|
params.append(f"workspace={self.workspace}")
|
||||||
|
if params:
|
||||||
|
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||||
|
return ws_url
|
||||||
|
|
||||||
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
||||||
"""Bulk import triples via WebSocket"""
|
"""Bulk import triples via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for triple in triples:
|
async for triple in triples:
|
||||||
|
|
@ -42,9 +53,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
||||||
"""Bulk export triples via WebSocket"""
|
"""Bulk export triples via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -57,9 +66,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""Bulk import graph embeddings via WebSocket"""
|
"""Bulk import graph embeddings via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for embedding in embeddings:
|
async for embedding in embeddings:
|
||||||
|
|
@ -67,9 +74,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
"""Bulk export graph embeddings via WebSocket"""
|
"""Bulk export graph embeddings via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -77,9 +82,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""Bulk import document embeddings via WebSocket"""
|
"""Bulk import document embeddings via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for embedding in embeddings:
|
async for embedding in embeddings:
|
||||||
|
|
@ -87,9 +90,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
"""Bulk export document embeddings via WebSocket"""
|
"""Bulk export document embeddings via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -97,9 +98,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""Bulk import entity contexts via WebSocket"""
|
"""Bulk import entity contexts via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for context in contexts:
|
async for context in contexts:
|
||||||
|
|
@ -107,9 +106,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
"""Bulk export entity contexts via WebSocket"""
|
"""Bulk export entity contexts via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -117,9 +114,7 @@ class AsyncBulkClient:
|
||||||
|
|
||||||
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""Bulk import rows via WebSocket"""
|
"""Bulk import rows via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for row in rows:
|
async for row in rows:
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class BulkClient:
|
||||||
Note: For true async support, use AsyncBulkClient instead.
|
Note: For true async support, use AsyncBulkClient instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||||
"""
|
"""
|
||||||
Initialize synchronous bulk client.
|
Initialize synchronous bulk client.
|
||||||
|
|
||||||
|
|
@ -42,10 +42,12 @@ class BulkClient:
|
||||||
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
|
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
|
||||||
timeout: WebSocket timeout in seconds
|
timeout: WebSocket timeout in seconds
|
||||||
token: Optional bearer token for authentication
|
token: Optional bearer token for authentication
|
||||||
|
workspace: Workspace for data isolation
|
||||||
"""
|
"""
|
||||||
self.url: str = self._convert_to_ws_url(url)
|
self.url: str = self._convert_to_ws_url(url)
|
||||||
self.timeout: int = timeout
|
self.timeout: int = timeout
|
||||||
self.token: Optional[str] = token
|
self.token: Optional[str] = token
|
||||||
|
self.workspace: str = workspace
|
||||||
|
|
||||||
def _convert_to_ws_url(self, url: str) -> str:
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
"""Convert HTTP URL to WebSocket URL"""
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
|
@ -58,6 +60,18 @@ class BulkClient:
|
||||||
else:
|
else:
|
||||||
return f"ws://{url}"
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def _build_ws_url(self, path: str) -> str:
|
||||||
|
"""Build a WebSocket URL with token and workspace query params."""
|
||||||
|
ws_url = f"{self.url}{path}"
|
||||||
|
params = []
|
||||||
|
if self.token:
|
||||||
|
params.append(f"token={self.token}")
|
||||||
|
if self.workspace:
|
||||||
|
params.append(f"workspace={self.workspace}")
|
||||||
|
if params:
|
||||||
|
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||||
|
return ws_url
|
||||||
|
|
||||||
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
||||||
"""Run async coroutine synchronously"""
|
"""Run async coroutine synchronously"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -116,9 +130,7 @@ class BulkClient:
|
||||||
metadata: Optional[Dict[str, Any]], batch_size: int
|
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Async implementation of triple import"""
|
"""Async implementation of triple import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||||
|
|
@ -194,9 +206,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
||||||
"""Async implementation of triple export"""
|
"""Async implementation of triple export"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -238,9 +248,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||||
"""Async implementation of graph embeddings import"""
|
"""Async implementation of graph embeddings import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
for embedding in embeddings:
|
for embedding in embeddings:
|
||||||
|
|
@ -296,9 +304,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
"""Async implementation of graph embeddings export"""
|
"""Async implementation of graph embeddings export"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -336,9 +342,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||||
"""Async implementation of document embeddings import"""
|
"""Async implementation of document embeddings import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
for embedding in embeddings:
|
for embedding in embeddings:
|
||||||
|
|
@ -394,9 +398,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
"""Async implementation of document embeddings export"""
|
"""Async implementation of document embeddings export"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -446,9 +448,7 @@ class BulkClient:
|
||||||
metadata: Optional[Dict[str, Any]], batch_size: int
|
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Async implementation of entity contexts import"""
|
"""Async implementation of entity contexts import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||||
|
|
@ -522,9 +522,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
"""Async implementation of entity contexts export"""
|
"""Async implementation of entity contexts export"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
|
|
@ -562,9 +560,7 @@ class BulkClient:
|
||||||
|
|
||||||
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
||||||
"""Async implementation of rows import"""
|
"""Async implementation of rows import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||||
if self.token:
|
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
|
|
||||||
|
|
@ -167,6 +167,7 @@ class SocketClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.get("type") == "auth-ok":
|
if resp.get("type") == "auth-ok":
|
||||||
|
if self.workspace == "default":
|
||||||
self.workspace = resp.get("workspace", self.workspace)
|
self.workspace = resp.get("workspace", self.workspace)
|
||||||
elif resp.get("type") == "auth-failed":
|
elif resp.get("type") == "auth-failed":
|
||||||
await self._socket.close()
|
await self._socket.close()
|
||||||
|
|
@ -501,6 +502,7 @@ class SocketClient:
|
||||||
|
|
||||||
def put_kg_core(
|
def put_kg_core(
|
||||||
self, id: str, triples=None, graph_embeddings=None,
|
self, id: str, triples=None, graph_embeddings=None,
|
||||||
|
library_metadata=None, library_blob=None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
request = {
|
request = {
|
||||||
"operation": "put-kg-core",
|
"operation": "put-kg-core",
|
||||||
|
|
@ -511,6 +513,10 @@ class SocketClient:
|
||||||
request["triples"] = triples
|
request["triples"] = triples
|
||||||
if graph_embeddings is not None:
|
if graph_embeddings is not None:
|
||||||
request["graph-embeddings"] = graph_embeddings
|
request["graph-embeddings"] = graph_embeddings
|
||||||
|
if library_metadata is not None:
|
||||||
|
request["library-metadata"] = library_metadata
|
||||||
|
if library_blob is not None:
|
||||||
|
request["library-blob"] = library_blob
|
||||||
return self._send_request_sync("knowledge", None, request)
|
return self._send_request_sync("knowledge", None, request)
|
||||||
|
|
||||||
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
|
|
||||||
|
|
@ -103,35 +103,19 @@ def resolve_cassandra_config(
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
default_keyspace: Optional[str] = None
|
default_keyspace: Optional[str] = None,
|
||||||
|
replication_factor: Optional[int] = None,
|
||||||
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
|
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
|
||||||
"""
|
|
||||||
Resolve Cassandra configuration from various sources.
|
|
||||||
|
|
||||||
Can accept either argparse args object or explicit parameters.
|
|
||||||
Converts host string to list format for Cassandra driver.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
|
|
||||||
host: Optional explicit host parameter (overrides args)
|
|
||||||
username: Optional explicit username parameter (overrides args)
|
|
||||||
password: Optional explicit password parameter (overrides args)
|
|
||||||
default_keyspace: Optional default keyspace if not specified elsewhere
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (hosts_list, username, password, keyspace, replication_factor)
|
|
||||||
"""
|
|
||||||
# If args provided, extract values
|
|
||||||
keyspace = None
|
keyspace = None
|
||||||
replication_factor = 1
|
|
||||||
if args is not None:
|
if args is not None:
|
||||||
host = host or getattr(args, 'cassandra_host', None)
|
host = host or getattr(args, 'cassandra_host', None)
|
||||||
username = username or getattr(args, 'cassandra_username', None)
|
username = username or getattr(args, 'cassandra_username', None)
|
||||||
password = password or getattr(args, 'cassandra_password', None)
|
password = password or getattr(args, 'cassandra_password', None)
|
||||||
keyspace = getattr(args, 'cassandra_keyspace', None)
|
keyspace = getattr(args, 'cassandra_keyspace', None)
|
||||||
replication_factor = getattr(args, 'cassandra_replication_factor', 1)
|
replication_factor = replication_factor or getattr(
|
||||||
|
args, 'cassandra_replication_factor', None
|
||||||
|
)
|
||||||
|
|
||||||
# Apply defaults if still None
|
|
||||||
defaults = get_cassandra_defaults()
|
defaults = get_cassandra_defaults()
|
||||||
host = host or defaults['host']
|
host = host or defaults['host']
|
||||||
username = username or defaults['username']
|
username = username or defaults['username']
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ Supports dual output to console and Loki for centralized log aggregation.
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
|
import uuid
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -132,14 +133,12 @@ def setup_logging(args: dict[str, Any]) -> None:
|
||||||
try:
|
try:
|
||||||
from logging_loki import LokiHandler
|
from logging_loki import LokiHandler
|
||||||
|
|
||||||
# Create Loki handler with optional authentication. The
|
instance_id = str(uuid.uuid4())[:8]
|
||||||
# processor label is NOT baked in here — it's stamped onto
|
|
||||||
# each record by _ProcessorIdFilter reading the task-local
|
|
||||||
# contextvar, and logging_loki's emitter reads record.tags
|
|
||||||
# to build per-record Loki labels.
|
|
||||||
loki_handler_kwargs = {
|
loki_handler_kwargs = {
|
||||||
'url': loki_url,
|
'url': loki_url,
|
||||||
'version': "1",
|
'version': "1",
|
||||||
|
'tags': {'instance': instance_id},
|
||||||
}
|
}
|
||||||
|
|
||||||
if loki_username and loki_password:
|
if loki_username and loki_password:
|
||||||
|
|
|
||||||
87
trustgraph-base/trustgraph/base/qdrant_config.py
Normal file
87
trustgraph-base/trustgraph/base/qdrant_config.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from typing import Optional, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def get_qdrant_defaults() -> dict:
|
||||||
|
return {
|
||||||
|
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
|
||||||
|
'api_key': os.getenv('QDRANT_API_KEY'),
|
||||||
|
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
|
||||||
|
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
|
||||||
|
url_help = f"Qdrant URL (default: {defaults['url']})"
|
||||||
|
if 'QDRANT_URL' in os.environ:
|
||||||
|
url_help += " [from QDRANT_URL]"
|
||||||
|
|
||||||
|
api_key_help = "Qdrant API key"
|
||||||
|
if defaults['api_key']:
|
||||||
|
api_key_help += " (default: <set>)"
|
||||||
|
if 'QDRANT_API_KEY' in os.environ:
|
||||||
|
api_key_help += " [from QDRANT_API_KEY]"
|
||||||
|
|
||||||
|
replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
|
||||||
|
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
|
||||||
|
replication_help += " [from QDRANT_REPLICATION_FACTOR]"
|
||||||
|
|
||||||
|
shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
|
||||||
|
if 'QDRANT_SHARD_NUMBER' in os.environ:
|
||||||
|
shard_help += " [from QDRANT_SHARD_NUMBER]"
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--store-uri',
|
||||||
|
default=defaults['url'],
|
||||||
|
help=url_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--api-key',
|
||||||
|
default=defaults['api_key'],
|
||||||
|
help=api_key_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--qdrant-replication-factor',
|
||||||
|
type=int,
|
||||||
|
default=defaults['replication_factor'],
|
||||||
|
help=replication_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--qdrant-shard-number',
|
||||||
|
type=int,
|
||||||
|
default=defaults['shard_number'],
|
||||||
|
help=shard_help,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_qdrant_config(
|
||||||
|
args: Optional[Any] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
replication_factor: Optional[int] = None,
|
||||||
|
shard_number: Optional[int] = None,
|
||||||
|
) -> Tuple[str, Optional[str], int, int]:
|
||||||
|
if args is not None:
|
||||||
|
url = url or getattr(args, 'store_uri', None)
|
||||||
|
api_key = api_key or getattr(args, 'api_key', None)
|
||||||
|
replication_factor = replication_factor or getattr(
|
||||||
|
args, 'qdrant_replication_factor', None
|
||||||
|
)
|
||||||
|
shard_number = shard_number or getattr(
|
||||||
|
args, 'qdrant_shard_number', None
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = get_qdrant_defaults()
|
||||||
|
url = url or defaults['url']
|
||||||
|
api_key = api_key or defaults['api_key']
|
||||||
|
replication_factor = replication_factor or defaults['replication_factor']
|
||||||
|
shard_number = shard_number or defaults['shard_number']
|
||||||
|
|
||||||
|
return url, api_key, replication_factor, shard_number
|
||||||
|
|
@ -2,7 +2,8 @@ from typing import Dict, Any, Tuple, Optional
|
||||||
from ...schema import (
|
from ...schema import (
|
||||||
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
||||||
DocumentEmbeddings, ChunkEmbeddings,
|
DocumentEmbeddings, ChunkEmbeddings,
|
||||||
Metadata, EntityEmbeddings
|
Metadata, EntityEmbeddings,
|
||||||
|
LibraryMetadata, LibraryBlob,
|
||||||
)
|
)
|
||||||
from .base import MessageTranslator
|
from .base import MessageTranslator
|
||||||
from .primitives import ValueTranslator, SubgraphTranslator
|
from .primitives import ValueTranslator, SubgraphTranslator
|
||||||
|
|
@ -61,6 +62,27 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
library_metadata = None
|
||||||
|
if "library-metadata" in data:
|
||||||
|
lm = data["library-metadata"]
|
||||||
|
library_metadata = LibraryMetadata(
|
||||||
|
id=lm.get("id", ""),
|
||||||
|
kind=lm.get("kind", ""),
|
||||||
|
title=lm.get("title", ""),
|
||||||
|
parent_id=lm.get("parent-id", ""),
|
||||||
|
document_type=lm.get("document-type", ""),
|
||||||
|
comments=lm.get("comments", ""),
|
||||||
|
tags=lm.get("tags", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
library_blob = None
|
||||||
|
if "library-blob" in data:
|
||||||
|
lb = data["library-blob"]
|
||||||
|
library_blob = LibraryBlob(
|
||||||
|
id=lb.get("id", ""),
|
||||||
|
data=lb.get("data", b""),
|
||||||
|
)
|
||||||
|
|
||||||
return KnowledgeRequest(
|
return KnowledgeRequest(
|
||||||
operation=data.get("operation"),
|
operation=data.get("operation"),
|
||||||
id=data.get("id"),
|
id=data.get("id"),
|
||||||
|
|
@ -69,6 +91,8 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
triples=triples,
|
triples=triples,
|
||||||
graph_embeddings=graph_embeddings,
|
graph_embeddings=graph_embeddings,
|
||||||
document_embeddings=document_embeddings,
|
document_embeddings=document_embeddings,
|
||||||
|
library_metadata=library_metadata,
|
||||||
|
library_blob=library_blob,
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
||||||
|
|
@ -125,6 +149,26 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if obj.library_metadata:
|
||||||
|
result["library-metadata"] = {
|
||||||
|
"id": obj.library_metadata.id,
|
||||||
|
"kind": obj.library_metadata.kind,
|
||||||
|
"title": obj.library_metadata.title,
|
||||||
|
"parent-id": obj.library_metadata.parent_id,
|
||||||
|
"document-type": obj.library_metadata.document_type,
|
||||||
|
"comments": obj.library_metadata.comments,
|
||||||
|
"tags": obj.library_metadata.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.library_blob:
|
||||||
|
data = obj.library_blob.data
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = data.decode("utf-8")
|
||||||
|
result["library-blob"] = {
|
||||||
|
"id": obj.library_blob.id,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -194,6 +238,32 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Streaming library metadata response
|
||||||
|
if obj.library_metadata:
|
||||||
|
return {
|
||||||
|
"library-metadata": {
|
||||||
|
"id": obj.library_metadata.id,
|
||||||
|
"kind": obj.library_metadata.kind,
|
||||||
|
"title": obj.library_metadata.title,
|
||||||
|
"parent-id": obj.library_metadata.parent_id,
|
||||||
|
"document-type": obj.library_metadata.document_type,
|
||||||
|
"comments": obj.library_metadata.comments,
|
||||||
|
"tags": obj.library_metadata.tags,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Streaming library blob response
|
||||||
|
if obj.library_blob:
|
||||||
|
data = obj.library_blob.data
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = data.decode("utf-8")
|
||||||
|
return {
|
||||||
|
"library-blob": {
|
||||||
|
"id": obj.library_blob.id,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# End of stream marker
|
# End of stream marker
|
||||||
if obj.eos is True:
|
if obj.eos is True:
|
||||||
return {"eos": True}
|
return {"eos": True}
|
||||||
|
|
@ -209,7 +279,9 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
||||||
is_final = (
|
is_final = (
|
||||||
obj.ids is not None or # List response
|
obj.ids is not None or # List response
|
||||||
obj.eos is True or # End of stream
|
obj.eos is True or # End of stream
|
||||||
(not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
|
(not obj.triples and not obj.graph_embeddings
|
||||||
|
and not obj.document_embeddings
|
||||||
|
and not obj.library_metadata and not obj.library_blob) # Empty response
|
||||||
)
|
)
|
||||||
|
|
||||||
return response, is_final
|
return response, is_final
|
||||||
|
|
@ -21,6 +21,21 @@ from .embeddings import GraphEmbeddings, DocumentEmbeddings
|
||||||
# <- ()
|
# <- ()
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LibraryMetadata:
|
||||||
|
id: str = ""
|
||||||
|
kind: str = ""
|
||||||
|
title: str = ""
|
||||||
|
parent_id: str = ""
|
||||||
|
document_type: str = ""
|
||||||
|
comments: str = ""
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LibraryBlob:
|
||||||
|
id: str = ""
|
||||||
|
data: bytes = b""
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KnowledgeRequest:
|
class KnowledgeRequest:
|
||||||
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
|
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
|
||||||
|
|
@ -44,6 +59,10 @@ class KnowledgeRequest:
|
||||||
# put-de-core
|
# put-de-core
|
||||||
document_embeddings: DocumentEmbeddings | None = None
|
document_embeddings: DocumentEmbeddings | None = None
|
||||||
|
|
||||||
|
# put-kg-core (source material)
|
||||||
|
library_metadata: LibraryMetadata | None = None
|
||||||
|
library_blob: LibraryBlob | None = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KnowledgeResponse:
|
class KnowledgeResponse:
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
|
|
@ -52,6 +71,8 @@ class KnowledgeResponse:
|
||||||
triples: Triples | None = None
|
triples: Triples | None = None
|
||||||
graph_embeddings: GraphEmbeddings | None = None
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
document_embeddings: DocumentEmbeddings | None = None
|
document_embeddings: DocumentEmbeddings | None = None
|
||||||
|
library_metadata: LibraryMetadata | None = None
|
||||||
|
library_blob: LibraryBlob | None = None
|
||||||
|
|
||||||
knowledge_request_queue = queue('knowledge', cls='request')
|
knowledge_request_queue = queue('knowledge', cls='request')
|
||||||
knowledge_response_queue = queue('knowledge', cls='response')
|
knowledge_response_queue = queue('knowledge', cls='response')
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ Gets document content from the library by document ID.
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from trustgraph.api import Api
|
import requests
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
@ -13,15 +13,29 @@ default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||||
|
|
||||||
def get_content(url, document_id, output_file, token=None, workspace="default"):
|
def get_content(url, document_id, output_file, token=None, workspace="default"):
|
||||||
|
|
||||||
api = Api(url, token=token, workspace=workspace).library()
|
stream_url = url.rstrip("/") + "/api/v1/document-stream"
|
||||||
|
|
||||||
content = api.get_document_content(id=document_id)
|
params = {
|
||||||
|
"document-id": document_id,
|
||||||
|
"workspace": workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
resp = requests.get(stream_url, params=params, headers=headers, stream=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
if output_file:
|
if output_file:
|
||||||
|
total = 0
|
||||||
with open(output_file, 'wb') as f:
|
with open(output_file, 'wb') as f:
|
||||||
f.write(content)
|
for chunk in resp.iter_content(chunk_size=65536):
|
||||||
print(f"Written {len(content)} bytes to {output_file}")
|
f.write(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
print(f"Written {total} bytes to {output_file}")
|
||||||
else:
|
else:
|
||||||
|
content = resp.content
|
||||||
try:
|
try:
|
||||||
text = content.decode('utf-8')
|
text = content.decode('utf-8')
|
||||||
print(text)
|
print(text)
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,31 @@ def write_ge(f, data):
|
||||||
)
|
)
|
||||||
f.write(msgpack.packb(msg, use_bin_type=True))
|
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||||
|
|
||||||
|
def write_library_metadata(f, data):
|
||||||
|
msg = (
|
||||||
|
"lm",
|
||||||
|
{
|
||||||
|
"i": data["id"],
|
||||||
|
"k": data.get("kind", ""),
|
||||||
|
"t": data.get("title", ""),
|
||||||
|
"p": data.get("parent-id", ""),
|
||||||
|
"d": data.get("document-type", ""),
|
||||||
|
"c": data.get("comments", ""),
|
||||||
|
"g": data.get("tags", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||||
|
|
||||||
|
def write_library_blob(f, data):
|
||||||
|
msg = (
|
||||||
|
"lb",
|
||||||
|
{
|
||||||
|
"i": data["id"],
|
||||||
|
"d": data.get("data", b""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||||
|
|
||||||
def fetch(url, workspace, id, output, token=None):
|
def fetch(url, workspace, id, output, token=None):
|
||||||
|
|
||||||
api = Api(url=url, token=token, workspace=workspace)
|
api = Api(url=url, token=token, workspace=workspace)
|
||||||
|
|
@ -55,6 +80,8 @@ def fetch(url, workspace, id, output, token=None):
|
||||||
try:
|
try:
|
||||||
ge = 0
|
ge = 0
|
||||||
t = 0
|
t = 0
|
||||||
|
lm = 0
|
||||||
|
lb = 0
|
||||||
|
|
||||||
with open(output, "wb") as f:
|
with open(output, "wb") as f:
|
||||||
|
|
||||||
|
|
@ -68,7 +95,15 @@ def fetch(url, workspace, id, output, token=None):
|
||||||
ge += 1
|
ge += 1
|
||||||
write_ge(f, response["graph-embeddings"])
|
write_ge(f, response["graph-embeddings"])
|
||||||
|
|
||||||
print(f"Got: {t} triple, {ge} GE messages.")
|
if "library-metadata" in response:
|
||||||
|
lm += 1
|
||||||
|
write_library_metadata(f, response["library-metadata"])
|
||||||
|
|
||||||
|
if "library-blob" in response:
|
||||||
|
lb += 1
|
||||||
|
write_library_blob(f, response["library-blob"])
|
||||||
|
|
||||||
|
print(f"Got: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
socket.close()
|
socket.close()
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def load_structured_data(
|
||||||
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
||||||
|
|
||||||
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
||||||
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not discovered_schema:
|
if not discovered_schema:
|
||||||
logger.error("Failed to discover suitable schema automatically")
|
logger.error("Failed to discover suitable schema automatically")
|
||||||
print("❌ Could not automatically determine the best schema for your data.")
|
print("❌ Could not automatically determine the best schema for your data.")
|
||||||
|
|
@ -90,7 +90,7 @@ def load_structured_data(
|
||||||
|
|
||||||
# Step 2: Auto-generate descriptor
|
# Step 2: Auto-generate descriptor
|
||||||
logger.info("Step 2: Generating descriptor configuration...")
|
logger.info("Step 2: Generating descriptor configuration...")
|
||||||
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
|
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not auto_descriptor:
|
if not auto_descriptor:
|
||||||
logger.error("Failed to generate descriptor automatically")
|
logger.error("Failed to generate descriptor automatically")
|
||||||
print("❌ Could not automatically generate descriptor configuration.")
|
print("❌ Could not automatically generate descriptor configuration.")
|
||||||
|
|
@ -172,7 +172,7 @@ def load_structured_data(
|
||||||
logger.info(f"Sample chars: {sample_chars} characters")
|
logger.info(f"Sample chars: {sample_chars} characters")
|
||||||
|
|
||||||
# Use the helper function to discover schema (get raw response for display)
|
# Use the helper function to discover schema (get raw response for display)
|
||||||
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
|
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
# Debug: print response type and content
|
# Debug: print response type and content
|
||||||
|
|
@ -203,7 +203,7 @@ def load_structured_data(
|
||||||
# If no schema specified, discover it first
|
# If no schema specified, discover it first
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
logger.info("No schema specified, auto-discovering...")
|
logger.info("No schema specified, auto-discovering...")
|
||||||
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
print("Error: Could not determine schema automatically.")
|
print("Error: Could not determine schema automatically.")
|
||||||
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
||||||
|
|
@ -213,7 +213,7 @@ def load_structured_data(
|
||||||
logger.info(f"Target schema: {schema_name}")
|
logger.info(f"Target schema: {schema_name}")
|
||||||
|
|
||||||
# Generate descriptor using helper function
|
# Generate descriptor using helper function
|
||||||
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
|
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
|
|
||||||
if descriptor:
|
if descriptor:
|
||||||
# Output the generated descriptor
|
# Output the generated descriptor
|
||||||
|
|
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for auto mode
|
# Helper functions for auto mode
|
||||||
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
|
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
|
||||||
"""Auto-discover the best matching schema for the input data
|
"""Auto-discover the best matching schema for the input data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get available schemas
|
# Get available schemas
|
||||||
|
|
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
|
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
|
||||||
"""Auto-generate descriptor configuration for the discovered schema"""
|
"""Auto-generate descriptor configuration for the discovered schema"""
|
||||||
try:
|
try:
|
||||||
# Read sample data
|
# Read sample data
|
||||||
|
|
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get schema definition
|
# Get schema definition
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,23 @@ def read_message(unpacked, id):
|
||||||
},
|
},
|
||||||
"triples": msg["t"],
|
"triples": msg["t"],
|
||||||
}
|
}
|
||||||
|
elif unpacked[0] == "lm":
|
||||||
|
msg = unpacked[1]
|
||||||
|
return "lm", {
|
||||||
|
"id": msg["i"],
|
||||||
|
"kind": msg.get("k", ""),
|
||||||
|
"title": msg.get("t", ""),
|
||||||
|
"parent-id": msg.get("p", ""),
|
||||||
|
"document-type": msg.get("d", ""),
|
||||||
|
"comments": msg.get("c", ""),
|
||||||
|
"tags": msg.get("g", []),
|
||||||
|
}
|
||||||
|
elif unpacked[0] == "lb":
|
||||||
|
msg = unpacked[1]
|
||||||
|
return "lb", {
|
||||||
|
"id": msg["i"],
|
||||||
|
"data": msg.get("d", b""),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
|
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
|
||||||
|
|
||||||
|
|
@ -51,6 +68,8 @@ def put(url, workspace, id, input, token=None):
|
||||||
try:
|
try:
|
||||||
ge = 0
|
ge = 0
|
||||||
t = 0
|
t = 0
|
||||||
|
lm = 0
|
||||||
|
lb = 0
|
||||||
|
|
||||||
with open(input, "rb") as f:
|
with open(input, "rb") as f:
|
||||||
|
|
||||||
|
|
@ -73,10 +92,18 @@ def put(url, workspace, id, input, token=None):
|
||||||
t += 1
|
t += 1
|
||||||
socket.put_kg_core(id, triples=msg)
|
socket.put_kg_core(id, triples=msg)
|
||||||
|
|
||||||
|
elif kind == "lm":
|
||||||
|
lm += 1
|
||||||
|
socket.put_kg_core(id, library_metadata=msg)
|
||||||
|
|
||||||
|
elif kind == "lb":
|
||||||
|
lb += 1
|
||||||
|
socket.put_kg_core(id, library_blob=msg)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unexpected message kind", kind)
|
raise RuntimeError("Unexpected message kind", kind)
|
||||||
|
|
||||||
print(f"Put: {t} triple, {ge} GE messages.")
|
print(f"Put: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
socket.close()
|
socket.close()
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,8 @@ class Processor(AsyncProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="config"
|
default_keyspace="config",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration
|
# Store resolved configuration
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
|
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
|
||||||
from .. schema import DocumentEmbeddings
|
from .. schema import DocumentEmbeddings, LibraryMetadata, LibraryBlob
|
||||||
|
from .. schema import LibrarianRequest, DocumentMetadata
|
||||||
from .. knowledge import hash
|
from .. knowledge import hash
|
||||||
from .. exceptions import RequestError
|
from .. exceptions import RequestError
|
||||||
from .. tables.knowledge import KnowledgeTableStore
|
from .. tables.knowledge import KnowledgeTableStore
|
||||||
|
|
@ -18,7 +19,7 @@ class KnowledgeManager:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cassandra_host, cassandra_username, cassandra_password,
|
self, cassandra_host, cassandra_username, cassandra_password,
|
||||||
keyspace, flow_config, replication_factor=1,
|
keyspace, flow_config, librarian=None, replication_factor=1,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.table_store = KnowledgeTableStore(
|
self.table_store = KnowledgeTableStore(
|
||||||
|
|
@ -26,6 +27,9 @@ class KnowledgeManager:
|
||||||
replication_factor
|
replication_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.librarian = librarian
|
||||||
|
self._pending_library_metadata = {}
|
||||||
|
|
||||||
self.loader_queue = asyncio.Queue(maxsize=20)
|
self.loader_queue = asyncio.Queue(maxsize=20)
|
||||||
self.background_task = None
|
self.background_task = None
|
||||||
self.flow_config = flow_config
|
self.flow_config = flow_config
|
||||||
|
|
@ -86,6 +90,9 @@ class KnowledgeManager:
|
||||||
publish_ge,
|
publish_ge,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.librarian:
|
||||||
|
await self._stream_library_docs(request.id, respond)
|
||||||
|
|
||||||
logger.debug("Knowledge core retrieval complete")
|
logger.debug("Knowledge core retrieval complete")
|
||||||
|
|
||||||
await respond(
|
await respond(
|
||||||
|
|
@ -122,6 +129,12 @@ class KnowledgeManager:
|
||||||
workspace, request.graph_embeddings
|
workspace, request.graph_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if request.library_metadata and self.librarian:
|
||||||
|
await self._put_library_metadata(request.library_metadata, workspace)
|
||||||
|
|
||||||
|
if request.library_blob and self.librarian:
|
||||||
|
await self._put_library_blob(request.library_blob, workspace)
|
||||||
|
|
||||||
await respond(
|
await respond(
|
||||||
KnowledgeResponse(
|
KnowledgeResponse(
|
||||||
error = None,
|
error = None,
|
||||||
|
|
@ -250,6 +263,112 @@ class KnowledgeManager:
|
||||||
|
|
||||||
await self.loader_queue.put((request, respond, workspace))
|
await self.loader_queue.put((request, respond, workspace))
|
||||||
|
|
||||||
|
async def _stream_library_docs(self, document_id, respond):
|
||||||
|
|
||||||
|
try:
|
||||||
|
root_meta = await self.librarian.fetch_document_metadata(
|
||||||
|
document_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not fetch library metadata for {document_id}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if root_meta is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._stream_one_doc(root_meta, respond)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self.librarian.request(
|
||||||
|
LibrarianRequest(
|
||||||
|
operation="list-children",
|
||||||
|
document_id=document_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not list children for {document_id}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for child_meta in resp.document_metadatas:
|
||||||
|
await self._stream_one_doc(child_meta, respond)
|
||||||
|
|
||||||
|
async def _stream_one_doc(self, doc_meta, respond):
|
||||||
|
|
||||||
|
lm = LibraryMetadata(
|
||||||
|
id=doc_meta.id,
|
||||||
|
kind=doc_meta.kind,
|
||||||
|
title=doc_meta.title,
|
||||||
|
parent_id=doc_meta.parent_id,
|
||||||
|
document_type=doc_meta.document_type,
|
||||||
|
comments=doc_meta.comments,
|
||||||
|
tags=doc_meta.tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(library_metadata=lm)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = await self.librarian.fetch_document_content(
|
||||||
|
doc_meta.id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not fetch content for {doc_meta.id}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await respond(
|
||||||
|
KnowledgeResponse(
|
||||||
|
library_blob=LibraryBlob(
|
||||||
|
id=doc_meta.id,
|
||||||
|
data=content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _put_library_metadata(self, lm, workspace):
|
||||||
|
self._pending_library_metadata[lm.id] = lm
|
||||||
|
|
||||||
|
async def _put_library_blob(self, lb, workspace):
|
||||||
|
|
||||||
|
lm = self._pending_library_metadata.pop(lb.id, None)
|
||||||
|
if lm is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Received library blob for {lb.id} with no preceding metadata"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
doc_meta = DocumentMetadata(
|
||||||
|
id=lm.id,
|
||||||
|
kind=lm.kind,
|
||||||
|
title=lm.title,
|
||||||
|
parent_id=lm.parent_id,
|
||||||
|
document_type=lm.document_type,
|
||||||
|
comments=lm.comments,
|
||||||
|
tags=lm.tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
if lm.parent_id:
|
||||||
|
operation = "add-child-document"
|
||||||
|
else:
|
||||||
|
operation = "add-document"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.librarian.request(
|
||||||
|
LibrarianRequest(
|
||||||
|
operation=operation,
|
||||||
|
document_id=lm.id,
|
||||||
|
document_metadata=doc_meta,
|
||||||
|
content=lb.data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "already exists" in str(e):
|
||||||
|
logger.debug(f"Library document {lm.id} already exists, skipping")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not save library document {lm.id}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not save library document {lm.id}: {e}")
|
||||||
|
|
||||||
async def core_loader(self):
|
async def core_loader(self):
|
||||||
|
|
||||||
logger.info("Knowledge background processor running...")
|
logger.info("Knowledge background processor running...")
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import logging
|
||||||
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
||||||
from .. base import ConsumerMetrics, ProducerMetrics
|
from .. base import ConsumerMetrics, ProducerMetrics
|
||||||
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
from .. base import LibrarianClient
|
||||||
|
|
||||||
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
|
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
|
||||||
from .. schema import knowledge_request_queue, knowledge_response_queue
|
from .. schema import knowledge_request_queue, knowledge_response_queue
|
||||||
|
|
@ -60,7 +61,8 @@ class Processor(WorkspaceProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="knowledge"
|
default_keyspace="knowledge",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
|
|
@ -77,12 +79,17 @@ class Processor(WorkspaceProcessor):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.librarian_client = LibrarianClient(
|
||||||
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
||||||
|
)
|
||||||
|
|
||||||
self.knowledge = KnowledgeManager(
|
self.knowledge = KnowledgeManager(
|
||||||
cassandra_host = self.cassandra_host,
|
cassandra_host = self.cassandra_host,
|
||||||
cassandra_username = self.cassandra_username,
|
cassandra_username = self.cassandra_username,
|
||||||
cassandra_password = self.cassandra_password,
|
cassandra_password = self.cassandra_password,
|
||||||
keyspace = keyspace,
|
keyspace = keyspace,
|
||||||
flow_config = self,
|
flow_config = self,
|
||||||
|
librarian = self.librarian_client,
|
||||||
replication_factor = replication_factor,
|
replication_factor = replication_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -156,6 +163,7 @@ class Processor(WorkspaceProcessor):
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
|
await self.librarian_client.start()
|
||||||
|
|
||||||
async def on_knowledge_config(self, workspace, config, version):
|
async def on_knowledge_config(self, workspace, config, version):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -219,7 +219,14 @@ class Processor(FlowProcessor):
|
||||||
source_doc_id = v.document_id or v.metadata.id
|
source_doc_id = v.document_id or v.metadata.id
|
||||||
|
|
||||||
# Run OCR, get per-page markdown
|
# Run OCR, get per-page markdown
|
||||||
|
try:
|
||||||
pages = self.ocr(blob)
|
pages = self.ocr(blob)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to decode PDF {source_doc_id}: "
|
||||||
|
f"{type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
for markdown, page_num in pages:
|
for markdown, page_num in pages:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_pdf(content):
|
||||||
|
return content.lstrip().startswith(b"%PDF-")
|
||||||
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
@ -94,14 +98,10 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
|
|
||||||
temp_path = fp.name
|
|
||||||
|
|
||||||
# Check if we should fetch from librarian or use inline data
|
# Check if we should fetch from librarian or use inline data
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
# Fetch from librarian via Pulsar
|
# Fetch from librarian via Pulsar
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
fp.close()
|
|
||||||
|
|
||||||
content = await flow.librarian.fetch_document_content(
|
content = await flow.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
|
|
@ -113,13 +113,21 @@ class Processor(FlowProcessor):
|
||||||
content = content.encode('utf-8')
|
content = content.encode('utf-8')
|
||||||
decoded_content = base64.b64decode(content)
|
decoded_content = base64.b64decode(content)
|
||||||
|
|
||||||
with open(temp_path, 'wb') as f:
|
|
||||||
f.write(decoded_content)
|
|
||||||
|
|
||||||
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
|
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
|
||||||
else:
|
else:
|
||||||
# Use inline data (backward compatibility)
|
# Use inline data (backward compatibility)
|
||||||
fp.write(base64.b64decode(v.data))
|
decoded_content = base64.b64decode(v.data)
|
||||||
|
|
||||||
|
if not _looks_like_pdf(decoded_content):
|
||||||
|
logger.error(
|
||||||
|
f"Document {v.metadata.id} is not valid PDF content. "
|
||||||
|
f"Ignoring document."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
|
||||||
|
temp_path = fp.name
|
||||||
|
fp.write(decoded_content)
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
global PyPDFLoader
|
global PyPDFLoader
|
||||||
|
|
@ -129,7 +137,15 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
PyPDFLoader = _cls
|
PyPDFLoader = _cls
|
||||||
loader = PyPDFLoader(temp_path)
|
loader = PyPDFLoader(temp_path)
|
||||||
|
try:
|
||||||
pages = loader.load()
|
pages = loader.load()
|
||||||
|
except Exception as e:
|
||||||
|
source_doc_id = v.document_id or v.metadata.id
|
||||||
|
logger.error(
|
||||||
|
f"Failed to decode PDF {source_doc_id}: "
|
||||||
|
f"{type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Get the source document ID
|
# Get the source document ID
|
||||||
source_doc_id = v.document_id or v.metadata.id
|
source_doc_id = v.document_id or v.metadata.id
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import logging
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from cassandra.query import BatchStatement, SimpleStatement
|
from cassandra.query import BatchStatement, SimpleStatement
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
from ..tables.cassandra_async import async_execute
|
from ..tables.cassandra_async import async_execute
|
||||||
|
|
||||||
|
|
@ -41,13 +41,15 @@ class KnowledgeGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hosts=None,
|
self, hosts=None,
|
||||||
keyspace="trustgraph", username=None, password=None
|
keyspace="trustgraph", username=None, password=None,
|
||||||
|
replication_factor=1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if hosts is None:
|
if hosts is None:
|
||||||
hosts = ["localhost"]
|
hosts = ["localhost"]
|
||||||
|
|
||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
|
self.replication_factor = replication_factor
|
||||||
self.username = username
|
self.username = username
|
||||||
|
|
||||||
# 7-table schema for quads with full query pattern support
|
# 7-table schema for quads with full query pattern support
|
||||||
|
|
@ -68,7 +70,7 @@ class KnowledgeGraph:
|
||||||
self.collection_metadata_table = "collection_metadata"
|
self.collection_metadata_table = "collection_metadata"
|
||||||
|
|
||||||
if username and password:
|
if username and password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||||
else:
|
else:
|
||||||
|
|
@ -92,7 +94,7 @@ class KnowledgeGraph:
|
||||||
create keyspace if not exists {self.keyspace}
|
create keyspace if not exists {self.keyspace}
|
||||||
with replication = {{
|
with replication = {{
|
||||||
'class' : 'SimpleStrategy',
|
'class' : 'SimpleStrategy',
|
||||||
'replication_factor' : 1
|
'replication_factor' : {self.replication_factor}
|
||||||
}};
|
}};
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hosts=None,
|
self, hosts=None,
|
||||||
keyspace="trustgraph", username=None, password=None
|
keyspace="trustgraph", username=None, password=None,
|
||||||
|
replication_factor=1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if hosts is None:
|
if hosts is None:
|
||||||
hosts = ["localhost"]
|
hosts = ["localhost"]
|
||||||
|
|
||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
|
self.replication_factor = replication_factor
|
||||||
self.username = username
|
self.username = username
|
||||||
|
|
||||||
# 2-table entity-centric schema
|
# 2-table entity-centric schema
|
||||||
|
|
@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
|
||||||
self.collection_metadata_table = "collection_metadata"
|
self.collection_metadata_table = "collection_metadata"
|
||||||
|
|
||||||
if username and password:
|
if username and password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
||||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
||||||
else:
|
else:
|
||||||
|
|
@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
|
||||||
create keyspace if not exists {self.keyspace}
|
create keyspace if not exists {self.keyspace}
|
||||||
with replication = {{
|
with replication = {{
|
||||||
'class' : 'SimpleStrategy',
|
'class' : 'SimpleStrategy',
|
||||||
'replication_factor' : 1
|
'replication_factor' : {self.replication_factor}
|
||||||
}};
|
}};
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,39 @@ class CoreExport:
|
||||||
enc = msgpack.packb(msg)
|
enc = msgpack.packb(msg)
|
||||||
await response.write(enc)
|
await response.write(enc)
|
||||||
|
|
||||||
|
if "library-metadata" in resp:
|
||||||
|
|
||||||
|
data = resp["library-metadata"]
|
||||||
|
msg = (
|
||||||
|
"lm",
|
||||||
|
{
|
||||||
|
"i": data["id"],
|
||||||
|
"k": data.get("kind", ""),
|
||||||
|
"t": data.get("title", ""),
|
||||||
|
"p": data.get("parent-id", ""),
|
||||||
|
"d": data.get("document-type", ""),
|
||||||
|
"c": data.get("comments", ""),
|
||||||
|
"g": data.get("tags", []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = msgpack.packb(msg)
|
||||||
|
await response.write(enc)
|
||||||
|
|
||||||
|
if "library-blob" in resp:
|
||||||
|
|
||||||
|
data = resp["library-blob"]
|
||||||
|
msg = (
|
||||||
|
"lb",
|
||||||
|
{
|
||||||
|
"i": data["id"],
|
||||||
|
"d": data.get("data", b""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = msgpack.packb(msg, use_bin_type=True)
|
||||||
|
await response.write(enc)
|
||||||
|
|
||||||
await kr.process(
|
await kr.process(
|
||||||
{
|
{
|
||||||
"operation": "get-kg-core",
|
"operation": "get-kg-core",
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,39 @@ class CoreImport:
|
||||||
|
|
||||||
await kr.process(msg)
|
await kr.process(msg)
|
||||||
|
|
||||||
|
elif unpacked[0] == "lm":
|
||||||
|
msg = unpacked[1]
|
||||||
|
msg = {
|
||||||
|
"operation": "put-kg-core",
|
||||||
|
"workspace": workspace,
|
||||||
|
"id": id,
|
||||||
|
"library-metadata": {
|
||||||
|
"id": msg["i"],
|
||||||
|
"kind": msg.get("k", ""),
|
||||||
|
"title": msg.get("t", ""),
|
||||||
|
"parent-id": msg.get("p", ""),
|
||||||
|
"document-type": msg.get("d", ""),
|
||||||
|
"comments": msg.get("c", ""),
|
||||||
|
"tags": msg.get("g", []),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await kr.process(msg)
|
||||||
|
|
||||||
|
elif unpacked[0] == "lb":
|
||||||
|
msg = unpacked[1]
|
||||||
|
msg = {
|
||||||
|
"operation": "put-kg-core",
|
||||||
|
"workspace": workspace,
|
||||||
|
"id": id,
|
||||||
|
"library-blob": {
|
||||||
|
"id": msg["i"],
|
||||||
|
"data": msg.get("d", b""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await kr.process(msg)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Core import exception: {e}", exc_info=True)
|
logger.error(f"Core import exception: {e}", exc_info=True)
|
||||||
await error(str(e))
|
await error(str(e))
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
from . librarian import LibrarianRequestor
|
from . librarian import LibrarianRequestor
|
||||||
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -23,10 +24,13 @@ class DocumentStreamExport:
|
||||||
|
|
||||||
response = await ok()
|
response = await ok()
|
||||||
|
|
||||||
|
uid = str(uuid.uuid4())
|
||||||
lr = LibrarianRequestor(
|
lr = LibrarianRequestor(
|
||||||
backend=self.backend,
|
backend=self.backend,
|
||||||
consumer="api-gateway-doc-stream-" + str(uuid.uuid4()),
|
consumer="api-gateway-doc-stream-" + uid,
|
||||||
subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()),
|
subscriber="api-gateway-doc-stream-" + uid,
|
||||||
|
request_queue=f"{librarian_request_queue}:{workspace}",
|
||||||
|
response_queue=f"{librarian_response_queue}:{workspace}",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import queue
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from ..capabilities import PUBLIC, AUTHENTICATED
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -156,15 +158,18 @@ class Mux:
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Resolve workspace first (default-fill from the caller's
|
# Resolve workspace (default-fill from the caller's
|
||||||
# bound workspace), then ask the regime to authorise the
|
# bound workspace). Workspace resolution applies to all
|
||||||
# service-level capability against the matched
|
# operations regardless of capability level.
|
||||||
# operation's resource shape.
|
|
||||||
try:
|
try:
|
||||||
await enforce_workspace(data, self.identity, self.auth)
|
await enforce_workspace(data, self.identity, self.auth)
|
||||||
if isinstance(inner, dict):
|
if isinstance(inner, dict):
|
||||||
await enforce_workspace(inner, self.identity, self.auth)
|
await enforce_workspace(inner, self.identity, self.auth)
|
||||||
|
|
||||||
|
# Authorisation: capability sentinels short-circuit
|
||||||
|
# the regime call; capability strings go through
|
||||||
|
# authorise().
|
||||||
|
if op.capability not in (PUBLIC, AUTHENTICATED):
|
||||||
if data.get("flow"):
|
if data.get("flow"):
|
||||||
resource = {
|
resource = {
|
||||||
"workspace": data.get("workspace", ""),
|
"workspace": data.get("workspace", ""),
|
||||||
|
|
@ -173,8 +178,9 @@ class Mux:
|
||||||
parameters = {}
|
parameters = {}
|
||||||
else:
|
else:
|
||||||
# Build a minimal RequestContext so the matched
|
# Build a minimal RequestContext so the matched
|
||||||
# operation's own extractors decide resource and
|
# operation's own extractors decide resource
|
||||||
# parameters — same path the HTTP endpoints take.
|
# and parameters — same path the HTTP
|
||||||
|
# endpoints take.
|
||||||
from ..registry import RequestContext
|
from ..registry import RequestContext
|
||||||
ctx = RequestContext(
|
ctx = RequestContext(
|
||||||
body=inner if isinstance(inner, dict) else {},
|
body=inner if isinstance(inner, dict) else {},
|
||||||
|
|
@ -288,6 +294,8 @@ class Mux:
|
||||||
await self.maybe_tidy_workers(workers)
|
await self.maybe_tidy_workers(workers)
|
||||||
|
|
||||||
async def responder(resp, fin):
|
async def responder(resp, fin):
|
||||||
|
if self.ws is None:
|
||||||
|
return
|
||||||
await self.ws.send_json({
|
await self.ws.send_json({
|
||||||
"id": id,
|
"id": id,
|
||||||
"response": resp,
|
"response": resp,
|
||||||
|
|
@ -321,6 +329,8 @@ class Mux:
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if self.ws is None:
|
||||||
|
return
|
||||||
await self.ws.send_json({
|
await self.ws.send_json({
|
||||||
"id": id,
|
"id": id,
|
||||||
"error": {"message": str(e), "type": "error"},
|
"error": {"message": str(e), "type": "error"},
|
||||||
|
|
|
||||||
|
|
@ -117,8 +117,10 @@ class SocketEndpoint:
|
||||||
|
|
||||||
running = Running()
|
running = Running()
|
||||||
|
|
||||||
|
params = dict(request.query)
|
||||||
|
params.update(request.match_info)
|
||||||
dispatcher = await self.dispatcher(
|
dispatcher = await self.dispatcher(
|
||||||
ws, running, request.match_info
|
ws, running, params
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_task = tg.create_task(
|
worker_task = tg.create_task(
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@ class Processor(AsyncProcessor):
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="iam",
|
default_keyspace="iam",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,9 @@ class Librarian:
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if object_id is None:
|
||||||
|
raise RequestError(f"Document not found: {request.document_id}")
|
||||||
|
|
||||||
content = await self.blob_store.get(
|
content = await self.blob_store.get(
|
||||||
object_id
|
object_id
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
||||||
|
|
@ -54,6 +55,16 @@ default_object_store_access_key = "object-user"
|
||||||
default_object_store_secret_key = "object-password"
|
default_object_store_secret_key = "object-password"
|
||||||
default_object_store_use_ssl = False
|
default_object_store_use_ssl = False
|
||||||
default_object_store_region = None
|
default_object_store_region = None
|
||||||
|
|
||||||
|
# Environment variables consulted as a fallback when the
|
||||||
|
# corresponding params field is not set in the processor-group YAML
|
||||||
|
# or via CLI. Intended for K8s Secret / env-var injection so
|
||||||
|
# credentials never have to live in the YAML (and thus in git).
|
||||||
|
ENV_OBJECT_STORE_ENDPOINT = "OBJECT_STORE_ENDPOINT"
|
||||||
|
ENV_OBJECT_STORE_ACCESS_KEY = "OBJECT_STORE_ACCESS_KEY"
|
||||||
|
ENV_OBJECT_STORE_SECRET_KEY = "OBJECT_STORE_SECRET_KEY"
|
||||||
|
ENV_OBJECT_STORE_USE_SSL = "OBJECT_STORE_USE_SSL"
|
||||||
|
ENV_OBJECT_STORE_REGION = "OBJECT_STORE_REGION"
|
||||||
default_cassandra_host = "cassandra"
|
default_cassandra_host = "cassandra"
|
||||||
default_min_chunk_size = 1 # No minimum by default (for Garage)
|
default_min_chunk_size = 1 # No minimum by default (for Garage)
|
||||||
|
|
||||||
|
|
@ -89,22 +100,36 @@ class Processor(WorkspaceProcessor):
|
||||||
"config_response_queue", default_config_response_queue
|
"config_response_queue", default_config_response_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint)
|
# Resolve object-store config. Precedence: explicit params
|
||||||
object_store_access_key = params.get(
|
# (CLI / processor-group YAML) → environment variable →
|
||||||
"object_store_access_key",
|
# hardcoded default. The env-var path lets K8s Secrets feed
|
||||||
default_object_store_access_key
|
# credentials without them appearing in the YAML.
|
||||||
|
object_store_endpoint = (
|
||||||
|
params.get("object_store_endpoint")
|
||||||
|
or os.environ.get(ENV_OBJECT_STORE_ENDPOINT)
|
||||||
|
or default_object_store_endpoint
|
||||||
)
|
)
|
||||||
object_store_secret_key = params.get(
|
object_store_access_key = (
|
||||||
"object_store_secret_key",
|
params.get("object_store_access_key")
|
||||||
default_object_store_secret_key
|
or os.environ.get(ENV_OBJECT_STORE_ACCESS_KEY)
|
||||||
|
or default_object_store_access_key
|
||||||
)
|
)
|
||||||
object_store_use_ssl = params.get(
|
object_store_secret_key = (
|
||||||
"object_store_use_ssl",
|
params.get("object_store_secret_key")
|
||||||
default_object_store_use_ssl
|
or os.environ.get(ENV_OBJECT_STORE_SECRET_KEY)
|
||||||
|
or default_object_store_secret_key
|
||||||
)
|
)
|
||||||
object_store_region = params.get(
|
object_store_use_ssl = params.get("object_store_use_ssl")
|
||||||
"object_store_region",
|
if object_store_use_ssl is None:
|
||||||
default_object_store_region
|
env_ssl = os.environ.get(ENV_OBJECT_STORE_USE_SSL)
|
||||||
|
if env_ssl is not None:
|
||||||
|
object_store_use_ssl = env_ssl.lower() in ("true", "1", "yes")
|
||||||
|
else:
|
||||||
|
object_store_use_ssl = default_object_store_use_ssl
|
||||||
|
object_store_region = (
|
||||||
|
params.get("object_store_region")
|
||||||
|
or os.environ.get(ENV_OBJECT_STORE_REGION)
|
||||||
|
or default_object_store_region
|
||||||
)
|
)
|
||||||
|
|
||||||
min_chunk_size = params.get(
|
min_chunk_size = params.get(
|
||||||
|
|
@ -121,7 +146,8 @@ class Processor(WorkspaceProcessor):
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="librarian"
|
default_keyspace="librarian",
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration
|
# Store resolved configuration
|
||||||
|
|
|
||||||
|
|
@ -12,31 +12,33 @@ from qdrant_client import QdrantClient
|
||||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||||
from .... schema import Error
|
from .... schema import Error
|
||||||
from .... base import DocumentEmbeddingsQueryService
|
from .... base import DocumentEmbeddingsQueryService
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "doc-embeddings-query"
|
default_ident = "doc-embeddings-query"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(DocumentEmbeddingsQueryService):
|
class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
#optional api key
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
api_key = params.get("api_key", None)
|
url=store_uri,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
async def query_document_embeddings(self, workspace, msg):
|
async def query_document_embeddings(self, workspace, msg):
|
||||||
|
|
||||||
|
|
@ -85,18 +87,7 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
DocumentEmbeddingsQueryService.add_args(parser)
|
DocumentEmbeddingsQueryService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'API key for qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
|
||||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||||
from .... schema import Error, Term, IRI, LITERAL
|
from .... schema import Error, Term, IRI, LITERAL
|
||||||
from .... base import GraphEmbeddingsQueryService
|
from .... base import GraphEmbeddingsQueryService
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "graph-embeddings-query"
|
default_ident = "graph-embeddings-query"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(GraphEmbeddingsQueryService):
|
class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
#optional api key
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
api_key = params.get("api_key", None)
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
def create_value(self, ent):
|
def create_value(self, ent):
|
||||||
if ent.startswith("http://") or ent.startswith("https://"):
|
if ent.startswith("http://") or ent.startswith("https://"):
|
||||||
|
|
@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
GraphEmbeddingsQueryService.add_args(parser)
|
GraphEmbeddingsQueryService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'API key for qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
|
||||||
# Create keyspace
|
# Create keyspace
|
||||||
self.session.execute(f"""
|
self.session.execute(f"""
|
||||||
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
|
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
|
||||||
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
|
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Create triples table optimized for SPARQL queries
|
# Create triples table optimized for SPARQL queries
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,12 @@ from .... schema import (
|
||||||
RowIndexMatch, Error
|
RowIndexMatch, Error
|
||||||
)
|
)
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "row-embeddings-query"
|
default_ident = "row-embeddings-query"
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
default_concurrency = 10
|
default_concurrency = 10
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,13 +35,17 @@ class Processor(FlowProcessor):
|
||||||
id = params.get("id", default_ident)
|
id = params.get("id", default_ident)
|
||||||
concurrency = params.get("concurrency", default_concurrency)
|
concurrency = params.get("concurrency", default_concurrency)
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, _, _ = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -62,7 +66,7 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
def sanitize_name(self, name: str) -> str:
|
||||||
"""Sanitize names for Qdrant collection naming"""
|
"""Sanitize names for Qdrant collection naming"""
|
||||||
|
|
@ -192,21 +196,9 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help='API key for Qdrant (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--concurrency',
|
'-c', '--concurrency',
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
from .... schema import Error, RowSchema, Field as SchemaField
|
from .... schema import Error, RowSchema, Field as SchemaField
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
from .... tables.cassandra_async import async_execute
|
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
|
||||||
|
|
||||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
|
||||||
for index_name in index_names:
|
for index_name in index_names:
|
||||||
if index_name in filters:
|
if index_name in filters:
|
||||||
value = filters[index_name]
|
value = filters[index_name]
|
||||||
|
if value == "" or value is None:
|
||||||
|
continue
|
||||||
# Single field index -> single element list
|
# Single field index -> single element list
|
||||||
index_value = [str(value)]
|
index_value = [str(value)]
|
||||||
return (index_name, index_value)
|
return (index_name, index_value)
|
||||||
|
|
@ -282,9 +284,11 @@ class Processor(FlowProcessor):
|
||||||
query += f" LIMIT {limit}"
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
pages = await async_execute_paged(
|
||||||
for row in rows:
|
self.session, query, params
|
||||||
# Convert data map to dict with proper field names
|
)
|
||||||
|
for page in pages:
|
||||||
|
for row in page:
|
||||||
row_dict = dict(row.data) if row.data else {}
|
row_dict = dict(row.data) if row.data else {}
|
||||||
results.append(row_dict)
|
results.append(row_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
|
||||||
# Query using the first index (arbitrary choice for scan)
|
# Query using the first index (arbitrary choice for scan)
|
||||||
primary_index = index_names[0]
|
primary_index = index_names[0]
|
||||||
|
|
||||||
# We need to scan all values for this index
|
|
||||||
# This requires ALLOW FILTERING or a different approach
|
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT data, source FROM {safe_keyspace}.rows
|
SELECT data, source FROM {safe_keyspace}.rows
|
||||||
WHERE collection = %s
|
WHERE collection = %s
|
||||||
|
|
@ -320,18 +322,19 @@ class Processor(FlowProcessor):
|
||||||
params = [collection, schema_name, primary_index]
|
params = [collection, schema_name, primary_index]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
def row_filter(row):
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
row_dict = dict(row.data) if row.data else {}
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
return self._matches_filters(row_dict, filters, row_schema)
|
||||||
|
|
||||||
# Apply post-filters
|
matched_rows = await async_scan(
|
||||||
if self._matches_filters(row_dict, filters, row_schema):
|
self.session, query, params,
|
||||||
|
row_filter=row_filter,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
for row in matched_rows:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
results.append(row_dict)
|
results.append(row_dict)
|
||||||
|
|
||||||
if limit and len(results) >= limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
|
||||||
# Parse filter key for operator
|
# Parse filter key for operator
|
||||||
if '_' in filter_key:
|
if '_' in filter_key:
|
||||||
parts = filter_key.rsplit('_', 1)
|
parts = filter_key.rsplit('_', 1)
|
||||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
|
||||||
field_name = parts[0]
|
field_name = parts[0]
|
||||||
operator = parts[1]
|
operator = parts[1]
|
||||||
else:
|
else:
|
||||||
|
|
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
|
||||||
elif operator == 'in':
|
elif operator == 'in':
|
||||||
if str(row_value) not in [str(v) for v in filter_value]:
|
if str(row_value) not in [str(v) for v in filter_value]:
|
||||||
return False
|
return False
|
||||||
|
elif operator == 'not':
|
||||||
|
if str(row_value) == str(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'startsWith':
|
||||||
|
if not str(row_value).startswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'endsWith':
|
||||||
|
if not str(row_value).endswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'not_in':
|
||||||
|
if str(row_value) in [str(v) for v in filter_value]:
|
||||||
|
return False
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,29 +14,36 @@ from qdrant_client.models import Distance, VectorParams
|
||||||
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
from .... base import ConsumerMetrics, ProducerMetrics
|
from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "doc-embeddings-write"
|
default_ident = "doc-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -61,6 +68,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
vectors_config=VectorParams(
|
vectors_config=VectorParams(
|
||||||
size=dim, distance=Distance.COSINE
|
size=dim, distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -109,18 +118,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
DocumentEmbeddingsStoreService.add_args(parser)
|
DocumentEmbeddingsStoreService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'Qdrant API key (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams
|
||||||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||||
from .... base import AsyncProcessor, Consumer, Producer
|
from .... base import AsyncProcessor, Consumer, Producer
|
||||||
from .... base import ConsumerMetrics, ProducerMetrics
|
from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
from .... schema import IRI, LITERAL
|
from .... schema import IRI, LITERAL
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -29,29 +30,34 @@ def get_term_value(term):
|
||||||
elif term.type == LITERAL:
|
elif term.type == LITERAL:
|
||||||
return term.value
|
return term.value
|
||||||
else:
|
else:
|
||||||
# For blank nodes or other types, use id or value
|
|
||||||
return term.id or term.value
|
return term.id or term.value
|
||||||
|
|
||||||
|
|
||||||
default_ident = "graph-embeddings-write"
|
default_ident = "graph-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -76,6 +82,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
vectors_config=VectorParams(
|
vectors_config=VectorParams(
|
||||||
size=dim, distance=Distance.COSINE
|
size=dim, distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -128,18 +136,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
GraphEmbeddingsStoreService.add_args(parser)
|
GraphEmbeddingsStoreService.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help=f'Qdrant API key'
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@ class Processor(FlowProcessor):
|
||||||
host=params.get("cassandra_host"),
|
host=params.get("cassandra_host"),
|
||||||
username=params.get("cassandra_username"),
|
username=params.get("cassandra_username"),
|
||||||
password=params.get("cassandra_password"),
|
password=params.get("cassandra_password"),
|
||||||
default_keyspace='knowledge'
|
default_keyspace='knowledge',
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||||
from .... schema import RowEmbeddings
|
from .... schema import RowEmbeddings
|
||||||
from .... base import FlowProcessor, ConsumerSpec
|
from .... base import FlowProcessor, ConsumerSpec
|
||||||
from .... base import CollectionConfigHandler
|
from .... base import CollectionConfigHandler
|
||||||
|
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "row-embeddings-write"
|
default_ident = "row-embeddings-write"
|
||||||
default_store_uri = 'http://localhost:6333'
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
@ -41,13 +41,19 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
id = params.get("id", default_ident)
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
store_uri = params.get("store_uri")
|
||||||
api_key = params.get("api_key", None)
|
api_key = params.get("api_key")
|
||||||
|
|
||||||
|
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
|
||||||
|
url=store_uri, api_key=api_key,
|
||||||
|
replication_factor=params.get("qdrant_replication_factor"),
|
||||||
|
shard_number=params.get("qdrant_shard_number"),
|
||||||
|
)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
"store_uri": store_uri,
|
"store_uri": url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -63,7 +69,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
# Register config handler for collection management
|
# Register config handler for collection management
|
||||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||||
|
|
||||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
self.qdrant = QdrantClient(url=url, api_key=api_key)
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
self.shard_number = shard_number
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
self._known_collections: set[str] = set()
|
self._known_collections: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -103,6 +111,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
size=dimension,
|
size=dimension,
|
||||||
distance=Distance.COSINE
|
distance=Distance.COSINE
|
||||||
),
|
),
|
||||||
|
replication_factor=self.replication_factor,
|
||||||
|
shard_number=self.shard_number,
|
||||||
)
|
)
|
||||||
self._known_collections.add(collection_name)
|
self._known_collections.add(collection_name)
|
||||||
|
|
||||||
|
|
@ -249,21 +259,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
add_qdrant_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Qdrant URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-k', '--api-key',
|
|
||||||
default=None,
|
|
||||||
help='Qdrant API key (default: None)'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -47,16 +47,18 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password,
|
||||||
|
replication_factor=params.get("cassandra_replication_factor"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration with proper names
|
# Store resolved configuration with proper names
|
||||||
self.cassandra_host = hosts # Store as list
|
self.cassandra_host = hosts # Store as list
|
||||||
self.cassandra_username = username
|
self.cassandra_username = username
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
|
self.replication_factor = replication_factor
|
||||||
|
|
||||||
# Config key for schemas
|
# Config key for schemas
|
||||||
self.config_key = params.get("config_type", "schema")
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
@ -170,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
@ -232,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||||
WITH REPLICATION = {{
|
WITH REPLICATION = {{
|
||||||
'class': 'SimpleStrategy',
|
'class': 'SimpleStrategy',
|
||||||
'replication_factor': 1
|
'replication_factor': {self.replication_factor}
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ Notes:
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from cassandra.query import SimpleStatement
|
||||||
|
|
||||||
|
|
||||||
async def async_execute(session, query, parameters=None):
|
async def async_execute(session, query, parameters=None):
|
||||||
"""Execute a CQL statement asynchronously.
|
"""Execute a CQL statement asynchronously.
|
||||||
|
|
@ -76,3 +78,83 @@ def _set_result_if_pending(fut, result):
|
||||||
def _set_exception_if_pending(fut, exc):
|
def _set_exception_if_pending(fut, exc):
|
||||||
if not fut.done():
|
if not fut.done():
|
||||||
fut.set_exception(exc)
|
fut.set_exception(exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
|
||||||
|
"""Execute a CQL query with page-by-page iteration.
|
||||||
|
|
||||||
|
Uses synchronous session.execute() inside run_in_executor so that
|
||||||
|
the driver's ResultSet paging works correctly without materialising
|
||||||
|
the entire result set in memory.
|
||||||
|
|
||||||
|
Returns all pages as a list of lists.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
if isinstance(query, str):
|
||||||
|
stmt = SimpleStatement(query, fetch_size=fetch_size)
|
||||||
|
else:
|
||||||
|
stmt = query
|
||||||
|
stmt.fetch_size = fetch_size
|
||||||
|
|
||||||
|
def _fetch_all_pages():
|
||||||
|
pages = []
|
||||||
|
result_set = session.execute(stmt, parameters)
|
||||||
|
while True:
|
||||||
|
pages.append(list(result_set.current_rows))
|
||||||
|
if result_set.has_more_pages:
|
||||||
|
result_set.fetch_next_page()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return pages
|
||||||
|
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None, _fetch_all_pages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_scan(
|
||||||
|
session, query, parameters=None, row_filter=None,
|
||||||
|
limit=None, fetch_size=5000,
|
||||||
|
):
|
||||||
|
"""Scan a CQL query page-by-page, applying a filter and limit.
|
||||||
|
|
||||||
|
Only matching rows accumulate in memory. Each page is discarded
|
||||||
|
after processing, so peak memory is bounded by fetch_size plus
|
||||||
|
the number of matching rows (capped by limit).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: cassandra.cluster.Session
|
||||||
|
query: CQL statement string
|
||||||
|
parameters: bind params
|
||||||
|
row_filter: callable(row) -> bool, or None to accept all
|
||||||
|
limit: max results to return, or None for unlimited
|
||||||
|
fetch_size: rows per Cassandra page fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching rows.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
if isinstance(query, str):
|
||||||
|
stmt = SimpleStatement(query, fetch_size=fetch_size)
|
||||||
|
else:
|
||||||
|
stmt = query
|
||||||
|
stmt.fetch_size = fetch_size
|
||||||
|
|
||||||
|
def _scan():
|
||||||
|
results = []
|
||||||
|
result_set = session.execute(stmt, parameters)
|
||||||
|
while True:
|
||||||
|
for row in result_set.current_rows:
|
||||||
|
if row_filter is None or row_filter(row):
|
||||||
|
results.append(row)
|
||||||
|
if limit and len(results) >= limit:
|
||||||
|
return results
|
||||||
|
if result_set.has_more_pages:
|
||||||
|
result_set.fetch_next_page()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
return await loop.run_in_executor(None, _scan)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -33,7 +33,7 @@ class ConfigTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import logging
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
from . cassandra_async import async_execute
|
from . cassandra_async import async_execute
|
||||||
|
|
||||||
|
|
@ -39,7 +39,7 @@ class IamTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
|
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password,
|
username=cassandra_username, password=cassandra_password,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
from . cassandra_async import async_execute
|
from . cassandra_async import async_execute, async_execute_paged
|
||||||
|
|
||||||
|
|
||||||
def term_to_tuple(term):
|
def term_to_tuple(term):
|
||||||
|
|
@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
|
||||||
else:
|
else:
|
||||||
return Term(type=LITERAL, value=value)
|
return Term(type=LITERAL, value=value)
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -50,7 +50,7 @@ class KnowledgeTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
@ -98,7 +98,8 @@ class KnowledgeTableStore:
|
||||||
text, boolean, text, boolean, text, boolean
|
text, boolean, text, boolean, text, boolean
|
||||||
>>,
|
>>,
|
||||||
triples list<tuple<
|
triples list<tuple<
|
||||||
text, boolean, text, boolean, text, boolean
|
text, boolean, text, boolean, text, boolean,
|
||||||
|
text
|
||||||
>>,
|
>>,
|
||||||
PRIMARY KEY ((workspace, document_id), id)
|
PRIMARY KEY ((workspace, document_id), id)
|
||||||
);
|
);
|
||||||
|
|
@ -234,7 +235,8 @@ class KnowledgeTableStore:
|
||||||
|
|
||||||
triples = [
|
triples = [
|
||||||
(
|
(
|
||||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
|
||||||
|
v.g or ""
|
||||||
)
|
)
|
||||||
for v in m.triples
|
for v in m.triples
|
||||||
]
|
]
|
||||||
|
|
@ -398,7 +400,7 @@ class KnowledgeTableStore:
|
||||||
logger.debug("Get triples...")
|
logger.debug("Get triples...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(
|
pages = await async_execute_paged(
|
||||||
self.cassandra,
|
self.cassandra,
|
||||||
self.get_triples_stmt,
|
self.get_triples_stmt,
|
||||||
(workspace, document_id),
|
(workspace, document_id),
|
||||||
|
|
@ -407,7 +409,8 @@ class KnowledgeTableStore:
|
||||||
logger.error("Exception occurred", exc_info=True)
|
logger.error("Exception occurred", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
for row in rows:
|
for page in pages:
|
||||||
|
for row in page:
|
||||||
|
|
||||||
if row[3]:
|
if row[3]:
|
||||||
triples = [
|
triples = [
|
||||||
|
|
@ -415,6 +418,7 @@ class KnowledgeTableStore:
|
||||||
s = tuple_to_term(elt[0], elt[1]),
|
s = tuple_to_term(elt[0], elt[1]),
|
||||||
p = tuple_to_term(elt[2], elt[3]),
|
p = tuple_to_term(elt[2], elt[3]),
|
||||||
o = tuple_to_term(elt[4], elt[5]),
|
o = tuple_to_term(elt[4], elt[5]),
|
||||||
|
g = elt[6] if elt[6] else None,
|
||||||
)
|
)
|
||||||
for elt in row[3]
|
for elt in row[3]
|
||||||
]
|
]
|
||||||
|
|
@ -425,7 +429,7 @@ class KnowledgeTableStore:
|
||||||
Triples(
|
Triples(
|
||||||
metadata = Metadata(
|
metadata = Metadata(
|
||||||
id = document_id,
|
id = document_id,
|
||||||
collection = "default", # FIXME: What to put here?
|
collection = "default",
|
||||||
),
|
),
|
||||||
triples = triples
|
triples = triples
|
||||||
)
|
)
|
||||||
|
|
@ -438,7 +442,7 @@ class KnowledgeTableStore:
|
||||||
logger.debug("Get GE...")
|
logger.debug("Get GE...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(
|
pages = await async_execute_paged(
|
||||||
self.cassandra,
|
self.cassandra,
|
||||||
self.get_graph_embeddings_stmt,
|
self.get_graph_embeddings_stmt,
|
||||||
(workspace, document_id),
|
(workspace, document_id),
|
||||||
|
|
@ -447,7 +451,8 @@ class KnowledgeTableStore:
|
||||||
logger.error("Exception occurred", exc_info=True)
|
logger.error("Exception occurred", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
for row in rows:
|
for page in pages:
|
||||||
|
for row in page:
|
||||||
|
|
||||||
if row[3]:
|
if row[3]:
|
||||||
entities = [
|
entities = [
|
||||||
|
|
@ -464,7 +469,7 @@ class KnowledgeTableStore:
|
||||||
GraphEmbeddings(
|
GraphEmbeddings(
|
||||||
metadata = Metadata(
|
metadata = Metadata(
|
||||||
id = document_id,
|
id = document_id,
|
||||||
collection = "default", # FIXME: What to put here?
|
collection = "default",
|
||||||
),
|
),
|
||||||
entities = entities
|
entities = entities
|
||||||
)
|
)
|
||||||
|
|
@ -477,7 +482,7 @@ class KnowledgeTableStore:
|
||||||
logger.debug("Get DE...")
|
logger.debug("Get DE...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(
|
pages = await async_execute_paged(
|
||||||
self.cassandra,
|
self.cassandra,
|
||||||
self.get_document_embeddings_stmt,
|
self.get_document_embeddings_stmt,
|
||||||
(workspace, document_id),
|
(workspace, document_id),
|
||||||
|
|
@ -486,7 +491,8 @@ class KnowledgeTableStore:
|
||||||
logger.error("Exception occurred", exc_info=True)
|
logger.error("Exception occurred", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
for row in rows:
|
for page in pages:
|
||||||
|
for row in page:
|
||||||
|
|
||||||
if row[3]:
|
if row[3]:
|
||||||
chunks = [
|
chunks = [
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .. exceptions import RequestError
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from cassandra.query import BatchStatement
|
from cassandra.query import BatchStatement
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
import ssl
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
|
@ -53,7 +53,7 @@ class LibraryTableStore:
|
||||||
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
|
||||||
|
|
||||||
if cassandra_username and cassandra_password:
|
if cassandra_username and cassandra_password:
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
ssl_context = ssl.create_default_context()
|
||||||
auth_provider = PlainTextAuthProvider(
|
auth_provider = PlainTextAuthProvider(
|
||||||
username=cassandra_username, password=cassandra_password
|
username=cassandra_username, password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,49 +1,110 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from websockets.asyncio.client import connect
|
from websockets.asyncio.client import connect
|
||||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import hashlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _token_key(token):
|
||||||
|
"""Derive a dict key from a token without storing the raw secret."""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
|
"""Manages an authenticated WebSocket connection to the TrustGraph
|
||||||
|
gateway on behalf of a single caller.
|
||||||
|
|
||||||
def __init__(self, url, token=None):
|
Each caller token gets its own WebSocketManager so that gateway-side
|
||||||
|
identity, workspace, and capability scoping are preserved end-to-end.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url, token):
|
||||||
self.url = url
|
self.url = url
|
||||||
|
# ── Security boundary: token storage ──
|
||||||
|
# This is the MCP caller's Bearer token, forwarded verbatim to
|
||||||
|
# the gateway. It MUST NOT be logged, persisted, or shared
|
||||||
|
# across callers. It is held only for the lifetime of this
|
||||||
|
# connection so that re-auth (e.g. after a reconnect) is
|
||||||
|
# possible.
|
||||||
self.token = token
|
self.token = token
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
self.identity = None
|
||||||
# FIXME: authentication is broken. The /api/v1/socket endpoint uses
|
self.last_used = None
|
||||||
# in-band auth (first-frame protocol via the Mux dispatcher), not
|
|
||||||
# query-parameter tokens. This query-string token is silently ignored.
|
|
||||||
# Fix: after connect(), send an auth frame with the bearer token as
|
|
||||||
# the first message, matching the gateway's in-band auth protocol.
|
|
||||||
def _build_url(self):
|
|
||||||
if not self.token:
|
|
||||||
return self.url
|
|
||||||
parsed = urlparse(self.url)
|
|
||||||
params = parse_qs(parsed.query)
|
|
||||||
params["token"] = [self.token]
|
|
||||||
new_query = urlencode(params, doseq=True)
|
|
||||||
return urlunparse(parsed._replace(query=new_query))
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.socket = await connect(self._build_url())
|
"""Connect and authenticate via the gateway's in-band auth
|
||||||
|
protocol. Raises on auth failure."""
|
||||||
|
|
||||||
|
# ── Security boundary: MCP server → gateway ──
|
||||||
|
# The WebSocket connects to the gateway and authenticates using
|
||||||
|
# the caller's Bearer token via the in-band first-frame auth
|
||||||
|
# protocol. The token belongs to the MCP client — we forward
|
||||||
|
# it as-is and never interpret its contents.
|
||||||
|
self.socket = await connect(self.url)
|
||||||
self.pending_requests = {}
|
self.pending_requests = {}
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
|
await self._authenticate()
|
||||||
|
|
||||||
self.reader_task = asyncio.create_task(self.reader())
|
self.reader_task = asyncio.create_task(self.reader())
|
||||||
|
|
||||||
|
async def _authenticate(self):
|
||||||
|
"""Send in-band auth frame and wait for auth-ok / auth-failed.
|
||||||
|
|
||||||
|
The gateway expects ``{"type": "auth", "token": "..."}`` as the
|
||||||
|
first frame on a new WebSocket. Any service frame sent before
|
||||||
|
auth-ok is rejected.
|
||||||
|
"""
|
||||||
|
await self.socket.send(json.dumps({
|
||||||
|
"type": "auth",
|
||||||
|
"token": self.token,
|
||||||
|
}))
|
||||||
|
|
||||||
|
response_text = await asyncio.wait_for(self.socket.recv(), 10)
|
||||||
|
response = json.loads(response_text)
|
||||||
|
|
||||||
|
if response.get("type") == "auth-ok":
|
||||||
|
logger.info(
|
||||||
|
"WebSocket authenticated, default workspace: %s",
|
||||||
|
response.get("workspace"),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Auth failed — close immediately, do not leave an
|
||||||
|
# unauthenticated socket open.
|
||||||
|
await self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
if response.get("type") == "auth-failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Gateway rejected the authentication token"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected auth response type: {response.get('type')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def whoami(self):
|
||||||
|
"""Verify the token by calling the gateway's whoami endpoint.
|
||||||
|
Returns the identity dict and caches it on ``self.identity``.
|
||||||
|
"""
|
||||||
|
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
|
||||||
|
async for response in gen:
|
||||||
|
self.identity = response
|
||||||
|
return response
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
if hasattr(self, "reader_task"):
|
||||||
await self.reader_task
|
await self.reader_task
|
||||||
|
|
||||||
async def reader(self):
|
async def reader(self):
|
||||||
"""
|
"""Background task: read WebSocket frames and route them to the
|
||||||
Background task to read websocket responses and route to correct
|
correct pending-request queue by ``id``."""
|
||||||
request
|
|
||||||
"""
|
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
|
|
@ -59,23 +120,21 @@ class WebSocketManager:
|
||||||
|
|
||||||
request_id = response.get("id")
|
request_id = response.get("id")
|
||||||
if request_id and request_id in self.pending_requests:
|
if request_id and request_id in self.pending_requests:
|
||||||
# Put the response in the queue
|
|
||||||
queue = self.pending_requests[request_id]
|
queue = self.pending_requests[request_id]
|
||||||
await queue.put(response)
|
await queue.put(response)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"Response for unknown request ID: {request_id}"
|
"Response for unknown request ID: %s", request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
logging.error(f"Error in websocket reader: {e}")
|
logger.error("Error in websocket reader: %s", e)
|
||||||
|
|
||||||
# Put error in all pending queues
|
|
||||||
for queue in self.pending_requests.values():
|
for queue in self.pending_requests.values():
|
||||||
try:
|
try:
|
||||||
await queue.put({"error": str(e)})
|
await queue.put({"error": str(e)})
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.pending_requests.clear()
|
self.pending_requests.clear()
|
||||||
|
|
@ -86,25 +145,29 @@ class WebSocketManager:
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
self, service, request_data, flow_id="default",
|
self, service, request_data, flow_id="default",
|
||||||
|
workspace=None,
|
||||||
):
|
):
|
||||||
"""
|
"""Send a request via WebSocket and yield responses.
|
||||||
Send a request via websocket and handle single or streaming responses
|
|
||||||
|
Args:
|
||||||
|
service: Gateway service name (e.g. "graph-rag", "config").
|
||||||
|
request_data: Inner request payload.
|
||||||
|
flow_id: Optional flow identifier. ``None`` omits the field
|
||||||
|
(workspace-level services don't use flows).
|
||||||
|
workspace: Optional workspace override. When ``None`` the
|
||||||
|
gateway uses the caller's default workspace.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Generate unique request ID
|
import time
|
||||||
|
self.last_used = time.monotonic()
|
||||||
|
|
||||||
request_id = f"{uuid.uuid4()}"
|
request_id = f"{uuid.uuid4()}"
|
||||||
|
|
||||||
# Determine if this service streams responses
|
|
||||||
streaming_services = {"agent"}
|
|
||||||
is_streaming = service in streaming_services
|
|
||||||
|
|
||||||
# Create a queue for all responses (streaming and single)
|
|
||||||
response_queue = asyncio.Queue()
|
response_queue = asyncio.Queue()
|
||||||
self.pending_requests[request_id] = response_queue
|
self.pending_requests[request_id] = response_queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Build request message
|
|
||||||
message = {
|
message = {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"service": service,
|
"service": service,
|
||||||
|
|
@ -114,7 +177,16 @@ class WebSocketManager:
|
||||||
if flow_id is not None:
|
if flow_id is not None:
|
||||||
message["flow"] = flow_id
|
message["flow"] = flow_id
|
||||||
|
|
||||||
# Send request
|
# ── Security boundary: workspace scoping ──
|
||||||
|
# When the caller supplies a workspace, we set it on the
|
||||||
|
# message envelope. The gateway's enforce_workspace()
|
||||||
|
# validates that the authenticated identity is permitted
|
||||||
|
# to access the target workspace — we MUST NOT skip or
|
||||||
|
# override that check. When workspace is None, the
|
||||||
|
# gateway default-fills from the identity's bound workspace.
|
||||||
|
if workspace is not None:
|
||||||
|
message["workspace"] = workspace
|
||||||
|
|
||||||
await self.socket.send(json.dumps(message))
|
await self.socket.send(json.dumps(message))
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
@ -127,19 +199,17 @@ class WebSocketManager:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
if "message" in response["error"]:
|
if isinstance(response["error"], dict):
|
||||||
raise RuntimeError(response["error"]["text"])
|
raise RuntimeError(
|
||||||
|
response["error"].get("message", str(response["error"]))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(str(response["error"]))
|
raise RuntimeError(str(response["error"]))
|
||||||
|
|
||||||
yield response["response"]
|
yield response["response"]
|
||||||
|
|
||||||
if "complete" in response:
|
if response.get("complete"):
|
||||||
if response["complete"]:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
finally:
|
||||||
# Clean up on error
|
|
||||||
self.pending_requests.pop(request_id, None)
|
self.pending_requests.pop(request_id, None)
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,14 @@ class Processor(FlowProcessor):
|
||||||
# Get the source document ID
|
# Get the source document ID
|
||||||
source_doc_id = v.document_id or v.metadata.id
|
source_doc_id = v.document_id or v.metadata.id
|
||||||
|
|
||||||
|
try:
|
||||||
pages = convert_from_bytes(blob)
|
pages = convert_from_bytes(blob)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to decode PDF {source_doc_id}: "
|
||||||
|
f"{type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
for ix, page in enumerate(pages):
|
for ix, page in enumerate(pages):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -418,7 +418,14 @@ class Processor(FlowProcessor):
|
||||||
doc_uri_str = document_uri(source_doc_id)
|
doc_uri_str = document_uri(source_doc_id)
|
||||||
|
|
||||||
# Extract elements using unstructured
|
# Extract elements using unstructured
|
||||||
|
try:
|
||||||
elements = self.extract_elements(blob, mime_type)
|
elements = self.extract_elements(blob, mime_type)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to extract elements from {source_doc_id}: "
|
||||||
|
f"{type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if not elements:
|
if not elements:
|
||||||
logger.warning("No elements extracted from document")
|
logger.warning("No elements extracted from document")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue