Merge branch 'release/v2.1'

This commit is contained in:
Cyber MacGeddon 2026-03-17 20:44:03 +00:00
commit 824f993985
266 changed files with 33195 additions and 5834 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=2.0.999
run: make update-package-versions VERSION=2.1.999
- name: Setup environment
run: python3 -m venv env

View file

@ -13,9 +13,17 @@
# The context backend for reliable AI
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
# The context backend for AI agents
</div>
LLMs alone hallucinate and diverge from ground truth. [TrustGraph](https://trustgraph.ai) is a context system that stores, enriches, and delivers context to LLMs to enable reliable AI agents. Think like [Supabase](https://github.com/supabase/supabase) but AI-native and powered by context graphs.
Durable agent memory you can trust. Build, version, and retrieve grounded context from a context graph.
- Give agents **memory** that persists across sessions and deployments.
- Reduce hallucinations with **grounded context retrieval**
- Ship reusable, portable [Context Cores](#context-cores) (packaged context you can move between projects/environments).
The context backend:
- [x] Multi-model and multimodal database system
@ -45,21 +53,6 @@ The context backend:
- [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html)
- [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python)
- [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/)
## No API Keys Required
How many times have you cloned a repo and opened the `.env.example` to see the dozens of API keys for 3rd party dependencies needed to make the services work? There are only 3 things in TrustGraph that might need an API key:
- 3rd party LLM services like Anthropic, Cohere, Gemini, Mistral, OpenAI, etc.
- 3rd party OCR like Mistral OCR
- The API key *you set* for the TrustGraph API gateway
Everything else is included.
- [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html)
- [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant)
- [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible)
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar)
- [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile)
## Quickstart
@ -76,8 +69,6 @@ TrustGraph downloads as Docker containers and can be run locally with Docker, Po
width="80%" controls></video>
</p>
For a browser based quickstart, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/).
<details>
<summary>Table of Contents</summary>
<br>
@ -181,24 +172,28 @@ TrustGraph provides component flexibility to optimize agent workflows.
</details>
<details>
<summary>Multi-model storage</summary>
<summary>Graph Storage</summary>
<br>
- Apache Cassandra (default)<br>
- Neo4j<br>
- Memgraph<br>
- FalkorDB<br>
</details>
<details>
<summary>VectorDBs</summary>
<br>
- Apache Cassandra<br>
</details>
<details>
<summary>VectorDB</summary>
<br>
- Qdrant<br>
</details>
<details>
<summary>File and Object Storage</summary>
<br>
- Garage<br>
- Garage (default)<br>
- MinIO<br>
</details>
<details>

View file

@ -0,0 +1,108 @@
# API Gateway Changes: v1.8 to v2.1
## Summary
The API gateway gained new WebSocket service dispatchers for embeddings
queries, a new REST streaming endpoint for document content, and underwent
a significant wire format change from `Value` to `Term`. The "objects"
service was renamed to "rows".
---
## New WebSocket Service Dispatchers
These are new request/response services available through the WebSocket
multiplexer at `/api/v1/socket` (flow-scoped):
| Service Key | Description |
|-------------|-------------|
| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. |
| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. |
These join the existing `graph-embeddings` dispatcher (which was already
present in v1.8 but may have been updated).
### Full list of WebSocket flow service dispatchers (v2.1)
Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or
WebSocket mux):
- `agent`, `text-completion`, `prompt`, `mcp-tool`
- `graph-rag`, `document-rag`
- `embeddings`, `graph-embeddings`, `document-embeddings`
- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag`
- `row-embeddings`
---
## New REST Endpoint
| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. |
---
## Renamed Service: "objects" to "rows"
| v1.8 | v2.1 | Notes |
|------|------|-------|
| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. |
| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. |
The WebSocket service key changed from `"objects"` to `"rows"`, and the
import dispatcher key similarly changed from `"objects"` to `"rows"`.
---
## Wire Format Change: Value to Term
The serialization layer (`serialize.py`) was rewritten to use the new `Term`
type instead of the old `Value` type.
### Old format (v1.8 — `Value`)
```json
{"v": "http://example.org/entity", "e": true}
```
- `v`: the value (string)
- `e`: boolean flag indicating whether the value is a URI
### New format (v2.1 — `Term`)
IRIs:
```json
{"t": "i", "i": "http://example.org/entity"}
```
Literals:
```json
{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"}
```
Quoted triples (RDF-star):
```json
{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}}
```
- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node)
- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives`
### Other serialization changes
| Field | v1.8 | v2.1 |
|-------|------|------|
| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) |
| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) |
| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) |
---
## Breaking Changes
- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format.
- **`objects` to `rows` rename**: WebSocket service key and import key changed.
- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value).
- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text.
- **New `/api/v1/document-stream` endpoint**: Additive, not breaking.

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,112 @@
# CLI Changes: v1.8 to v2.1
## Summary
The CLI (`trustgraph-cli`) has significant additions focused on three themes:
**explainability/provenance**, **embeddings access**, and **graph querying**.
Two legacy tools were removed, one was renamed, and several existing tools
gained new capabilities.
---
## New CLI Tools
### Explainability & Provenance
| Command | Description |
|---------|-------------|
| `tg-list-explain-traces` | Lists all explainability sessions (GraphRAG and Agent) in a collection, showing session IDs, type, question text, and timestamps. |
| `tg-show-explain-trace` | Displays the full explainability trace for a session. For GraphRAG: Question, Exploration, Focus, Synthesis stages. For Agent: Session, Iterations (thought/action/observation), Final Answer. Auto-detects trace type. Supports `--show-provenance` to trace edges back to source documents. |
| `tg-show-extraction-provenance` | Given a document ID, traverses the provenance chain: Document -> Pages -> Chunks -> Edges, using `prov:wasDerivedFrom` relationships. Supports `--show-content` and `--max-content` options. |
### Embeddings
| Command | Description |
|---------|-------------|
| `tg-invoke-embeddings` | Converts text to a vector embedding via the embeddings service. Accepts one or more text inputs, returns vectors as lists of floats. |
| `tg-invoke-graph-embeddings` | Queries graph entities by text similarity using vector embeddings. Returns matching entities with similarity scores. |
| `tg-invoke-document-embeddings` | Queries document chunks by text similarity using vector embeddings. Returns matching chunk IDs with similarity scores. |
| `tg-invoke-row-embeddings` | Queries structured data rows by text similarity on indexed fields. Returns matching rows with index values and scores. Requires `--schema-name` and supports `--index-name`. |
### Graph Querying
| Command | Description |
|---------|-------------|
| `tg-query-graph` | Pattern-based triple store query. Unlike `tg-show-graph` (which dumps everything), this allows selective queries by any combination of subject, predicate, object, and graph. Auto-detects value types: IRIs (`http://...`, `urn:...`, `<...>`), quoted triples (`<<s p o>>`), and literals. |
| `tg-get-document-content` | Retrieves document content from the library by document ID. Can output to file or stdout, handles both text and binary content. |
---
## Removed CLI Tools
| Command | Notes |
|---------|-------|
| `tg-load-pdf` | Removed. Document loading is now handled through the library/processing pipeline. |
| `tg-load-text` | Removed. Document loading is now handled through the library/processing pipeline. |
---
## Renamed CLI Tools
| Old Name | New Name | Notes |
|----------|----------|-------|
| `tg-invoke-objects-query` | `tg-invoke-rows-query` | Reflects the terminology rename from "objects" to "rows" for structured data. |
---
## Significant Changes to Existing Tools
### `tg-invoke-graph-rag`
- **Explainability support**: Now supports a 4-stage explainability pipeline (Question, Grounding/Exploration, Focus, Synthesis) with inline provenance event display.
- **Streaming**: Uses WebSocket streaming for real-time output.
- **Provenance tracing**: Can trace selected edges back to source documents via reification and `prov:wasDerivedFrom` chains.
- Grew from ~30 lines to ~760 lines to accommodate the full explainability pipeline.
### `tg-invoke-document-rag`
- **Explainability support**: Added `question_explainable()` mode that streams Document RAG responses with inline provenance events (Question, Grounding, Exploration, Synthesis stages).
### `tg-invoke-agent`
- **Explainability support**: Added `question_explainable()` mode showing provenance events inline during agent execution (Question, Analysis, Conclusion, AgentThought, AgentObservation, AgentAnswer).
- Verbose mode shows thought/observation streams with emoji prefixes.
### `tg-show-graph`
- **Streaming mode**: Now uses `triples_query_stream()` with configurable batch sizes for lower time-to-first-result and reduced memory overhead.
- **Named graph support**: New `--graph` filter option. Recognises named graphs:
- Default graph (empty): Core knowledge facts
- `urn:graph:source`: Extraction provenance
- `urn:graph:retrieval`: Query-time explainability
- **Show graph column**: New `--show-graph` flag to display the named graph for each triple.
- **Configurable limits**: New `--limit` and `--batch-size` options.
### `tg-graph-to-turtle`
- **RDF-star support**: Now handles quoted triples (RDF-star reification).
- **Streaming mode**: Uses streaming for lower time-to-first-processing.
- **Wire format handling**: Updated to use the new term wire format (`{"t": "i", "i": uri}` for IRIs, `{"t": "l", "v": value}` for literals, `{"t": "r", "r": {...}}` for quoted triples).
- **Named graph support**: New `--graph` filter option.
### `tg-set-tool`
- **New tool type**: `row-embeddings-query` for semantic search on structured data indexes.
- **New options**: `--schema-name`, `--index-name`, `--limit` for configuring row embeddings query tools.
### `tg-show-tools`
- Displays the new `row-embeddings-query` tool type with its `schema-name`, `index-name`, and `limit` fields.
### `tg-load-knowledge`
- **Progress reporting**: Now counts and reports triples and entity contexts loaded per file and in total.
- **Term format update**: Entity contexts now use the new Term format (`{"t": "i", "i": uri}`) instead of the old Value format (`{"v": entity, "e": True}`).
---
## Breaking Changes
- **Terminology rename**: The `Value` schema was renamed to `Term` across the system (PR #622). This affects the wire format used by CLI tools that interact with the graph store. The new format uses `{"t": "i", "i": uri}` for IRIs and `{"t": "l", "v": value}` for literals, replacing the old `{"v": ..., "e": ...}` format.
- **`tg-invoke-objects-query` renamed** to `tg-invoke-rows-query`.
- **`tg-load-pdf` and `tg-load-text` removed**.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,272 @@
# Agent Explainability: Provenance Recording
## Overview
Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG.
**Design Decisions:**
- Write to `urn:graph:retrieval` (generic explainability graph)
- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1)
- Tools are opaque black boxes (record input/output only)
- DAG support deferred to future iteration
## Entity Types
Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes:
### GraphRAG Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph |
| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning |
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
### Agent Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query |
| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle |
| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer |
### Document RAG Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query |
| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store |
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase).
### Question Subtypes
All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism:
| Subtype | URI Pattern | Mechanism |
|---------|-------------|-----------|
| `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG |
| `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG |
| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent |
This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype.
## Provenance Model
```
Question (urn:trustgraph:agent:{uuid})
│ tg:query = "User's question"
│ prov:startedAtTime = timestamp
│ rdf:type = prov:Activity, tg:Question
↓ prov:wasDerivedFrom
Analysis1 (urn:trustgraph:agent:{uuid}/i1)
│ tg:thought = "I need to query the knowledge base..."
│ tg:action = "knowledge-query"
│ tg:arguments = {"question": "..."}
│ tg:observation = "Result from tool..."
│ rdf:type = prov:Entity, tg:Analysis
↓ prov:wasDerivedFrom
Analysis2 (urn:trustgraph:agent:{uuid}/i2)
│ ...
↓ prov:wasDerivedFrom
Conclusion (urn:trustgraph:agent:{uuid}/final)
│ tg:answer = "The final response..."
│ rdf:type = prov:Entity, tg:Conclusion
```
### Document RAG Provenance Model
```
Question (urn:trustgraph:docrag:{uuid})
│ tg:query = "User's question"
│ prov:startedAtTime = timestamp
│ rdf:type = prov:Activity, tg:Question
↓ prov:wasGeneratedBy
Exploration (urn:trustgraph:docrag:{uuid}/exploration)
│ tg:chunkCount = 5
│ tg:selectedChunk = "chunk-id-1"
│ tg:selectedChunk = "chunk-id-2"
│ ...
│ rdf:type = prov:Entity, tg:Exploration
↓ prov:wasDerivedFrom
Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
│ tg:content = "The synthesized answer..."
│ rdf:type = prov:Entity, tg:Synthesis
```
## Changes Required
### 1. Schema Changes
**File:** `trustgraph-base/trustgraph/schema/services/agent.py`
Add `session_id` and `collection` fields to `AgentRequest`:
```python
@dataclass
class AgentRequest:
question: str = ""
state: str = ""
group: list[str] | None = None
history: list[AgentStep] = field(default_factory=list)
user: str = ""
collection: str = "default" # NEW: Collection for provenance traces
streaming: bool = False
session_id: str = "" # NEW: For provenance tracking across iterations
```
**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py`
Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`.
### 2. Add Explainability Producer to Agent Service
**File:** `trustgraph-flow/trustgraph/agent/react/service.py`
Register an "explainability" producer (same pattern as GraphRAG):
```python
from ... base import ProducerSpec
from ... schema import Triples
# In __init__:
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
```
### 3. Provenance Triple Generation
**File:** `trustgraph-base/trustgraph/provenance/agent.py`
Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.):
```python
def agent_session_triples(session_uri, query, timestamp):
"""Generate triples for agent Question."""
return [
Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY),
Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION),
Triple(s=session_uri, p=TG_QUERY, o=query),
Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp),
]
def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation):
"""Generate triples for one Analysis step."""
return [
Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY),
Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS),
Triple(s=iteration_uri, p=TG_THOUGHT, o=thought),
Triple(s=iteration_uri, p=TG_ACTION, o=action),
Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)),
Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation),
Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
]
def agent_final_triples(final_uri, parent_uri, answer):
"""Generate triples for Conclusion."""
return [
Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY),
Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION),
Triple(s=final_uri, p=TG_ANSWER, o=answer),
Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
]
```
### 4. Type Definitions
**File:** `trustgraph-base/trustgraph/provenance/namespaces.py`
Add explainability entity types and agent predicates:
```python
# Explainability entity types (used by both GraphRAG and Agent)
TG_QUESTION = TG + "Question"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
# Agent predicates
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
```
## Files Modified
| File | Change |
|------|--------|
| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest |
| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields |
| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates |
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id and explain_graph to DocumentRagResponse |
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields |
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types |
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG |
## Files Created
| File | Purpose |
|------|---------|
| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators |
## CLI Updates
**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by:
1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:`
2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG)
**`list_explain_traces.py`:**
- Shows Type column (Agent vs GraphRAG)
**`show_explain_trace.py`:**
- Auto-detects trace type
- Agent rendering shows: Question → Analysis step(s) → Conclusion
## Backwards Compatibility
- `session_id` defaults to `""` - old requests work, just won't have provenance
- `collection` defaults to `"default"` - reasonable fallback
- CLI gracefully handles both trace types
## Verification
```bash
# Run an agent query
tg-invoke-agent -q "What is the capital of France?"
# List traces (should show agent sessions with Type column)
tg-list-explain-traces -U trustgraph -C default
# Show agent trace
tg-show-explain-trace "urn:trustgraph:agent:xxx"
```
## Future Work (Not This PR)
- DAG dependencies (when analysis N uses results from multiple prior analyses)
- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace)
- Streaming provenance emission (emit as we go, not batch at end)

View file

@ -0,0 +1,136 @@
# Document Embeddings Chunk ID
## Overview
Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references.
## Current State
```python
@dataclass
class ChunkEmbeddings:
chunk: bytes = b""
vectors: list[list[float]] = field(default_factory=list)
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunks: list[str] = field(default_factory=list)
```
Vector store payload:
```python
payload={"doc": chunk} # Duplicates Garage content
```
## Design
### Schema Changes
**ChunkEmbeddings** - replace chunk with chunk_id:
```python
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
```
**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks:
```python
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
```
### Vector Store Payload
All stores (Qdrant, Milvus, Pinecone):
```python
payload={"chunk_id": chunk_id}
```
### Document RAG Changes
The document RAG processor fetches chunk content from Garage:
```python
# Get chunk_ids from embeddings store
chunk_ids = await self.rag.doc_embeddings_client.query(...)
# Fetch chunk content from Garage
docs = []
for chunk_id in chunk_ids:
content = await self.rag.librarian_client.get_document_content(
chunk_id, self.user
)
docs.append(content)
```
### API/SDK Changes
**DocumentEmbeddingsClient** returns chunk_ids:
```python
return resp.chunk_ids # Changed from resp.chunks
```
**Wire format** (DocumentEmbeddingsResponseTranslator):
```python
result["chunk_ids"] = obj.chunk_ids # Changed from chunks
```
### CLI Changes
CLI tool displays chunk_ids (callers can fetch content separately if needed).
## Files to Modify
### Schema
- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings
- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse
### Messaging/Translators
- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator
### Client
- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids
### Python SDK/API
- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query
- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable
- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings
- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings
### Embeddings Service
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id
### Storage Writers
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
### Query Services
- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py`
- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py`
### Gateway
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py`
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py`
### Document RAG
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage
### CLI
- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py`
- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py`
- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py`
## Benefits
1. Single source of truth - chunk text only in Garage
2. Reduced vector store storage
3. Enables query-time provenance via chunk_id

View file

@ -0,0 +1,667 @@
# Embeddings Batch Processing Technical Specification
## Overview
This specification describes optimizations for the embeddings service to support batch processing of multiple texts in a single request. The current implementation processes one text at a time, missing the significant performance benefits that embedding models provide when processing batches.
1. **Single-Text Processing Inefficiency**: Current implementation wraps single texts in a list, underutilizing FastEmbed's batch capabilities
2. **Request-Per-Text Overhead**: Each text requires a separate Pulsar message round-trip
3. **Model Inference Inefficiency**: Embedding models have fixed per-batch overhead; small batches waste GPU/CPU resources
4. **Serial Processing in Callers**: Key services loop over items and call embeddings one at a time
## Goals
- **Batch API Support**: Enable processing multiple texts in a single request
- **Backward Compatibility**: Maintain support for single-text requests
- **Significant Throughput Improvement**: Target 5-10x throughput improvement for bulk operations
- **Reduced Latency per Text**: Lower amortized latency when embedding multiple texts
- **Memory Efficiency**: Process batches without excessive memory consumption
- **Provider Agnostic**: Support batching across FastEmbed, Ollama, and other providers
- **Caller Migration**: Update all embedding callers to use batch API where beneficial
## Background
### Current Implementation - Embeddings Service
The embeddings implementation in `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` exhibits a significant performance inefficiency:
```python
# fastembed/processor.py line 56
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
self._load_model(use_model)
vecs = self.embeddings.embed([text]) # Single text wrapped in list
return [v.tolist() for v in vecs]
```
**Problems:**
1. **Batch Size 1**: FastEmbed's `embed()` method is optimized for batch processing, but we always call it with `[text]` - a batch of size 1
2. **Per-Request Overhead**: Each embedding request incurs:
- Pulsar message serialization/deserialization
- Network round-trip latency
- Model inference startup overhead
- Python async scheduling overhead
3. **Schema Limitation**: The `EmbeddingsRequest` schema only supports a single text:
```python
@dataclass
class EmbeddingsRequest:
text: str = "" # Single text only
```
### Current Callers - Serial Processing
#### 1. API Gateway
**File:** `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
The gateway accepts single-text embedding requests via HTTP/WebSocket and forwards them to the embeddings service. Currently no batch endpoint exists.
```python
class EmbeddingsRequestor(ServiceRequestor):
# Handles single EmbeddingsRequest -> EmbeddingsResponse
request_schema=EmbeddingsRequest, # Single text only
response_schema=EmbeddingsResponse,
```
**Impact:** External clients (web apps, scripts) must make N HTTP requests to embed N texts.
#### 2. Document Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
Processes document chunks one at a time:
```python
async def on_message(self, msg, consumer, flow):
v = msg.value()
# Single chunk per request
resp = await flow("embeddings-request").request(
EmbeddingsRequest(text=v.chunk)
)
vectors = resp.vectors
```
**Impact:** Each document chunk requires a separate embedding call. A document with 100 chunks = 100 embedding requests.
#### 3. Graph Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
Loops over entities and embeds each one serially:
```python
async def on_message(self, msg, consumer, flow):
for entity in v.entities:
# Serial embedding - one entity at a time
vectors = await flow("embeddings-request").embed(
text=entity.context
)
entities.append(EntityEmbeddings(
entity=entity.entity,
vectors=vectors,
chunk_id=entity.chunk_id,
))
```
**Impact:** A message with 50 entities = 50 serial embedding requests. This is a major bottleneck during knowledge graph construction.
#### 4. Row Embeddings Service
**File:** `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
Loops over unique texts and embeds each one serially:
```python
async def on_message(self, msg, consumer, flow):
for text, (index_name, index_value) in texts_to_embed.items():
# Serial embedding - one text at a time
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
))
```
**Impact:** Processing a table with 100 unique indexed values = 100 serial embedding requests.
#### 5. EmbeddingsClient (Base Client)
**File:** `trustgraph-base/trustgraph/base/embeddings_client.py`
The client used by all flow processors only supports single-text embedding:
```python
class EmbeddingsClient(RequestResponse):
async def embed(self, text, timeout=30):
resp = await self.request(
EmbeddingsRequest(text=text), # Single text
timeout=timeout
)
return resp.vectors
```
**Impact:** All callers using this client are limited to single-text operations.
#### 6. Command-Line Tools
**File:** `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
CLI tool accepts single text argument:
```python
def query(url, flow_id, text, token=None):
result = flow.embeddings(text=text) # Single text
vectors = result.get("vectors", [])
```
**Impact:** Users cannot batch-embed from command line. Processing a file of texts requires N invocations.
#### 7. Python SDK
The Python SDK provides two client classes for interacting with TrustGraph services. Both only support single-text embedding.
**File:** `trustgraph-base/trustgraph/api/flow.py`
```python
class FlowInstance:
def embeddings(self, text):
"""Get embeddings for a single text"""
input = {"text": text}
return self.request("service/embeddings", input)["vectors"]
```
**File:** `trustgraph-base/trustgraph/api/socket_client.py`
```python
class SocketFlowInstance:
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
"""Get embeddings for a single text via WebSocket"""
request = {"text": text}
return self.client._send_request_sync(
"embeddings", self.flow_id, request, False
)
```
**Impact:** Python developers using the SDK must loop over texts and make N separate API calls. No batch embedding support exists for SDK users.
### Performance Impact
For typical document ingestion (1000 text chunks):
- **Current**: 1000 separate requests, 1000 model inference calls
- **Batched (batch_size=32)**: 32 requests, 32 model inference calls (96.8% reduction)
For graph embedding (message with 50 entities):
- **Current**: 50 serial await calls, ~5-10 seconds
- **Batched**: 1-2 batch calls, ~0.5-1 second (5-10x improvement)
FastEmbed and similar libraries achieve near-linear throughput scaling with batch size up to hardware limits (typically 32-128 texts per batch).
## Technical Design
### Architecture
The embeddings batch processing optimization requires changes to the following components:
#### 1. **Schema Enhancement**
- Extend `EmbeddingsRequest` to support multiple texts
- Extend `EmbeddingsResponse` to return multiple vector sets
- Maintain backward compatibility with single-text requests
Module: `trustgraph-base/trustgraph/schema/services/llm.py`
#### 2. **Base Service Enhancement**
- Update `EmbeddingsService` to handle batch requests
- Add batch size configuration
- Implement batch-aware request handling
Module: `trustgraph-base/trustgraph/base/embeddings_service.py`
#### 3. **Provider Processor Updates**
- Update FastEmbed processor to pass full batch to `embed()`
- Update Ollama processor to handle batches (if supported)
- Add fallback sequential processing for providers without batch support
Modules:
- `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py`
- `trustgraph-flow/trustgraph/embeddings/ollama/processor.py`
#### 4. **Client Enhancement**
- Add batch embedding method to `EmbeddingsClient`
- Support both single and batch APIs
- Add automatic batching for large inputs
Module: `trustgraph-base/trustgraph/base/embeddings_client.py`
#### 5. **Caller Updates - Flow Processors**
- Update `graph_embeddings` to batch entity contexts
- Update `row_embeddings` to batch index texts
- Update `document_embeddings` if message batching is feasible
Modules:
- `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
- `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
#### 6. **API Gateway Enhancement**
- Add batch embedding endpoint
- Support array of texts in request body
Module: `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
#### 7. **CLI Tool Enhancement**
- Add support for multiple texts or file input
- Add batch size parameter
Module: `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
#### 8. **Python SDK Enhancement**
- Add `embeddings_batch()` method to `FlowInstance`
- Add `embeddings_batch()` method to `SocketFlowInstance`
- Support both single and batch APIs for SDK users
Modules:
- `trustgraph-base/trustgraph/api/flow.py`
- `trustgraph-base/trustgraph/api/socket_client.py`
### Data Models
#### EmbeddingsRequest
```python
@dataclass
class EmbeddingsRequest:
texts: list[str] = field(default_factory=list)
```
Usage:
- Single text: `EmbeddingsRequest(texts=["hello world"])`
- Batch: `EmbeddingsRequest(texts=["text1", "text2", "text3"])`
#### EmbeddingsResponse
```python
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
```
Response structure:
- `vectors[i]` contains the vector set for `texts[i]`
- Each vector set is `list[list[float]]` (models may return multiple vectors per text)
- Example: 3 texts → `vectors` has 3 entries, each containing that text's embeddings
### APIs
#### EmbeddingsClient
```python
class EmbeddingsClient(RequestResponse):
async def embed(
self,
texts: list[str],
timeout: float = 300,
) -> list[list[list[float]]]:
"""
Embed one or more texts in a single request.
Args:
texts: List of texts to embed
timeout: Timeout for the operation
Returns:
List of vector sets, one per input text
"""
resp = await self.request(
EmbeddingsRequest(texts=texts),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.vectors
```
#### API Gateway Embeddings Endpoint
Updated endpoint supporting single or batch embedding:
```
POST /api/v1/embeddings
Content-Type: application/json
{
"texts": ["text1", "text2", "text3"],
"flow_id": "default"
}
Response:
{
"vectors": [
[[0.1, 0.2, ...]],
[[0.3, 0.4, ...]],
[[0.5, 0.6, ...]]
]
}
```
### Implementation Details
#### Phase 1: Schema Changes
**EmbeddingsRequest:**
```python
@dataclass
class EmbeddingsRequest:
texts: list[str] = field(default_factory=list)
```
**EmbeddingsResponse:**
```python
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
```
**Updated EmbeddingsService.on_request:**
```python
async def on_request(self, msg, consumer, flow):
request = msg.value()
id = msg.properties()["id"]
model = flow("model")
vectors = await self.on_embeddings(request.texts, model=model)
response = EmbeddingsResponse(error=None, vectors=vectors)
await flow("response").send(response, properties={"id": id})
```
#### Phase 2: FastEmbed Processor Update
**Current (Inefficient):**
```python
async def on_embeddings(self, text, model=None):
use_model = model or self.default_model
self._load_model(use_model)
vecs = self.embeddings.embed([text]) # Batch of 1
return [v.tolist() for v in vecs]
```
**Updated:**
```python
async def on_embeddings(self, texts: list[str], model=None):
"""Embed texts - processes all texts in single model call"""
if not texts:
return []
use_model = model or self.default_model
self._load_model(use_model)
# FastEmbed handles the full batch efficiently
all_vecs = list(self.embeddings.embed(texts))
# Return list of vector sets, one per input text
return [[v.tolist()] for v in all_vecs]
```
#### Phase 3: Graph Embeddings Service Update
**Current (Serial):**
```python
async def on_message(self, msg, consumer, flow):
entities = []
for entity in v.entities:
vectors = await flow("embeddings-request").embed(text=entity.context)
entities.append(EntityEmbeddings(...))
```
**Updated (Batch):**
```python
async def on_message(self, msg, consumer, flow):
# Collect all contexts
contexts = [entity.context for entity in v.entities]
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=contexts)
# Pair results with entities
entities = [
EntityEmbeddings(
entity=entity.entity,
vectors=vectors[0], # First vector from the set
chunk_id=entity.chunk_id,
)
for entity, vectors in zip(v.entities, all_vectors)
]
```
#### Phase 4: Row Embeddings Service Update
**Current (Serial):**
```python
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(RowIndexEmbedding(...))
```
**Updated (Batch):**
```python
# Collect texts and metadata
texts = list(texts_to_embed.keys())
metadata = list(texts_to_embed.values())
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results
embeddings_list = [
RowIndexEmbedding(
index_name=meta[0],
index_value=meta[1],
text=text,
vectors=vectors[0] # First vector from the set
)
for text, meta, vectors in zip(texts, metadata, all_vectors)
]
```
#### Phase 5: CLI Tool Enhancement
**Updated CLI:**
```python
def main():
parser = argparse.ArgumentParser(...)
parser.add_argument(
'text',
nargs='*', # Zero or more texts
help='Text(s) to convert to embedding vectors',
)
parser.add_argument(
'-f', '--file',
help='File containing texts (one per line)',
)
parser.add_argument(
'--batch-size',
type=int,
default=32,
help='Batch size for processing (default: 32)',
)
```
Usage:
```bash
# Single text (existing)
tg-invoke-embeddings "hello world"
# Multiple texts
tg-invoke-embeddings "text one" "text two" "text three"
# From file
tg-invoke-embeddings -f texts.txt --batch-size 64
```
#### Phase 6: Python SDK Enhancement
**FlowInstance (HTTP client):**
```python
class FlowInstance:
def embeddings(self, texts: list[str]) -> list[list[list[float]]]:
"""
Get embeddings for one or more texts.
Args:
texts: List of texts to embed
Returns:
List of vector sets, one per input text
"""
input = {"texts": texts}
return self.request("service/embeddings", input)["vectors"]
```
**SocketFlowInstance (WebSocket client):**
```python
class SocketFlowInstance:
def embeddings(self, texts: list[str], **kwargs: Any) -> list[list[list[float]]]:
"""
Get embeddings for one or more texts via WebSocket.
Args:
texts: List of texts to embed
Returns:
List of vector sets, one per input text
"""
request = {"texts": texts}
response = self.client._send_request_sync(
"embeddings", self.flow_id, request, False
)
return response["vectors"]
```
**SDK Usage Examples:**
```python
# Single text
vectors = flow.embeddings(["hello world"])
print(f"Dimensions: {len(vectors[0][0])}")
# Batch embedding
texts = ["text one", "text two", "text three"]
all_vectors = flow.embeddings(texts)
# Process results
for text, vecs in zip(texts, all_vectors):
print(f"{text}: {len(vecs[0])} dimensions")
```
## Security Considerations
- **Request Size Limits**: Enforce maximum batch size to prevent resource exhaustion
- **Timeout Handling**: Scale timeouts appropriately for batch size
- **Memory Limits**: Monitor memory usage for large batches
- **Input Validation**: Validate all texts in batch before processing
## Performance Considerations
### Expected Improvements
**Throughput:**
- Single-text: ~10-50 texts/second (depending on model)
- Batch (size 32): ~200-500 texts/second (5-10x improvement)
**Latency per Text:**
- Single-text: 50-200ms per text
- Batch (size 32): 5-20ms per text (amortized)
**Service-Specific Improvements:**
| Service | Current | Batched | Improvement |
|---------|---------|---------|-------------|
| Graph Embeddings (50 entities) | 5-10s | 0.5-1s | 5-10x |
| Row Embeddings (100 texts) | 10-20s | 1-2s | 5-10x |
| Document Ingestion (1000 chunks) | 100-200s | 10-30s | 5-10x |
### Configuration Parameters
```python
# Recommended defaults
DEFAULT_BATCH_SIZE = 32
MAX_BATCH_SIZE = 128
BATCH_TIMEOUT_MULTIPLIER = 2.0
```
## Testing Strategy
### Unit Testing
- Single text embedding (backward compatibility)
- Empty batch handling
- Maximum batch size enforcement
- Error handling for partial batch failures
### Integration Testing
- End-to-end batch embedding through Pulsar
- Graph embeddings service batch processing
- Row embeddings service batch processing
- API gateway batch endpoint
### Performance Testing
- Benchmark single vs batch throughput
- Memory usage under various batch sizes
- Latency distribution analysis
## Migration Plan
This is a breaking change release. All phases are implemented together.
### Phase 1: Schema Changes
- Replace `text: str` with `texts: list[str]` in EmbeddingsRequest
- Change `vectors` type to `list[list[list[float]]]` in EmbeddingsResponse
### Phase 2: Processor Updates
- Update `on_embeddings` signature in FastEmbed and Ollama processors
- Process full batch in single model call
### Phase 3: Client Updates
- Update `EmbeddingsClient.embed()` to accept `texts: list[str]`
### Phase 4: Caller Updates
- Update graph_embeddings to batch entity contexts
- Update row_embeddings to batch index texts
- Update document_embeddings to use new schema
- Update CLI tool
### Phase 5: API Gateway
- Update embeddings endpoint for new schema
### Phase 6: Python SDK
- Update `FlowInstance.embeddings()` signature
- Update `SocketFlowInstance.embeddings()` signature
## Open Questions
- **Streaming Large Batches**: Should we support streaming results for very large batches (>100 texts)?
- **Provider-Specific Limits**: How should we handle providers with different maximum batch sizes?
- **Partial Failure Handling**: If one text in a batch fails, should we fail the entire batch or return partial results?
- **Document Embeddings Batching**: Should we batch across multiple Chunk messages or keep per-message processing?
## References
- [FastEmbed Documentation](https://github.com/qdrant/fastembed)
- [Ollama Embeddings API](https://github.com/ollama/ollama)
- [EmbeddingsService Implementation](trustgraph-base/trustgraph/base/embeddings_service.py)
- [GraphRAG Performance Optimization](graphrag-performance-optimization.md)

View file

@ -42,7 +42,7 @@ CREATE TABLE quads_by_entity (
d text, -- Dataset/graph of the quad
dtype text, -- XSD datatype (when otype = 'L'), e.g. 'xsd:string'
lang text, -- Language tag (when otype = 'L'), e.g. 'en', 'fr'
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d, dtype, lang)
);
```
@ -54,6 +54,7 @@ CREATE TABLE quads_by_entity (
2. **p** — next most common filter, "give me all `knows` relationships"
3. **otype** — enables filtering by URI-valued vs literal-valued relationships
4. **s, o, d** — remaining columns for uniqueness
5. **dtype, lang** — distinguish literals with same value but different type metadata (e.g., `"thing"` vs `"thing"@en` vs `"thing"^^xsd:string`)
### Table 2: quads_by_collection
@ -69,11 +70,11 @@ CREATE TABLE quads_by_collection (
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
dtype text, -- XSD datatype (when otype = 'L')
lang text, -- Language tag (when otype = 'L')
PRIMARY KEY (collection, d, s, p, o)
PRIMARY KEY (collection, d, s, p, o, otype, dtype, lang)
);
```
Clustered by dataset first, enabling deletion at either collection or dataset granularity.
Clustered by dataset first, enabling deletion at either collection or dataset granularity. The `otype`, `dtype`, and `lang` columns are included in the clustering key to distinguish literals with the same value but different type metadata — in RDF, `"thing"`, `"thing"@en`, and `"thing"^^xsd:string` are semantically distinct values.
## Write Path

View file

@ -0,0 +1,220 @@
# Explainability CLI Technical Specification
## Status
Draft
## Overview
This specification describes CLI tools for debugging and exploring explainability data in TrustGraph. These tools enable users to trace how answers were derived and debug the provenance chain from edges back to source documents.
Three CLI tools:
1. **`tg-show-document-hierarchy`** - Show document → page → chunk → edge hierarchy
2. **`tg-list-explain-traces`** - List all GraphRAG sessions with questions
3. **`tg-show-explain-trace`** - Show full explainability trace for a session
## Goals
- **Debugging**: Enable developers to inspect document processing results
- **Auditability**: Trace any extracted fact back to its source document
- **Transparency**: Show exactly how GraphRAG derived an answer
- **Usability**: Simple CLI interface with sensible defaults
## Background
TrustGraph has two provenance systems:
1. **Extraction-time provenance** (see `extraction-time-provenance.md`): Records document → page → chunk → edge relationships during ingestion. Stored in `urn:graph:source` named graph using `prov:wasDerivedFrom`.
2. **Query-time explainability** (see `query-time-explainability.md`): Records question → exploration → focus → synthesis chain during GraphRAG queries. Stored in `urn:graph:retrieval` named graph.
Current limitations:
- No easy way to visualize document hierarchy after processing
- Must manually query triples to see explainability data
- No consolidated view of a GraphRAG session
## Technical Design
### Tool 1: tg-show-document-hierarchy
**Purpose**: Given a document ID, traverse and display all derived entities.
**Usage**:
```bash
tg-show-document-hierarchy "urn:trustgraph:doc:abc123"
tg-show-document-hierarchy --show-content --max-content 500 "urn:trustgraph:doc:abc123"
```
**Arguments**:
| Arg | Description |
|-----|-------------|
| `document_id` | Document URI (positional) |
| `-u/--api-url` | Gateway URL (default: `$TRUSTGRAPH_URL`) |
| `-t/--token` | Auth token (default: `$TRUSTGRAPH_TOKEN`) |
| `-U/--user` | User ID (default: `trustgraph`) |
| `-C/--collection` | Collection (default: `default`) |
| `--show-content` | Include blob/document content |
| `--max-content` | Max chars per blob (default: 200) |
| `--format` | Output: `tree` (default), `json` |
**Implementation**:
1. Query triples: `?child prov:wasDerivedFrom <document_id>` in `urn:graph:source`
2. Recursively query children of each result
3. Build tree structure: Document → Pages → Chunks
4. If `--show-content`, fetch content from librarian API
5. Display as indented tree or JSON
**Output Example**:
```
Document: urn:trustgraph:doc:abc123
Title: "Sample PDF"
Type: application/pdf
└── Page 1: urn:trustgraph:doc:abc123/p1
├── Chunk 0: urn:trustgraph:doc:abc123/p1/c0
│ Content: "The quick brown fox..." [truncated]
└── Chunk 1: urn:trustgraph:doc:abc123/p1/c1
Content: "Machine learning is..." [truncated]
```
### Tool 2: tg-list-explain-traces
**Purpose**: List all GraphRAG sessions (questions) in a collection.
**Usage**:
```bash
tg-list-explain-traces
tg-list-explain-traces --limit 20 --format json
```
**Arguments**:
| Arg | Description |
|-----|-------------|
| `-u/--api-url` | Gateway URL |
| `-t/--token` | Auth token |
| `-U/--user` | User ID |
| `-C/--collection` | Collection |
| `--limit` | Max results (default: 50) |
| `--format` | Output: `table` (default), `json` |
**Implementation**:
1. Query: `?session tg:query ?text` in `urn:graph:retrieval`
2. Query timestamps: `?session prov:startedAtTime ?time`
3. Display as table
**Output Example**:
```
Session ID | Question | Time
----------------------------------------------|--------------------------------|---------------------
urn:trustgraph:question:abc123 | What was the War on Terror? | 2024-01-15 10:30:00
urn:trustgraph:question:def456 | Who founded OpenAI? | 2024-01-15 09:15:00
```
### Tool 3: tg-show-explain-trace
**Purpose**: Show full explainability cascade for a GraphRAG session.
**Usage**:
```bash
tg-show-explain-trace "urn:trustgraph:question:abc123"
tg-show-explain-trace --max-answer 1000 --show-provenance "urn:trustgraph:question:abc123"
```
**Arguments**:
| Arg | Description |
|-----|-------------|
| `question_id` | Question URI (positional) |
| `-u/--api-url` | Gateway URL |
| `-t/--token` | Auth token |
| `-U/--user` | User ID |
| `-C/--collection` | Collection |
| `--max-answer` | Max chars for answer (default: 500) |
| `--show-provenance` | Trace edges to source documents |
| `--format` | Output: `text` (default), `json` |
**Implementation**:
1. Get question text from `tg:query` predicate
2. Find exploration: `?exp prov:wasGeneratedBy <question_id>`
3. Find focus: `?focus prov:wasDerivedFrom <exploration_id>`
4. Get selected edges: `<focus_id> tg:selectedEdge ?edge`
5. For each edge, get `tg:edge` (quoted triple) and `tg:reasoning`
6. Find synthesis: `?synth prov:wasDerivedFrom <focus_id>`
7. Get answer from `tg:document` via librarian
8. If `--show-provenance`, trace edges to source documents
**Output Example**:
```
=== GraphRAG Session: urn:trustgraph:question:abc123 ===
Question: What was the War on Terror?
Time: 2024-01-15 10:30:00
--- Exploration ---
Retrieved 50 edges from knowledge graph
--- Focus (Edge Selection) ---
Selected 12 edges:
1. (War on Terror, definition, "A military campaign...")
Reasoning: Directly defines the subject of the query
Source: chunk → page 2 → "Beyond the Vigilant State"
2. (Guantanamo Bay, part_of, War on Terror)
Reasoning: Shows key component of the campaign
--- Synthesis ---
Answer:
The War on Terror was a military campaign initiated...
[truncated at 500 chars]
```
## Files to Create
| File | Purpose |
|------|---------|
| `trustgraph-cli/trustgraph/cli/show_document_hierarchy.py` | Tool 1 |
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | Tool 2 |
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Tool 3 |
## Files to Modify
| File | Change |
|------|--------|
| `trustgraph-cli/setup.py` | Add console_scripts entries |
## Implementation Notes
1. **Binary content safety**: Try UTF-8 decode; if fails, show `[Binary: {size} bytes]`
2. **Truncation**: Respect `--max-content`/`--max-answer` with `[truncated]` indicator
3. **Quoted triples**: Parse RDF-star format from `tg:edge` predicate
4. **Patterns**: Follow existing CLI patterns from `query_graph.py`
## Security Considerations
- All queries respect user/collection boundaries
- Token authentication supported via `--token` or `$TRUSTGRAPH_TOKEN`
## Testing Strategy
Manual verification with sample data:
```bash
# Load a test document
tg-load-pdf -f test.pdf -c test-collection
# Verify hierarchy
tg-show-document-hierarchy "urn:trustgraph:doc:test"
# Run a GraphRAG query with explainability
tg-invoke-graph-rag --explainable -q "Test question"
# List and inspect traces
tg-list-explain-traces
tg-show-explain-trace "urn:trustgraph:question:xxx"
```
## References
- Query-time explainability: `docs/tech-specs/query-time-explainability.md`
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
- Existing CLI example: `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py`

View file

@ -0,0 +1,347 @@
# Extraction Flows
This document describes how data flows through the TrustGraph extraction pipeline, from document submission through to storage in knowledge stores.
## Overview
```
┌──────────┐ ┌─────────────┐ ┌─────────┐ ┌────────────────────┐
│ Librarian│────▶│ PDF Decoder │────▶│ Chunker │────▶│ Knowledge │
│ │ │ (PDF only) │ │ │ │ Extraction │
│ │────────────────────────▶│ │ │ │
└──────────┘ └─────────────┘ └─────────┘ └────────────────────┘
│ │
│ ├──▶ Triples
│ ├──▶ Entity Contexts
│ └──▶ Rows
└──▶ Document Embeddings
```
## Content Storage
### Blob Storage (S3/Minio)
Document content is stored in S3-compatible blob storage:
- Path format: `doc/{object_id}` where object_id is a UUID
- All document types stored here: source documents, pages, chunks
### Metadata Storage (Cassandra)
Document metadata stored in Cassandra includes:
- Document ID, title, kind (MIME type)
- `object_id` reference to blob storage
- `parent_id` for child documents (pages, chunks)
- `document_type`: "source", "page", "chunk", "answer"
### Inline vs Streaming Threshold
Content transmission uses a size-based strategy:
- **< 2MB**: Content included inline in message (base64-encoded)
- **≥ 2MB**: Only `document_id` sent; processor fetches via librarian API
## Stage 1: Document Submission (Librarian)
### Entry Point
Documents enter the system via librarian's `add-document` operation:
1. Content uploaded to blob storage
2. Metadata record created in Cassandra
3. Returns document ID
### Triggering Extraction
The `add-processing` operation triggers extraction:
- Specifies `document_id`, `flow` (pipeline ID), `collection` (target store)
- Librarian's `load_document()` fetches content and publishes to flow input queue
### Schema: Document
```
Document
├── metadata: Metadata
│ ├── id: str # Document identifier
│ ├── user: str # Tenant/user ID
│ ├── collection: str # Target collection
│ └── metadata: list[Triple] # (largely unused, historical)
├── data: bytes # PDF content (base64, if inline)
└── document_id: str # Librarian reference (if streaming)
```
**Routing**: Based on `kind` field:
- `application/pdf``document-load` queue → PDF Decoder
- `text/plain``text-load` queue → Chunker
## Stage 2: PDF Decoder
Converts PDF documents into text pages.
### Process
1. Fetch content (inline `data` or via `document_id` from librarian)
2. Extract pages using PyPDF
3. For each page:
- Save as child document in librarian (`{doc_id}/p{page_num}`)
- Emit provenance triples (page derived from document)
- Forward to chunker
### Schema: TextDocument
```
TextDocument
├── metadata: Metadata
│ ├── id: str # Page URI (e.g., https://trustgraph.ai/doc/xxx/p1)
│ ├── user: str
│ ├── collection: str
│ └── metadata: list[Triple]
├── text: bytes # Page text content (if inline)
└── document_id: str # Librarian reference (e.g., "doc123/p1")
```
## Stage 3: Chunker
Splits text into chunks at configured size.
### Parameters (flow-configurable)
- `chunk_size`: Target chunk size in characters (default: 2000)
- `chunk_overlap`: Overlap between chunks (default: 100)
### Process
1. Fetch text content (inline or via librarian)
2. Split using recursive character splitter
3. For each chunk:
- Save as child document in librarian (`{parent_id}/c{index}`)
- Emit provenance triples (chunk derived from page/document)
- Forward to extraction processors
### Schema: Chunk
```
Chunk
├── metadata: Metadata
│ ├── id: str # Chunk URI
│ ├── user: str
│ ├── collection: str
│ └── metadata: list[Triple]
├── chunk: bytes # Chunk text content
└── document_id: str # Librarian chunk ID (e.g., "doc123/p1/c3")
```
### Document ID Hierarchy
Child documents encode their lineage in the ID:
- Source: `doc123`
- Page: `doc123/p5`
- Chunk from page: `doc123/p5/c2`
- Chunk from text: `doc123/c2`
## Stage 4: Knowledge Extraction
Multiple extraction patterns available, selected by flow configuration.
### Pattern A: Basic GraphRAG
Two parallel processors:
**kg-extract-definitions**
- Input: Chunk
- Output: Triples (entity definitions), EntityContexts
- Extracts: entity labels, definitions
**kg-extract-relationships**
- Input: Chunk
- Output: Triples (relationships), EntityContexts
- Extracts: subject-predicate-object relationships
### Pattern B: Ontology-Driven (kg-extract-ontology)
- Input: Chunk
- Output: Triples, EntityContexts
- Uses configured ontology to guide extraction
### Pattern C: Agent-Based (kg-extract-agent)
- Input: Chunk
- Output: Triples, EntityContexts
- Uses agent framework for extraction
### Pattern D: Row Extraction (kg-extract-rows)
- Input: Chunk
- Output: Rows (structured data, not triples)
- Uses schema definition to extract structured records
### Schema: Triples
```
Triples
├── metadata: Metadata
│ ├── id: str
│ ├── user: str
│ ├── collection: str
│ └── metadata: list[Triple] # (set to [] by extractors)
└── triples: list[Triple]
└── Triple
├── s: Term # Subject
├── p: Term # Predicate
├── o: Term # Object
└── g: str | None # Named graph
```
### Schema: EntityContexts
```
EntityContexts
├── metadata: Metadata
└── entities: list[EntityContext]
└── EntityContext
├── entity: Term # Entity identifier (IRI)
├── context: str # Textual description for embedding
└── chunk_id: str # Source chunk ID (provenance)
```
### Schema: Rows
```
Rows
├── metadata: Metadata
├── row_schema: RowSchema
│ ├── name: str
│ ├── description: str
│ └── fields: list[Field]
└── rows: list[dict[str, str]] # Extracted records
```
## Stage 5: Embeddings Generation
### Graph Embeddings
Converts entity contexts into vector embeddings.
**Process:**
1. Receive EntityContexts
2. Call embeddings service with context text
3. Output GraphEmbeddings (entity → vector mapping)
**Schema: GraphEmbeddings**
```
GraphEmbeddings
├── metadata: Metadata
└── entities: list[EntityEmbeddings]
└── EntityEmbeddings
├── entity: Term # Entity identifier
├── vector: list[float] # Embedding vector
└── chunk_id: str # Source chunk (provenance)
```
### Document Embeddings
Converts chunk text directly into vector embeddings.
**Process:**
1. Receive Chunk
2. Call embeddings service with chunk text
3. Output DocumentEmbeddings
**Schema: DocumentEmbeddings**
```
DocumentEmbeddings
├── metadata: Metadata
└── chunks: list[ChunkEmbeddings]
└── ChunkEmbeddings
├── chunk_id: str # Chunk identifier
└── vector: list[float] # Embedding vector
```
### Row Embeddings
Converts row index fields into vector embeddings.
**Process:**
1. Receive Rows
2. Embed configured index fields
3. Output to row vector store
## Stage 6: Storage
### Triple Store
- Receives: Triples
- Storage: Cassandra (entity-centric tables)
- Named graphs separate core knowledge from provenance:
- `""` (default): Core knowledge facts
- `urn:graph:source`: Extraction provenance
- `urn:graph:retrieval`: Query-time explainability
### Vector Store (Graph Embeddings)
- Receives: GraphEmbeddings
- Storage: Qdrant, Milvus, or Pinecone
- Indexed by: entity IRI
- Metadata: chunk_id for provenance
### Vector Store (Document Embeddings)
- Receives: DocumentEmbeddings
- Storage: Qdrant, Milvus, or Pinecone
- Indexed by: chunk_id
### Row Store
- Receives: Rows
- Storage: Cassandra
- Schema-driven table structure
### Row Vector Store
- Receives: Row embeddings
- Storage: Vector DB
- Indexed by: row index fields
## Metadata Field Analysis
### Actively Used Fields
| Field | Usage |
|-------|-------|
| `metadata.id` | Document/chunk identifier, logging, provenance |
| `metadata.user` | Multi-tenancy, storage routing |
| `metadata.collection` | Target collection selection |
| `document_id` | Librarian reference, provenance linking |
| `chunk_id` | Provenance tracking through pipeline |
<<<<<<< HEAD
### Potentially Redundant Fields
| Field | Status |
|-------|--------|
| `metadata.metadata` | Set to `[]` by all extractors; document-level metadata now handled by librarian at submission time |
=======
### Removed Fields
| Field | Status |
|-------|--------|
| `metadata.metadata` | Removed from `Metadata` class. Document-level metadata triples are now emitted directly by librarian to triple store at submission time, not carried through the extraction pipeline. |
>>>>>>> e3bcbf73 (The metadata field (list of triples) in the pipeline Metadata class)
### Bytes Fields Pattern
All content fields (`data`, `text`, `chunk`) are `bytes` but immediately decoded to UTF-8 strings by all processors. No processor uses raw bytes.
## Flow Configuration
Flows are defined externally and provided to librarian via config service. Each flow specifies:
- Input queues (`text-load`, `document-load`)
- Processor chain
- Parameters (chunk size, extraction method, etc.)
Example flow patterns:
- `pdf-graphrag`: PDF → Decoder → Chunker → Definitions + Relationships → Embeddings
- `text-graphrag`: Text → Chunker → Definitions + Relationships → Embeddings
- `pdf-ontology`: PDF → Decoder → Chunker → Ontology Extraction → Embeddings
- `text-rows`: Text → Chunker → Row Extraction → Row Store

View file

@ -0,0 +1,205 @@
# Extraction Provenance: Subgraph Model
## Problem
Extraction-time provenance currently generates a full reification per
extracted triple: a unique `stmt_uri`, `activity_uri`, and associated
PROV-O metadata for every single knowledge fact. Processing one chunk
that yields 20 relationships produces ~220 provenance triples on top of
the ~20 knowledge triples — a roughly 10:1 overhead.
This is both expensive (storage, indexing, transmission) and semantically
inaccurate. Each chunk is processed by a single LLM call that produces
all its triples in one transaction. The current per-triple model
obscures that by creating the illusion of 20 independent extraction
events.
Additionally, two of the four extraction processors (kg-extract-ontology,
kg-extract-agent) have no provenance at all, leaving gaps in the audit
trail.
## Solution
Replace per-triple reification with a **subgraph model**: one provenance
record per chunk extraction, shared across all triples produced from that
chunk.
### Terminology Change
| Old | New |
|-----|-----|
| `stmt_uri` (`https://trustgraph.ai/stmt/{uuid}`) | `subgraph_uri` (`https://trustgraph.ai/subgraph/{uuid}`) |
| `statement_uri()` | `subgraph_uri()` |
| `tg:reifies` (1:1, identity) | `tg:contains` (1:many, containment) |
### Target Structure
All provenance triples go in the `urn:graph:source` named graph.
```
# Subgraph contains each extracted triple (RDF-star quoted triples)
<subgraph> tg:contains <<s1 p1 o1>> .
<subgraph> tg:contains <<s2 p2 o2>> .
<subgraph> tg:contains <<s3 p3 o3>> .
# Derivation from source chunk
<subgraph> prov:wasDerivedFrom <chunk_uri> .
<subgraph> prov:wasGeneratedBy <activity> .
# Activity: one per chunk extraction
<activity> rdf:type prov:Activity .
<activity> rdfs:label "{component_name} extraction" .
<activity> prov:used <chunk_uri> .
<activity> prov:wasAssociatedWith <agent> .
<activity> prov:startedAtTime "2026-03-13T10:00:00Z" .
<activity> tg:componentVersion "0.25.0" .
<activity> tg:llmModel "gpt-4" . # if available
<activity> tg:ontology <ontology_uri> . # if available
# Agent: stable per component
<agent> rdf:type prov:Agent .
<agent> rdfs:label "{component_name}" .
```
### Volume Comparison
For a chunk producing N extracted triples:
| | Old (per-triple) | New (subgraph) |
|---|---|---|
| `tg:contains` / `tg:reifies` | N | N |
| Activity triples | ~9 x N | ~9 |
| Agent triples | 2 x N | 2 |
| Statement/subgraph metadata | 2 x N | 2 |
| **Total provenance triples** | **~13N** | **N + 13** |
| **Example (N=20)** | **~260** | **33** |
## Scope
### Processors to Update (existing provenance, per-triple)
**kg-extract-definitions**
(`trustgraph-flow/trustgraph/extract/kg/definitions/extract.py`)
Currently calls `statement_uri()` + `triple_provenance_triples()` inside
the per-definition loop.
Changes:
- Move `subgraph_uri()` and `activity_uri()` creation before the loop
- Collect `tg:contains` triples inside the loop
- Emit shared activity/agent/derivation block once after the loop
**kg-extract-relationships**
(`trustgraph-flow/trustgraph/extract/kg/relationships/extract.py`)
Same pattern as definitions. Same changes.
### Processors to Add Provenance (currently missing)
**kg-extract-ontology**
(`trustgraph-flow/trustgraph/extract/kg/ontology/extract.py`)
Currently emits triples with no provenance. Add subgraph provenance
using the same pattern: one subgraph per chunk, `tg:contains` for each
extracted triple.
**kg-extract-agent**
(`trustgraph-flow/trustgraph/extract/kg/agent/extract.py`)
Currently emits triples with no provenance. Add subgraph provenance
using the same pattern.
### Shared Provenance Library Changes
**`trustgraph-base/trustgraph/provenance/triples.py`**
- Replace `triple_provenance_triples()` with `subgraph_provenance_triples()`
- New function accepts a list of extracted triples instead of a single one
- Generates one `tg:contains` per triple, shared activity/agent block
- Remove old `triple_provenance_triples()`
**`trustgraph-base/trustgraph/provenance/uris.py`**
- Replace `statement_uri()` with `subgraph_uri()`
**`trustgraph-base/trustgraph/provenance/namespaces.py`**
- Replace `TG_REIFIES` with `TG_CONTAINS`
### Not in Scope
- **kg-extract-topics**: older-style processor, not currently used in
standard flows
- **kg-extract-rows**: produces rows not triples, different provenance
model
- **Query-time provenance** (`urn:graph:retrieval`): separate concern,
already uses a different pattern (question/exploration/focus/synthesis)
- **Document/page/chunk provenance** (PDF decoder, chunker): already uses
`derived_entity_triples()` which is per-entity, not per-triple — no
redundancy issue
## Implementation Notes
### Processor Loop Restructure
Before (per-triple, in relationships):
```python
for rel in rels:
# ... build relationship_triple ...
stmt_uri = statement_uri()
prov_triples = triple_provenance_triples(
stmt_uri=stmt_uri,
extracted_triple=relationship_triple,
...
)
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
```
After (subgraph):
```python
sg_uri = subgraph_uri()
for rel in rels:
# ... build relationship_triple ...
extracted_triples.append(relationship_triple)
prov_triples = subgraph_provenance_triples(
subgraph_uri=sg_uri,
extracted_triples=extracted_triples,
chunk_uri=chunk_uri,
component_name=default_ident,
component_version=COMPONENT_VERSION,
llm_model=llm_model,
ontology_uri=ontology_uri,
)
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
```
### New Helper Signature
```python
def subgraph_provenance_triples(
subgraph_uri: str,
extracted_triples: List[Triple],
chunk_uri: str,
component_name: str,
component_version: str,
llm_model: Optional[str] = None,
ontology_uri: Optional[str] = None,
timestamp: Optional[str] = None,
) -> List[Triple]:
"""
Build provenance triples for a subgraph of extracted knowledge.
Creates:
- tg:contains link for each extracted triple (RDF-star quoted)
- One prov:wasDerivedFrom link to source chunk
- One activity with agent metadata
"""
```
### Breaking Change
This is a breaking change to the provenance model. Provenance has not
been released, so no migration is needed. The old `tg:reifies` /
`statement_uri` code can be removed outright.

View file

@ -0,0 +1,619 @@
# Extraction-Time Provenance: Source Layer
## Overview
This document captures notes on extraction-time provenance for future specification work. Extraction-time provenance records the "source layer" - where data came from originally, how it was extracted and transformed.
This is separate from query-time provenance (see `query-time-provenance.md`) which records agent reasoning.
## Problem Statement
### Current Implementation
Provenance currently works as follows:
- Document metadata is stored as RDF triples in the knowledge graph
- A document ID ties metadata to the document, so the document appears as a node in the graph
- When edges (relationships/facts) are extracted from documents, a `subjectOf` relationship links the extracted edge back to the source document
### Problems with Current Approach
1. **Repetitive metadata loading:** Document metadata is bundled and loaded repeatedly with every batch of triples extracted from that document. This is wasteful and redundant - the same metadata travels as cargo with every extraction output.
2. **Shallow provenance:** The current `subjectOf` relationship only links facts directly to the top-level document. There is no visibility into the transformation chain - which page the fact came from, which chunk, what extraction method was used.
### Desired State
1. **Load metadata once:** Document metadata should be loaded once and attached to the top-level document node, not repeated with every triple batch.
2. **Rich provenance DAG:** Capture the full transformation chain from source document through all intermediate artifacts down to extracted facts. For example, a PDF document transformation:
```
PDF file (source document with metadata)
→ Page 1 (decoded text)
→ Chunk 1
→ Extracted edge/fact (via subjectOf)
→ Extracted edge/fact
→ Chunk 2
→ Extracted edge/fact
→ Page 2
→ Chunk 3
→ ...
```
3. **Unified storage:** The provenance DAG is stored in the same knowledge graph as the extracted knowledge. This allows provenance to be queried the same way as knowledge - following edges back up the chain from any fact to its exact source location.
4. **Stable IDs:** Each intermediate artifact (page, chunk) has a stable ID as a node in the graph.
5. **Parent-child linking:** Derived documents are linked to their parents all the way up to the top-level source document using consistent relationship types.
6. **Precise fact attribution:** The `subjectOf` relationship on extracted edges points to the immediate parent (chunk), not the top-level document. Full provenance is recovered by traversing up the DAG.
## Use Cases
### UC1: Source Attribution in GraphRAG Responses
**Scenario:** A user runs a GraphRAG query and receives a response from the agent.
**Flow:**
1. User submits a query to the GraphRAG agent
2. Agent retrieves relevant facts from the knowledge graph to formulate a response
3. Per the query-time provenance spec, the agent reports which facts contributed to the response
4. Each fact links to its source chunk via the provenance DAG
5. Chunks link to pages, pages link to source documents
**UX Outcome:** The interface displays the LLM response alongside source attribution. The user can:
- See which facts supported the response
- Drill down from facts → chunks → pages → documents
- Peruse the original source documents to verify claims
- Understand exactly where in a document (which page, which section) a fact originated
**Value:** Users can verify AI-generated responses against primary sources, building trust and enabling fact-checking.
### UC2: Debugging Extraction Quality
A fact looks wrong. Trace back through chunk → page → document to see the original text. Was it a bad extraction, or was the source itself wrong?
### UC3: Incremental Re-extraction
Source document gets updated. Which chunks/facts were derived from it? Invalidate and regenerate just those, rather than re-processing everything.
### UC4: Data Deletion / Right to be Forgotten
A source document must be removed (GDPR, legal, etc.). Traverse the DAG to find and remove all derived facts.
### UC5: Conflict Resolution
Two facts contradict each other. Trace both back to their sources to understand why and decide which to trust (more authoritative source, more recent, etc.).
### UC6: Source Authority Weighting
Some sources are more authoritative than others. Facts can be weighted or filtered based on the authority/quality of their origin documents.
### UC7: Extraction Pipeline Comparison
Compare outputs from different extraction methods/versions. Which extractor produced better facts from the same source?
## Integration Points
### Librarian
The librarian component already provides document storage with unique document IDs. The provenance system integrates with this existing infrastructure.
#### Existing Capabilities (already implemented)
**Parent-Child Document Linking:**
- `parent_id` field in `DocumentMetadata` - links child to parent document
- `document_type` field - values: `"source"` (original) or `"extracted"` (derived)
- `add-child-document` API - creates child document with automatic `document_type = "extracted"`
- `list-children` API - retrieves all children of a parent document
- Cascade deletion - removing a parent automatically deletes all child documents
**Document Identification:**
- Document IDs are client-specified (not auto-generated)
- Documents keyed by composite `(user, document_id)` in Cassandra
- Object IDs (UUIDs) generated internally for blob storage
**Metadata Support:**
- `metadata: list[Triple]` field - RDF triples for structured metadata
- `title`, `comments`, `tags` - basic document metadata
- `time` - timestamp, `kind` - MIME type
**Storage Architecture:**
- Metadata stored in Cassandra (`librarian` keyspace, `document` table)
- Content stored in MinIO/S3 blob storage (`library` bucket)
- Smart content delivery: documents < 2MB embedded, larger documents streamed
#### Key Files
- `trustgraph-flow/trustgraph/librarian/librarian.py` - Core librarian operations
- `trustgraph-flow/trustgraph/librarian/service.py` - Service processor, document loading
- `trustgraph-flow/trustgraph/tables/library.py` - Cassandra table store
- `trustgraph-base/trustgraph/schema/services/library.py` - Schema definitions
#### Gaps to Address
The librarian has the building blocks but currently:
1. Parent-child linking is one level deep - no multi-level DAG traversal helpers
2. No standard relationship type vocabulary (e.g., `derivedFrom`, `extractedFrom`)
3. Provenance metadata (extraction method, confidence, chunk position) not standardized
4. No query API to traverse the full provenance chain from a fact back to source
## End-to-End Flow Design
Each processor in the pipeline follows a consistent pattern:
- Receive document ID from upstream
- Fetch content from librarian
- Produce child artifacts
- For each child: save to librarian, emit edge to graph, forward ID downstream
### Processing Flows
There are two flows depending on document type:
#### PDF Document Flow
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Librarian (initiate processing) │
│ 1. Emit root document metadata to knowledge graph (once) │
│ 2. Send root document ID to PDF extractor │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ PDF Extractor (per page) │
│ 1. Fetch PDF content from librarian using document ID │
│ 2. Extract pages as text │
│ 3. For each page: │
│ a. Save page as child document in librarian (parent = root doc) │
│ b. Emit parent-child edge to knowledge graph │
│ c. Send page document ID to chunker │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Chunker (per chunk) │
│ 1. Fetch page content from librarian using document ID │
│ 2. Split text into chunks │
│ 3. For each chunk: │
│ a. Save chunk as child document in librarian (parent = page) │
│ b. Emit parent-child edge to knowledge graph │
│ c. Send chunk document ID + chunk content to next processor │
└─────────────────────────────────────────────────────────────────────────┘
─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
Post-chunker optimization: messages carry both
chunk ID (for provenance) and content (to avoid
librarian round-trip). Chunks are small (2-4KB).
─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
┌─────────────────────────────────────────────────────────────────────────┐
│ Knowledge Extractor (per chunk) │
│ 1. Receive chunk ID + content directly (no librarian fetch needed) │
│ 2. Extract facts/triples and embeddings from chunk content │
│ 3. For each triple: │
│ a. Emit triple to knowledge graph │
│ b. Emit reified edge linking triple → chunk ID (edge pointing │
│ to edge - first use of reification support) │
│ 4. For each embedding: │
│ a. Emit embedding with its entity ID │
│ b. Link entity ID → chunk ID in knowledge graph │
└─────────────────────────────────────────────────────────────────────────┘
```
#### Text Document Flow
Text documents skip the PDF extractor and go directly to the chunker:
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Librarian (initiate processing) │
│ 1. Emit root document metadata to knowledge graph (once) │
│ 2. Send root document ID directly to chunker (skip PDF extractor) │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Chunker (per chunk) │
│ 1. Fetch text content from librarian using document ID │
│ 2. Split text into chunks │
│ 3. For each chunk: │
│ a. Save chunk as child document in librarian (parent = root doc) │
│ b. Emit parent-child edge to knowledge graph │
│ c. Send chunk document ID + chunk content to next processor │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Knowledge Extractor │
│ (same as PDF flow) │
└─────────────────────────────────────────────────────────────────────────┘
```
The resulting DAG is one level shorter:
```
PDF: Document → Pages → Chunks → Triples/Embeddings
Text: Document → Chunks → Triples/Embeddings
```
The design accommodates both because the chunker treats its input generically - it uses whatever document ID it receives as the parent, regardless of whether that's a source document or a page.
### Metadata Schema (PROV-O)
Provenance metadata uses the W3C PROV-O ontology. This provides a standard vocabulary and enables future signing/authentication of extraction outputs.
#### PROV-O Core Concepts
| PROV-O Type | TrustGraph Usage |
|-------------|------------------|
| `prov:Entity` | Document, Page, Chunk, Triple, Embedding |
| `prov:Activity` | Instances of extraction operations |
| `prov:Agent` | TG components (PDF extractor, chunker, etc.) with versions |
#### PROV-O Relationships
| Predicate | Meaning | Example |
|-----------|---------|---------|
| `prov:wasDerivedFrom` | Entity derived from another entity | Page wasDerivedFrom Document |
| `prov:wasGeneratedBy` | Entity generated by an activity | Page wasGeneratedBy PDFExtractionActivity |
| `prov:used` | Activity used an entity as input | PDFExtractionActivity used Document |
| `prov:wasAssociatedWith` | Activity performed by an agent | PDFExtractionActivity wasAssociatedWith tg:PDFExtractor |
#### Metadata at Each Level
**Source Document (emitted by Librarian):**
```
doc:123 a prov:Entity .
doc:123 dc:title "Research Paper" .
doc:123 dc:source <https://example.com/paper.pdf> .
doc:123 dc:date "2024-01-15" .
doc:123 dc:creator "Author Name" .
doc:123 tg:pageCount 42 .
doc:123 tg:mimeType "application/pdf" .
```
**Page (emitted by PDF Extractor):**
```
page:123-1 a prov:Entity .
page:123-1 prov:wasDerivedFrom doc:123 .
page:123-1 prov:wasGeneratedBy activity:pdf-extract-456 .
page:123-1 tg:pageNumber 1 .
activity:pdf-extract-456 a prov:Activity .
activity:pdf-extract-456 prov:used doc:123 .
activity:pdf-extract-456 prov:wasAssociatedWith tg:PDFExtractor .
activity:pdf-extract-456 tg:componentVersion "1.2.3" .
activity:pdf-extract-456 prov:startedAtTime "2024-01-15T10:30:00Z" .
```
**Chunk (emitted by Chunker):**
```
chunk:123-1-1 a prov:Entity .
chunk:123-1-1 prov:wasDerivedFrom page:123-1 .
chunk:123-1-1 prov:wasGeneratedBy activity:chunk-789 .
chunk:123-1-1 tg:chunkIndex 1 .
chunk:123-1-1 tg:charOffset 0 .
chunk:123-1-1 tg:charLength 2048 .
activity:chunk-789 a prov:Activity .
activity:chunk-789 prov:used page:123-1 .
activity:chunk-789 prov:wasAssociatedWith tg:Chunker .
activity:chunk-789 tg:componentVersion "1.0.0" .
activity:chunk-789 tg:chunkSize 2048 .
activity:chunk-789 tg:chunkOverlap 200 .
```
**Triple (emitted by Knowledge Extractor):**
```
# The extracted triple (edge)
entity:JohnSmith rel:worksAt entity:AcmeCorp .
# Subgraph containing the extracted triples
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
subgraph:001 prov:wasGeneratedBy activity:extract-999 .
activity:extract-999 a prov:Activity .
activity:extract-999 prov:used chunk:123-1-1 .
activity:extract-999 prov:wasAssociatedWith tg:KnowledgeExtractor .
activity:extract-999 tg:componentVersion "2.1.0" .
activity:extract-999 tg:llmModel "claude-3" .
activity:extract-999 tg:ontology <http://example.org/ontologies/business-v1> .
```
**Embedding (stored in vector store, not triple store):**
Embeddings are stored in the vector store with metadata, not as RDF triples. Each embedding record contains:
| Field | Description | Example |
|-------|-------------|---------|
| vector | The embedding vector | [0.123, -0.456, ...] |
| entity | Node URI the embedding represents | `entity:JohnSmith` |
| chunk_id | Source chunk (provenance) | `chunk:123-1-1` |
| model | Embedding model used | `text-embedding-ada-002` |
| component_version | TG embedder version | `1.0.0` |
The `entity` field links the embedding to the knowledge graph (node URI). The `chunk_id` field provides provenance back to the source chunk, enabling traversal up the DAG to the original document.
#### TrustGraph Namespace Extensions
Custom predicates under the `tg:` namespace for extraction-specific metadata:
| Predicate | Domain | Description |
|-----------|--------|-------------|
| `tg:contains` | Subgraph | Points at a triple contained in this extraction subgraph |
| `tg:pageCount` | Document | Total number of pages in source document |
| `tg:mimeType` | Document | MIME type of source document |
| `tg:pageNumber` | Page | Page number in source document |
| `tg:chunkIndex` | Chunk | Index of chunk within parent |
| `tg:charOffset` | Chunk | Character offset in parent text |
| `tg:charLength` | Chunk | Length of chunk in characters |
| `tg:chunkSize` | Activity | Configured chunk size |
| `tg:chunkOverlap` | Activity | Configured overlap between chunks |
| `tg:componentVersion` | Activity | Version of TG component |
| `tg:llmModel` | Activity | LLM used for extraction |
| `tg:ontology` | Activity | Ontology URI used to guide extraction |
| `tg:embeddingModel` | Activity | Model used for embeddings |
| `tg:sourceText` | Statement | Exact text from which a triple was extracted |
| `tg:sourceCharOffset` | Statement | Character offset within chunk where source text starts |
| `tg:sourceCharLength` | Statement | Length of source text in characters |
#### Vocabulary Bootstrap (Per Collection)
The knowledge graph is ontology-neutral and initialises empty. When writing PROV-O provenance data to a collection for the first time, the vocabulary must be bootstrapped with RDF labels for all classes and predicates. This ensures human-readable display in queries and UI.
**PROV-O Classes:**
```
prov:Entity rdfs:label "Entity" .
prov:Activity rdfs:label "Activity" .
prov:Agent rdfs:label "Agent" .
```
**PROV-O Predicates:**
```
prov:wasDerivedFrom rdfs:label "was derived from" .
prov:wasGeneratedBy rdfs:label "was generated by" .
prov:used rdfs:label "used" .
prov:wasAssociatedWith rdfs:label "was associated with" .
prov:startedAtTime rdfs:label "started at" .
```
**TrustGraph Predicates:**
```
tg:contains rdfs:label "contains" .
tg:pageCount rdfs:label "page count" .
tg:mimeType rdfs:label "MIME type" .
tg:pageNumber rdfs:label "page number" .
tg:chunkIndex rdfs:label "chunk index" .
tg:charOffset rdfs:label "character offset" .
tg:charLength rdfs:label "character length" .
tg:chunkSize rdfs:label "chunk size" .
tg:chunkOverlap rdfs:label "chunk overlap" .
tg:componentVersion rdfs:label "component version" .
tg:llmModel rdfs:label "LLM model" .
tg:ontology rdfs:label "ontology" .
tg:embeddingModel rdfs:label "embedding model" .
tg:sourceText rdfs:label "source text" .
tg:sourceCharOffset rdfs:label "source character offset" .
tg:sourceCharLength rdfs:label "source character length" .
```
**Implementation note:** This vocabulary bootstrap should be idempotent - safe to run multiple times without creating duplicates. Could be triggered on first document processing in a collection, or as a separate collection initialisation step.
#### Sub-Chunk Provenance (Aspirational)
For finer-grained provenance, it would be valuable to record exactly where within a chunk a triple was extracted from. This enables:
- Highlighting the exact source text in the UI
- Verifying extraction accuracy against source
- Debugging extraction quality at the sentence level
**Example with position tracking:**
```
# The extracted triple
entity:JohnSmith rel:worksAt entity:AcmeCorp .
# Subgraph with sub-chunk provenance
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
subgraph:001 tg:sourceText "John Smith has worked at Acme Corp since 2019" .
subgraph:001 tg:sourceCharOffset 1547 .
subgraph:001 tg:sourceCharLength 46 .
```
**Example with text range (alternative):**
```
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
subgraph:001 tg:sourceRange "1547-1593" .
subgraph:001 tg:sourceText "John Smith has worked at Acme Corp since 2019" .
```
**Implementation considerations:**
- LLM-based extraction may not naturally provide character positions
- Could prompt the LLM to return the source sentence/phrase alongside extracted triples
- Alternatively, post-process to fuzzy-match extracted entities back to source text
- Trade-off between extraction complexity and provenance granularity
- May be easier to achieve with structured extraction methods than free-form LLM extraction
This is marked as aspirational - the basic chunk-level provenance should be implemented first, with sub-chunk tracking as a future enhancement if feasible.
### Dual Storage Model
The provenance DAG is built progressively as documents flow through the pipeline:
| Store | What's Stored | Purpose |
|-------|---------------|---------|
| Librarian | Document content + parent-child links | Content retrieval, cascade deletion |
| Knowledge Graph | Parent-child edges + metadata | Provenance queries, fact attribution |
Both stores maintain the same DAG structure. The librarian holds content; the graph holds relationships and enables traversal queries.
### Key Design Principles
1. **Document ID as the unit of flow** - Processors pass IDs, not content. Content is fetched from librarian when needed.
2. **Emit once at source** - Metadata is written to the graph once when processing begins, not repeated downstream.
3. **Consistent processor pattern** - Every processor follows the same receive/fetch/produce/save/emit/forward pattern.
4. **Progressive DAG construction** - Each processor adds its level to the DAG. The full provenance chain is built incrementally.
5. **Post-chunker optimization** - After chunking, messages carry both ID and content. Chunks are small (2-4KB), so including content avoids unnecessary librarian round-trips while preserving provenance via the ID.
## Implementation Tasks
### Librarian Changes
#### Current State
- Initiates document processing by sending document ID to first processor
- No connection to triple store - metadata is bundled with extraction outputs
- `add-child-document` creates one-level parent-child links
- `list-children` returns immediate children only
#### Required Changes
**1. New interface: Triple store connection**
Librarian needs to emit document metadata edges directly to the knowledge graph when initiating processing.
- Add triple store client/publisher to librarian service
- On processing initiation: emit root document metadata as graph edges (once)
**2. Document type vocabulary**
Standardize `document_type` values for child documents:
- `source` - original uploaded document
- `page` - page extracted from source (PDF, etc.)
- `chunk` - text chunk derived from page or source
#### Interface Changes Summary
| Interface | Change |
|-----------|--------|
| Triple store | New outbound connection - emit document metadata edges |
| Processing initiation | Emit metadata to graph before forwarding document ID |
### PDF Extractor Changes
#### Current State
- Receives document content (or streams large documents)
- Extracts text from PDF pages
- Forwards page content to chunker
- No interaction with librarian or triple store
#### Required Changes
**1. New interface: Librarian client**
PDF extractor needs to save each page as a child document in librarian.
- Add librarian client to PDF extractor service
- For each page: call `add-child-document` with parent = root document ID
**2. New interface: Triple store connection**
PDF extractor needs to emit parent-child edges to knowledge graph.
- Add triple store client/publisher
- For each page: emit edge linking page document to parent document
**3. Change output format**
Instead of forwarding page content directly, forward page document ID.
- Chunker will fetch content from librarian using the ID
#### Interface Changes Summary
| Interface | Change |
|-----------|--------|
| Librarian | New outbound - save child documents |
| Triple store | New outbound - emit parent-child edges |
| Output message | Change from content to document ID |
### Chunker Changes
#### Current State
- Receives page/text content
- Splits into chunks
- Forwards chunk content to downstream processors
- No interaction with librarian or triple store
#### Required Changes
**1. Change input handling**
Receive document ID instead of content, fetch from librarian.
- Add librarian client to chunker service
- Fetch page content using document ID
**2. New interface: Librarian client (write)**
Save each chunk as a child document in librarian.
- For each chunk: call `add-child-document` with parent = page document ID
**3. New interface: Triple store connection**
Emit parent-child edges to knowledge graph.
- Add triple store client/publisher
- For each chunk: emit edge linking chunk document to page document
**4. Change output format**
Forward both chunk document ID and chunk content (post-chunker optimization).
- Downstream processors receive ID for provenance + content to work with
#### Interface Changes Summary
| Interface | Change |
|-----------|--------|
| Input message | Change from content to document ID |
| Librarian | New outbound (read + write) - fetch content, save child documents |
| Triple store | New outbound - emit parent-child edges |
| Output message | Change from content-only to ID + content |
### Knowledge Extractor Changes
#### Current State
- Receives chunk content
- Extracts triples and embeddings
- Emits to triple store and embedding store
- `subjectOf` relationship points to top-level document (not chunk)
#### Required Changes
**1. Change input handling**
Receive chunk document ID alongside content.
- Use chunk ID for provenance linking (content already included per optimization)
**2. Update triple provenance**
Link extracted triples to chunk (not top-level document).
- Use reification to create edge pointing to edge
- `subjectOf` relationship: triple → chunk document ID
- First use of existing reification support
**3. Update embedding provenance**
Link embedding entity IDs to chunk.
- Emit edge: embedding entity ID → chunk document ID
#### Interface Changes Summary
| Interface | Change |
|-----------|--------|
| Input message | Expect chunk ID + content (not content only) |
| Triple store | Use reification for triple → chunk provenance |
| Embedding provenance | Link entity ID → chunk ID |
## References
- Query-time provenance: `docs/tech-specs/query-time-provenance.md`
- PROV-O standard for provenance modeling
- Existing source metadata in knowledge graph (needs audit)

View file

@ -0,0 +1,984 @@
# Large Document Loading Technical Specification
## Overview
This specification addresses scalability and user experience issues when loading
large documents into TrustGraph. The current architecture treats document upload
as a single atomic operation, causing memory pressure at multiple points in the
pipeline and providing no feedback or recovery options to users.
This implementation targets the following use cases:
1. **Large PDF Processing**: Upload and process multi-hundred-megabyte PDF files
without exhausting memory
2. **Resumable Uploads**: Allow interrupted uploads to continue from where they
left off rather than restarting
3. **Progress Feedback**: Provide users with real-time visibility into upload
and processing progress
4. **Memory-Efficient Processing**: Process documents in a streaming fashion
without holding entire files in memory
## Goals
- **Incremental Upload**: Support chunked document upload via REST and WebSocket
- **Resumable Transfers**: Enable recovery from interrupted uploads
- **Progress Visibility**: Provide upload/processing progress feedback to clients
- **Memory Efficiency**: Eliminate full-document buffering throughout the pipeline
- **Backward Compatibility**: Existing small-document workflows continue unchanged
- **Streaming Processing**: PDF decoding and text chunking operate on streams
## Background
### Current Architecture
Document submission flows through the following path:
1. **Client** submits document via REST (`POST /api/v1/librarian`) or WebSocket
2. **API Gateway** receives complete request with base64-encoded document content
3. **LibrarianRequestor** translates request to Pulsar message
4. **Librarian Service** receives message, decodes document into memory
5. **BlobStore** uploads document to Garage/S3
6. **Cassandra** stores metadata with object reference
7. For processing: document retrieved from S3, decoded, chunked—all in memory
Key files:
- REST/WebSocket entry: `trustgraph-flow/trustgraph/gateway/service.py`
- Librarian core: `trustgraph-flow/trustgraph/librarian/librarian.py`
- Blob storage: `trustgraph-flow/trustgraph/librarian/blob_store.py`
- Cassandra tables: `trustgraph-flow/trustgraph/tables/library.py`
- API schema: `trustgraph-base/trustgraph/schema/services/library.py`
### Current Limitations
The current design has several compounding memory and UX issues:
1. **Atomic Upload Operation**: The entire document must be transmitted in a
single request. Large documents require long-running requests with no
progress indication and no retry mechanism if the connection fails.
2. **API Design**: Both REST and WebSocket APIs expect the complete document
in a single message. The schema (`LibrarianRequest`) has a single `content`
field containing the entire base64-encoded document.
3. **Librarian Memory**: The librarian service decodes the entire document
into memory before uploading to S3. For a 500MB PDF, this means holding
500MB+ in process memory.
4. **PDF Decoder Memory**: When processing begins, the PDF decoder loads the
entire PDF into memory to extract text. PyPDF and similar libraries
typically require full document access.
5. **Chunker Memory**: The text chunker receives the complete extracted text
and holds it in memory while producing chunks.
**Memory Impact Example** (500MB PDF):
- Gateway: ~700MB (base64 encoding overhead)
- Librarian: ~500MB (decoded bytes)
- PDF Decoder: ~500MB + extraction buffers
- Chunker: extracted text (variable, potentially 100MB+)
Total peak memory can exceed 2GB for a single large document.
## Technical Design
### Design Principles
1. **API Facade**: All client interaction goes through the librarian API. Clients
have no direct access to or knowledge of the underlying S3/Garage storage.
2. **S3 Multipart Upload**: Use standard S3 multipart upload under the hood.
This is widely supported across S3-compatible systems (AWS S3, MinIO, Garage,
Ceph, DigitalOcean Spaces, Backblaze B2, etc.) ensuring portability.
3. **Atomic Completion**: S3 multipart uploads are inherently atomic - uploaded
parts are invisible until `CompleteMultipartUpload` is called. No temporary
files or rename operations needed.
4. **Trackable State**: Upload sessions tracked in Cassandra, providing
visibility into incomplete uploads and enabling resume capability.
### Chunked Upload Flow
```
Client Librarian API S3/Garage
│ │ │
│── begin-upload ───────────►│ │
│ (metadata, size) │── CreateMultipartUpload ────►│
│ │◄── s3_upload_id ─────────────│
│◄── upload_id ──────────────│ (store session in │
│ │ Cassandra) │
│ │ │
│── upload-chunk ───────────►│ │
│ (upload_id, index, data) │── UploadPart ───────────────►│
│ │◄── etag ─────────────────────│
│◄── ack + progress ─────────│ (store etag in session) │
│ ⋮ │ ⋮ │
│ (repeat for all chunks) │ │
│ │ │
│── complete-upload ────────►│ │
│ (upload_id) │── CompleteMultipartUpload ──►│
│ │ (parts coalesced by S3) │
│ │── store doc metadata ───────►│ Cassandra
│◄── document_id ────────────│ (delete session) │
```
The client never interacts with S3 directly. The librarian translates between
our chunked upload API and S3 multipart operations internally.
### Librarian API Operations
#### `begin-upload`
Initialize a chunked upload session.
Request:
```json
{
"operation": "begin-upload",
"document-metadata": {
"id": "doc-123",
"kind": "application/pdf",
"title": "Large Document",
"user": "user-id",
"tags": ["tag1", "tag2"]
},
"total-size": 524288000,
"chunk-size": 5242880
}
```
Response:
```json
{
"upload-id": "upload-abc-123",
"chunk-size": 5242880,
"total-chunks": 100
}
```
The librarian:
1. Generates a unique `upload_id` and `object_id` (UUID for blob storage)
2. Calls S3 `CreateMultipartUpload`, receives `s3_upload_id`
3. Creates session record in Cassandra
4. Returns `upload_id` to client
#### `upload-chunk`
Upload a single chunk.
Request:
```json
{
"operation": "upload-chunk",
"upload-id": "upload-abc-123",
"chunk-index": 0,
"content": "<base64-encoded-chunk>"
}
```
Response:
```json
{
"upload-id": "upload-abc-123",
"chunk-index": 0,
"chunks-received": 1,
"total-chunks": 100,
"bytes-received": 5242880,
"total-bytes": 524288000
}
```
The librarian:
1. Looks up session by `upload_id`
2. Validates ownership (user must match session creator)
3. Calls S3 `UploadPart` with chunk data, receives `etag`
4. Updates session record with chunk index and etag
5. Returns progress to client
Failed chunks can be retried - just send the same `chunk-index` again.
#### `complete-upload`
Finalize the upload and create the document.
Request:
```json
{
"operation": "complete-upload",
"upload-id": "upload-abc-123"
}
```
Response:
```json
{
"document-id": "doc-123",
"object-id": "550e8400-e29b-41d4-a716-446655440000"
}
```
The librarian:
1. Looks up session, verifies all chunks received
2. Calls S3 `CompleteMultipartUpload` with part etags (S3 coalesces parts
internally - zero memory cost to librarian)
3. Creates document record in Cassandra with metadata and object reference
4. Deletes upload session record
5. Returns document ID to client
#### `abort-upload`
Cancel an in-progress upload.
Request:
```json
{
"operation": "abort-upload",
"upload-id": "upload-abc-123"
}
```
The librarian:
1. Calls S3 `AbortMultipartUpload` to clean up parts
2. Deletes session record from Cassandra
#### `get-upload-status`
Query status of an upload (for resume capability).
Request:
```json
{
"operation": "get-upload-status",
"upload-id": "upload-abc-123"
}
```
Response:
```json
{
"upload-id": "upload-abc-123",
"state": "in-progress",
"chunks-received": [0, 1, 2, 5, 6],
"missing-chunks": [3, 4, 7, 8],
"total-chunks": 100,
"bytes-received": 36700160,
"total-bytes": 524288000
}
```
#### `list-uploads`
List incomplete uploads for a user.
Request:
```json
{
"operation": "list-uploads"
}
```
Response:
```json
{
"uploads": [
{
"upload-id": "upload-abc-123",
"document-metadata": { "title": "Large Document", ... },
"progress": { "chunks-received": 43, "total-chunks": 100 },
"created-at": "2024-01-15T10:30:00Z"
}
]
}
```
### Upload Session Storage
Track in-progress uploads in Cassandra:
```sql
CREATE TABLE upload_session (
upload_id text PRIMARY KEY,
user text,
document_id text,
document_metadata text, -- JSON: title, kind, tags, comments, etc.
s3_upload_id text, -- internal, for S3 operations
object_id uuid, -- target blob ID
total_size bigint,
chunk_size int,
total_chunks int,
chunks_received map<int, text>, -- chunk_index → etag
created_at timestamp,
updated_at timestamp
) WITH default_time_to_live = 86400; -- 24 hour TTL
CREATE INDEX upload_session_user ON upload_session (user);
```
**TTL Behavior:**
- Sessions expire after 24 hours if not completed
- When Cassandra TTL expires, the session record is deleted
- Orphaned S3 parts are cleaned up by S3 lifecycle policy (configure on bucket)
### Failure Handling and Atomicity
**Chunk upload failure:**
- Client retries the failed chunk (same `upload_id` and `chunk-index`)
- S3 `UploadPart` is idempotent for the same part number
- Session tracks which chunks succeeded
**Client disconnect mid-upload:**
- Session remains in Cassandra with received chunks recorded
- Client can call `get-upload-status` to see what's missing
- Resume by uploading only missing chunks, then `complete-upload`
**Complete-upload failure:**
- S3 `CompleteMultipartUpload` is atomic - either succeeds fully or fails
- On failure, parts remain and client can retry `complete-upload`
- No partial document is ever visible
**Session expiry:**
- Cassandra TTL deletes session record after 24 hours
- S3 bucket lifecycle policy cleans up incomplete multipart uploads
- No manual cleanup required
### S3 Multipart Atomicity
S3 multipart uploads provide built-in atomicity:
1. **Parts are invisible**: Uploaded parts cannot be accessed as objects.
They exist only as parts of an incomplete multipart upload.
2. **Atomic completion**: `CompleteMultipartUpload` either succeeds (object
appears atomically) or fails (no object created). No partial state.
3. **No rename needed**: The final object key is specified at
`CreateMultipartUpload` time. Parts are coalesced directly to that key.
4. **Server-side coalesce**: S3 combines parts internally. The librarian
never reads parts back - zero memory overhead regardless of document size.
### BlobStore Extensions
**File:** `trustgraph-flow/trustgraph/librarian/blob_store.py`
Add multipart upload methods:
```python
class BlobStore:
# Existing methods...
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
"""Initialize multipart upload, return s3_upload_id."""
# minio client: create_multipart_upload()
def upload_part(
self, object_id: UUID, s3_upload_id: str,
part_number: int, data: bytes
) -> str:
"""Upload a single part, return etag."""
# minio client: upload_part()
# Note: S3 part numbers are 1-indexed
def complete_multipart_upload(
self, object_id: UUID, s3_upload_id: str,
parts: List[Tuple[int, str]] # [(part_number, etag), ...]
) -> None:
"""Finalize multipart upload."""
# minio client: complete_multipart_upload()
def abort_multipart_upload(
self, object_id: UUID, s3_upload_id: str
) -> None:
"""Cancel multipart upload, clean up parts."""
# minio client: abort_multipart_upload()
```
### Chunk Size Considerations
- **S3 minimum**: 5MB per part (except last part)
- **S3 maximum**: 10,000 parts per upload
- **Practical default**: 5MB chunks
- 500MB document = 100 chunks
- 5GB document = 1,000 chunks
- **Progress granularity**: Smaller chunks = finer progress updates
- **Network efficiency**: Larger chunks = fewer round trips
Chunk size could be client-configurable within bounds (5MB - 100MB).
### Document Processing: Streaming Retrieval
The upload flow addresses getting documents into storage efficiently. The
processing flow addresses extracting and chunking documents without loading
them entirely into memory.
#### Design Principle: Identifier, Not Content
Currently, when processing is triggered, document content flows through Pulsar
messages. This loads entire documents into memory. Instead:
- Pulsar messages carry only the **document identifier**
- Processors fetch document content directly from librarian
- Fetching happens as a **stream to temporary file**
- Document-specific parsing (PDF, text, etc.) works with files, not memory buffers
This keeps the librarian document-structure-agnostic. PDF parsing, text
extraction, and other format-specific logic stays in the respective decoders.
#### Processing Flow
```
Pulsar PDF Decoder Librarian S3
│ │ │ │
│── doc-id ───────────►│ │ │
│ (processing msg) │ │ │
│ │ │ │
│ │── stream-document ──────►│ │
│ │ (doc-id) │── GetObject ────►│
│ │ │ │
│ │◄── chunk ────────────────│◄── stream ───────│
│ │ (write to temp file) │ │
│ │◄── chunk ────────────────│◄── stream ───────│
│ │ (append to temp file) │ │
│ │ ⋮ │ ⋮ │
│ │◄── EOF ──────────────────│ │
│ │ │ │
│ │ ┌──────────────────────────┐ │
│ │ │ temp file on disk │ │
│ │ │ (memory stays bounded) │ │
│ │ └────────────┬─────────────┘ │
│ │ │ │
│ │ PDF library opens file │
│ │ extract page 1 text ──► chunker │
│ │ extract page 2 text ──► chunker │
│ │ ⋮ │
│ │ close file │
│ │ delete temp file │
```
#### Librarian Stream API
Add a streaming document retrieval operation:
**`stream-document`**
Request:
```json
{
"operation": "stream-document",
"document-id": "doc-123"
}
```
Response: Streamed binary chunks (not a single response).
For REST API, this returns a streaming response with `Transfer-Encoding: chunked`.
For internal service-to-service calls (processor to librarian), this could be:
- Direct S3 streaming via presigned URL (if internal network allows)
- Chunked responses over the service protocol
- A dedicated streaming endpoint
The key requirement: data flows in chunks, never fully buffered in librarian.
#### PDF Decoder Changes
**Current implementation** (memory-intensive):
```python
def decode_pdf(document_content: bytes) -> str:
reader = PdfReader(BytesIO(document_content)) # full doc in memory
text = ""
for page in reader.pages:
text += page.extract_text() # accumulating
return text # full text in memory
```
**New implementation** (temp file, incremental):
```python
def decode_pdf_streaming(doc_id: str, librarian_client) -> Iterator[str]:
"""Yield extracted text page by page."""
with tempfile.NamedTemporaryFile(delete=True, suffix='.pdf') as tmp:
# Stream document to temp file
for chunk in librarian_client.stream_document(doc_id):
tmp.write(chunk)
tmp.flush()
# Open PDF from file (not memory)
reader = PdfReader(tmp.name)
# Yield pages incrementally
for page in reader.pages:
yield page.extract_text()
# tmp file auto-deleted on context exit
```
Memory profile:
- Temp file on disk: size of PDF (disk is cheap)
- In memory: one page's text at a time
- Peak memory: bounded, independent of document size
#### Text Document Decoder Changes
For plain text documents, even simpler - no temp file needed:
```python
def decode_text_streaming(doc_id: str, librarian_client) -> Iterator[str]:
"""Yield text in chunks as it streams from storage."""
buffer = ""
for chunk in librarian_client.stream_document(doc_id):
buffer += chunk.decode('utf-8')
# Yield complete lines/paragraphs as they arrive
while '\n\n' in buffer:
paragraph, buffer = buffer.split('\n\n', 1)
yield paragraph + '\n\n'
# Yield remaining buffer
if buffer:
yield buffer
```
Text documents can stream directly without temp file since they're
linearly structured.
#### Streaming Chunker Integration
The chunker receives an iterator of text (pages or paragraphs) and produces
chunks incrementally:
```python
class StreamingChunker:
def __init__(self, chunk_size: int, overlap: int):
self.chunk_size = chunk_size
self.overlap = overlap
def process(self, text_stream: Iterator[str]) -> Iterator[str]:
"""Yield chunks as text arrives."""
buffer = ""
for text_segment in text_stream:
buffer += text_segment
while len(buffer) >= self.chunk_size:
chunk = buffer[:self.chunk_size]
yield chunk
# Keep overlap for context continuity
buffer = buffer[self.chunk_size - self.overlap:]
# Yield remaining buffer as final chunk
if buffer.strip():
yield buffer
```
#### End-to-End Processing Pipeline
```python
async def process_document(doc_id: str, librarian_client, embedder):
"""Process document with bounded memory."""
# Get document metadata to determine type
metadata = await librarian_client.get_document_metadata(doc_id)
# Select decoder based on document type
if metadata.kind == 'application/pdf':
text_stream = decode_pdf_streaming(doc_id, librarian_client)
elif metadata.kind == 'text/plain':
text_stream = decode_text_streaming(doc_id, librarian_client)
else:
raise UnsupportedDocumentType(metadata.kind)
# Chunk incrementally
chunker = StreamingChunker(chunk_size=1000, overlap=100)
# Process each chunk as it's produced
for chunk in chunker.process(text_stream):
# Generate embeddings, store in vector DB, etc.
embedding = await embedder.embed(chunk)
await store_chunk(doc_id, chunk, embedding)
```
At no point is the full document or full extracted text held in memory.
#### Temp File Considerations
**Location**: Use system temp directory (`/tmp` or equivalent). For
containerized deployments, ensure temp directory has sufficient space
and is on fast storage (not network-mounted if possible).
**Cleanup**: Use context managers (`with tempfile...`) to ensure cleanup
even on exceptions.
**Concurrent processing**: Each processing job gets its own temp file.
No conflicts between parallel document processing.
**Disk space**: Temp files are short-lived (duration of processing). For
a 500MB PDF, need 500MB temp space during processing. Size limit could
be enforced at upload time if disk space is constrained.
### Unified Processing Interface: Child Documents
PDF extraction and text document processing need to feed into the same
downstream pipeline (chunker → embeddings → storage). To achieve this with
a consistent "fetch by ID" interface, extracted text blobs are stored back
to librarian as child documents.
#### Processing Flow with Child Documents
```
PDF Document Text Document
│ │
▼ │
pdf-extractor │
│ │
│ (stream PDF from librarian) │
│ (extract page 1 text) │
│ (store as child doc → librarian) │
│ (extract page 2 text) │
│ (store as child doc → librarian) │
│ ⋮ │
▼ ▼
[child-doc-id, child-doc-id, ...] [doc-id]
│ │
└─────────────────────┬───────────────────────────────┘
chunker
│ (receives document ID)
│ (streams content from librarian)
│ (chunks incrementally)
[chunks → embedding → storage]
```
The chunker has one uniform interface:
- Receive a document ID (via Pulsar)
- Stream content from librarian
- Chunk it
It doesn't know or care whether the ID refers to:
- A user-uploaded text document
- An extracted text blob from a PDF page
- Any future document type
#### Child Document Metadata
Extend the document schema to track parent/child relationships:
```sql
-- Add columns to document table
ALTER TABLE document ADD parent_id text;
ALTER TABLE document ADD document_type text;
-- Index for finding children of a parent
CREATE INDEX document_parent ON document (parent_id);
```
**Document types:**
| `document_type` | Description |
|-----------------|-------------|
| `source` | User-uploaded document (PDF, text, etc.) |
| `extracted` | Derived from a source document (e.g., PDF page text) |
**Metadata fields:**
| Field | Source Document | Extracted Child |
|-------|-----------------|-----------------|
| `id` | user-provided or generated | generated (e.g., `{parent-id}-page-{n}`) |
| `parent_id` | `NULL` | parent document ID |
| `document_type` | `source` | `extracted` |
| `kind` | `application/pdf`, etc. | `text/plain` |
| `title` | user-provided | generated (e.g., "Page 3 of Report.pdf") |
| `user` | authenticated user | same as parent |
#### Librarian API for Child Documents
**Creating child documents** (internal, used by pdf-extractor):
```json
{
"operation": "add-child-document",
"parent-id": "doc-123",
"document-metadata": {
"id": "doc-123-page-1",
"kind": "text/plain",
"title": "Page 1"
},
"content": "<base64-encoded-text>"
}
```
For small extracted text (typical page text is < 100KB), single-operation
upload is acceptable. For very large text extractions, chunked upload
could be used.
**Listing child documents** (for debugging/admin):
```json
{
"operation": "list-children",
"parent-id": "doc-123"
}
```
Response:
```json
{
"children": [
{ "id": "doc-123-page-1", "title": "Page 1", "kind": "text/plain" },
{ "id": "doc-123-page-2", "title": "Page 2", "kind": "text/plain" },
...
]
}
```
#### User-Facing Behavior
**`list-documents` default behavior:**
```sql
SELECT * FROM document WHERE user = ? AND parent_id IS NULL;
```
Only top-level (source) documents appear in the user's document list.
Child documents are filtered out by default.
**Optional include-children flag** (for admin/debugging):
```json
{
"operation": "list-documents",
"include-children": true
}
```
#### Cascade Delete
When a parent document is deleted, all children must be deleted:
```python
def delete_document(doc_id: str):
# Find all children
children = query("SELECT id, object_id FROM document WHERE parent_id = ?", doc_id)
# Delete child blobs from S3
for child in children:
blob_store.delete(child.object_id)
# Delete child metadata from Cassandra
execute("DELETE FROM document WHERE parent_id = ?", doc_id)
# Delete parent blob and metadata
parent = get_document(doc_id)
blob_store.delete(parent.object_id)
execute("DELETE FROM document WHERE id = ? AND user = ?", doc_id, user)
```
#### Storage Considerations
Extracted text blobs do duplicate content:
- Original PDF stored in Garage
- Extracted text per page also stored in Garage
This tradeoff enables:
- **Uniform chunker interface**: Chunker always fetches by ID
- **Resume/retry**: Can restart at chunker stage without re-extracting PDF
- **Debugging**: Extracted text is inspectable
- **Separation of concerns**: PDF extractor and chunker are independent services
For a 500MB PDF with 200 pages averaging 5KB text per page:
- PDF storage: 500MB
- Extracted text storage: ~1MB total
- Overhead: negligible
#### PDF Extractor Output
The pdf-extractor, after processing a document:
1. Streams PDF from librarian to temp file
2. Extracts text page by page
3. For each page, stores extracted text as child document via librarian
4. Sends child document IDs to chunker queue
```python
async def extract_pdf(doc_id: str, librarian_client, output_queue):
"""Extract PDF pages and store as child documents."""
with tempfile.NamedTemporaryFile(delete=True, suffix='.pdf') as tmp:
# Stream PDF to temp file
for chunk in librarian_client.stream_document(doc_id):
tmp.write(chunk)
tmp.flush()
# Extract pages
reader = PdfReader(tmp.name)
for page_num, page in enumerate(reader.pages, start=1):
text = page.extract_text()
# Store as child document
child_id = f"{doc_id}-page-{page_num}"
await librarian_client.add_child_document(
parent_id=doc_id,
document_id=child_id,
kind="text/plain",
title=f"Page {page_num}",
content=text.encode('utf-8')
)
# Send to chunker queue
await output_queue.send(child_id)
```
The chunker receives these child IDs and processes them identically to
how it would process a user-uploaded text document.
### Client Updates
#### Python SDK
The Python SDK (`trustgraph-base/trustgraph/api/library.py`) should handle
chunked uploads transparently. The public interface remains unchanged:
```python
# Existing interface - no change for users
library.add_document(
id="doc-123",
title="Large Report",
kind="application/pdf",
content=large_pdf_bytes, # Can be hundreds of MB
tags=["reports"]
)
```
Internally, the SDK detects document size and switches strategy:
```python
class Library:
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 # 2MB
def add_document(self, id, title, kind, content, tags=None, ...):
if len(content) < self.CHUNKED_UPLOAD_THRESHOLD:
# Small document: single operation (existing behavior)
return self._add_document_single(id, title, kind, content, tags)
else:
# Large document: chunked upload
return self._add_document_chunked(id, title, kind, content, tags)
def _add_document_chunked(self, id, title, kind, content, tags):
# 1. begin-upload
session = self._begin_upload(
document_metadata={...},
total_size=len(content),
chunk_size=5 * 1024 * 1024
)
# 2. upload-chunk for each chunk
for i, chunk in enumerate(self._chunk_bytes(content, session.chunk_size)):
self._upload_chunk(session.upload_id, i, chunk)
# 3. complete-upload
return self._complete_upload(session.upload_id)
```
**Progress callbacks** (optional enhancement):
```python
def add_document(self, ..., on_progress=None):
"""
on_progress: Optional callback(bytes_sent, total_bytes)
"""
```
This allows UIs to display upload progress without changing the basic API.
#### CLI Tools
**`tg-add-library-document`** continues to work unchanged:
```bash
# Works transparently for any size - SDK handles chunking internally
tg-add-library-document --file large-report.pdf --title "Large Report"
```
Optional progress display could be added:
```bash
tg-add-library-document --file large-report.pdf --title "Large Report" --progress
# Output:
# Uploading: 45% (225MB / 500MB)
```
**Legacy tools removed:**
- `tg-load-pdf` - deprecated, use `tg-add-library-document`
- `tg-load-text` - deprecated, use `tg-add-library-document`
**Admin/debug commands** (optional, low priority):
```bash
# List incomplete uploads (admin troubleshooting)
tg-add-library-document --list-pending
# Resume specific upload (recovery scenario)
tg-add-library-document --resume upload-abc-123 --file large-report.pdf
```
These could be flags on the existing command rather than separate tools.
#### API Specification Updates
The OpenAPI spec (`specs/api/paths/librarian.yaml`) needs updates for:
**New operations:**
- `begin-upload` - Initialize chunked upload session
- `upload-chunk` - Upload individual chunk
- `complete-upload` - Finalize upload
- `abort-upload` - Cancel upload
- `get-upload-status` - Query upload progress
- `list-uploads` - List incomplete uploads for user
- `stream-document` - Streaming document retrieval
- `add-child-document` - Store extracted text (internal)
- `list-children` - List child documents (admin)
**Modified operations:**
- `list-documents` - Add `include-children` parameter
**New schemas:**
- `ChunkedUploadBeginRequest`
- `ChunkedUploadBeginResponse`
- `ChunkedUploadChunkRequest`
- `ChunkedUploadChunkResponse`
- `UploadSession`
- `UploadProgress`
**WebSocket spec updates** (`specs/websocket/`):
Mirror the REST operations for WebSocket clients, enabling real-time
progress updates during upload.
#### UX Considerations
The API spec updates enable frontend improvements:
**Upload progress UI:**
- Progress bar showing chunks uploaded
- Estimated time remaining
- Pause/resume capability
**Error recovery:**
- "Resume upload" option for interrupted uploads
- List of pending uploads on reconnect
**Large file handling:**
- Client-side file size detection
- Automatic chunked upload for large files
- Clear feedback during long uploads
These UX improvements require frontend work guided by the updated API spec.

View file

@ -0,0 +1,263 @@
# Query-Time Explainability
## Status
Implemented
## Overview
This specification describes how GraphRAG records and communicates explainability data during query execution. The goal is full traceability: from final answer back through selected edges to source documents.
Query-time explainability captures what the GraphRAG pipeline did during reasoning. It connects to extraction-time provenance which records where knowledge graph facts originated.
## Terminology
| Term | Definition |
|------|------------|
| **Explainability** | The record of how a result was derived |
| **Session** | A single GraphRAG query execution |
| **Edge Selection** | LLM-driven selection of relevant edges with reasoning |
| **Provenance Chain** | Path from edge → chunk → page → document |
## Architecture
### Explainability Flow
```
GraphRAG Query
├─► Session Activity
│ └─► Query text, timestamp
├─► Retrieval Entity
│ └─► All edges retrieved from subgraph
├─► Selection Entity
│ └─► Selected edges with LLM reasoning
│ └─► Each edge links to extraction provenance
└─► Answer Entity
└─► Reference to synthesized response (in librarian)
```
### Two-Stage GraphRAG Pipeline
1. **Edge Selection**: LLM selects relevant edges from subgraph, providing reasoning for each
2. **Synthesis**: LLM generates answer from selected edges only
This separation enables explainability - we know exactly which edges contributed.
### Storage
- Explainability triples stored in configurable collection (default: `explainability`)
- Uses PROV-O ontology for provenance relationships
- RDF-star reification for edge references
- Answer content stored in librarian service (not inline - too large)
### Real-Time Streaming
Explainability events stream to client as the query executes:
1. Session created → event emitted
2. Edges retrieved → event emitted
3. Edges selected with reasoning → event emitted
4. Answer synthesized → event emitted
Client receives `explain_id` and `explain_collection` to fetch full details.
## URI Structure
All URIs use the `urn:trustgraph:` namespace with UUIDs:
| Entity | URI Pattern |
|--------|-------------|
| Session | `urn:trustgraph:session:{uuid}` |
| Retrieval | `urn:trustgraph:prov:retrieval:{uuid}` |
| Selection | `urn:trustgraph:prov:selection:{uuid}` |
| Answer | `urn:trustgraph:prov:answer:{uuid}` |
| Edge Selection | `urn:trustgraph:prov:edge:{uuid}:{index}` |
## RDF Model (PROV-O)
### Session Activity
```turtle
<session-uri> a prov:Activity ;
rdfs:label "GraphRAG query session" ;
prov:startedAtTime "2024-01-15T10:30:00Z" ;
tg:query "What was the War on Terror?" .
```
### Retrieval Entity
```turtle
<retrieval-uri> a prov:Entity ;
rdfs:label "Retrieved edges" ;
prov:wasGeneratedBy <session-uri> ;
tg:edgeCount 50 .
```
### Selection Entity
```turtle
<selection-uri> a prov:Entity ;
rdfs:label "Selected edges" ;
prov:wasDerivedFrom <retrieval-uri> ;
tg:selectedEdge <edge-sel-0> ;
tg:selectedEdge <edge-sel-1> .
<edge-sel-0> tg:edge << <s> <p> <o> >> ;
tg:reasoning "This edge establishes the key relationship..." .
```
### Answer Entity
```turtle
<answer-uri> a prov:Entity ;
rdfs:label "GraphRAG answer" ;
prov:wasDerivedFrom <selection-uri> ;
tg:document <urn:trustgraph:answer:{uuid}> .
```
The `tg:document` references the answer stored in the librarian service.
## Namespace Constants
Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`:
| Constant | URI |
|----------|-----|
| `TG_QUERY` | `https://trustgraph.ai/ns/query` |
| `TG_EDGE_COUNT` | `https://trustgraph.ai/ns/edgeCount` |
| `TG_SELECTED_EDGE` | `https://trustgraph.ai/ns/selectedEdge` |
| `TG_EDGE` | `https://trustgraph.ai/ns/edge` |
| `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` |
| `TG_CONTENT` | `https://trustgraph.ai/ns/content` |
| `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` |
## GraphRagResponse Schema
```python
@dataclass
class GraphRagResponse:
error: Error | None = None
response: str = ""
end_of_stream: bool = False
explain_id: str | None = None
explain_collection: str | None = None
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False
```
### Message Types
| message_type | Purpose |
|--------------|---------|
| `chunk` | Response text (streaming or final) |
| `explain` | Explainability event with IRI reference |
### Session Lifecycle
1. Multiple `explain` messages (session, retrieval, selection, answer)
2. Multiple `chunk` messages (streaming response)
3. Final `chunk` with `end_of_session=True`
## Edge Selection Format
LLM returns JSONL with selected edges:
```jsonl
{"id": "edge-hash-1", "reasoning": "This edge shows the key relationship..."}
{"id": "edge-hash-2", "reasoning": "Provides supporting evidence..."}
```
The `id` is a hash of `(labeled_s, labeled_p, labeled_o)` computed by `edge_id()`.
## URI Preservation
### The Problem
GraphRAG displays human-readable labels to the LLM, but explainability needs original URIs for provenance tracing.
### Solution
`get_labelgraph()` returns both:
- `labeled_edges`: List of `(label_s, label_p, label_o)` for LLM
- `uri_map`: Dict mapping `edge_id(labels)``(uri_s, uri_p, uri_o)`
When storing explainability data, URIs from `uri_map` are used.
## Provenance Tracing
### From Edge to Source
Selected edges can be traced back to source documents:
1. Query for containing subgraph: `?subgraph tg:contains <<s p o>>`
2. Follow `prov:wasDerivedFrom` chain to root document
3. Each step in chain: chunk → page → document
### Cassandra Quoted Triple Support
The Cassandra query service supports matching quoted triples:
```python
# In get_term_value():
elif term.type == TRIPLE:
return serialize_triple(term.triple)
```
This enables queries like:
```
?subgraph tg:contains <<http://example.org/s http://example.org/p "value">>
```
## CLI Usage
```bash
tg-invoke-graph-rag --explainable -q "What was the War on Terror?"
```
### Output Format
```
[session] urn:trustgraph:session:abc123
[retrieval] urn:trustgraph:prov:retrieval:abc123
[selection] urn:trustgraph:prov:selection:abc123
Selected 12 edge(s)
Edge: (Guantanamo, definition, A detention facility...)
Reason: Directly connects Guantanamo to the War on Terror
Source: Chunk 1 → Page 2 → Beyond the Vigilant State
[answer] urn:trustgraph:prov:answer:abc123
Based on the provided knowledge statements...
```
### Features
- Real-time explainability events during query
- Label resolution for edge components via `rdfs:label`
- Source chain tracing via `prov:wasDerivedFrom`
- Label caching to avoid repeated queries
## Files Implemented
| File | Purpose |
|------|---------|
| `trustgraph-base/trustgraph/provenance/uris.py` | URI generators |
| `trustgraph-base/trustgraph/provenance/namespaces.py` | RDF namespace constants |
| `trustgraph-base/trustgraph/provenance/triples.py` | Triple builders |
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | GraphRagResponse schema |
| `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py` | Core GraphRAG with URI preservation |
| `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` | Service with librarian integration |
| `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support |
| `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display |
## References
- PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/
- RDF-star: https://w3c.github.io/rdf-star/
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`

View file

@ -0,0 +1,471 @@
# Tool Services: Dynamically Pluggable Agent Tools
## Status
Implemented
## Overview
This specification defines a mechanism for dynamically pluggable agent tools called "tool services". Unlike the existing built-in tool types (`KnowledgeQueryImpl`, `McpToolImpl`, etc.), tool services allow new tools to be introduced by:
1. Deploying a new Pulsar-based service
2. Adding a configuration descriptor that tells the agent how to invoke it
This enables extensibility without modifying the core agent-react framework.
## Terminology
| Term | Definition |
|------|------------|
| **Built-in Tool** | Existing tool types with hardcoded implementations in `tools.py` |
| **Tool Service** | A Pulsar service that can be invoked as an agent tool, defined by a service descriptor |
| **Tool** | A configured instance that references a tool service, exposed to the agent/LLM |
This is a two-tier model, analogous to MCP tools:
- MCP: MCP server defines the tool interface → Tool config references it
- Tool Services: Tool service defines the Pulsar interface → Tool config references it
## Background: Existing Tools
### Built-in Tool Implementation
Tools are currently defined in `trustgraph-flow/trustgraph/agent/react/tools.py` with typed implementations:
```python
class KnowledgeQueryImpl:
async def invoke(self, question):
client = self.context("graph-rag-request")
return await client.rag(question, self.collection)
```
Each tool type:
- Has a hardcoded Pulsar service it calls (e.g., `graph-rag-request`)
- Knows the exact method to call on the client (e.g., `client.rag()`)
- Has typed arguments defined in the implementation
### Tool Registration (service.py:105-214)
Tools are loaded from config with a `type` field that maps to an implementation:
```python
if impl_id == "knowledge-query":
impl = functools.partial(KnowledgeQueryImpl, collection=data.get("collection"))
elif impl_id == "text-completion":
impl = TextCompletionImpl
# ... etc
```
## Architecture
### Two-Tier Model
#### Tier 1: Tool Service Descriptor
A tool service defines a Pulsar service interface. It declares:
- The Pulsar queues for request/response
- Configuration parameters it requires from tools that use it
```json
{
"id": "custom-rag",
"request-queue": "non-persistent://tg/request/custom-rag",
"response-queue": "non-persistent://tg/response/custom-rag",
"config-params": [
{"name": "collection", "required": true}
]
}
```
A tool service that needs no configuration parameters:
```json
{
"id": "calculator",
"request-queue": "non-persistent://tg/request/calc",
"response-queue": "non-persistent://tg/response/calc",
"config-params": []
}
```
#### Tier 2: Tool Descriptor
A tool references a tool service and provides:
- Config parameter values (satisfying the service's requirements)
- Tool metadata for the agent (name, description)
- Argument definitions for the LLM
```json
{
"type": "tool-service",
"name": "query-customers",
"description": "Query the customer knowledge base",
"service": "custom-rag",
"collection": "customers",
"arguments": [
{
"name": "question",
"type": "string",
"description": "The question to ask about customers"
}
]
}
```
Multiple tools can reference the same service with different configurations:
```json
{
"type": "tool-service",
"name": "query-products",
"description": "Query the product knowledge base",
"service": "custom-rag",
"collection": "products",
"arguments": [
{
"name": "question",
"type": "string",
"description": "The question to ask about products"
}
]
}
```
### Request Format
When a tool is invoked, the request to the tool service includes:
- `user`: From the agent request (multi-tenancy)
- `config`: JSON-encoded config values from the tool descriptor
- `arguments`: JSON-encoded arguments from the LLM
```json
{
"user": "alice",
"config": "{\"collection\": \"customers\"}",
"arguments": "{\"question\": \"What are the top customer complaints?\"}"
}
```
The tool service receives these as parsed dicts in the `invoke` method.
### Generic Tool Service Implementation
A `ToolServiceImpl` class invokes tool services based on configuration:
```python
class ToolServiceImpl:
def __init__(self, context, request_queue, response_queue, config_values, arguments, processor):
self.request_queue = request_queue
self.response_queue = response_queue
self.config_values = config_values # e.g., {"collection": "customers"}
# ...
async def invoke(self, **arguments):
client = await self._get_or_create_client()
response = await client.call(user, self.config_values, arguments)
if isinstance(response, str):
return response
else:
return json.dumps(response)
```
## Design Decisions
### Two-Tier Configuration Model
Tool services follow a two-tier model similar to MCP tools:
1. **Tool Service**: Defines the Pulsar service interface (topic, required config params)
2. **Tool**: References a tool service, provides config values, defines LLM arguments
This separation allows:
- One tool service to be used by multiple tools with different configurations
- Clear distinction between service interface and tool configuration
- Reusability of service definitions
### Request Mapping: Pass-Through with Envelope
The request to a tool service is a structured envelope containing:
- `user`: Propagated from the agent request for multi-tenancy
- Config values: From the tool descriptor (e.g., `collection`)
- `arguments`: LLM-provided arguments, passed through as a dict
The agent manager parses the LLM's response into `act.arguments` as a dict (`agent_manager.py:117-154`). This dict is included in the request envelope.
### Schema Handling: Untyped
Requests and responses use untyped dicts. No schema validation at the agent level - the tool service is responsible for validating its inputs. This provides maximum flexibility for defining new services.
### Client Interface: Direct Pulsar Topics
Tool services use direct Pulsar topics without requiring flow configuration. The tool-service descriptor specifies the full queue names:
```json
{
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [...]
}
```
This allows services to be hosted in any namespace.
### Error Handling: Standard Error Convention
Tool service responses follow the existing schema convention with an `error` field:
```python
@dataclass
class Error:
type: str = ""
message: str = ""
```
Response structure:
- Success: `error` is `None`, response contains result
- Error: `error` is populated with `type` and `message`
This matches the pattern used throughout existing service schemas (e.g., `PromptResponse`, `QueryResponse`, `AgentResponse`).
### Request/Response Correlation
Requests and responses are correlated using an `id` in Pulsar message properties:
- Request includes `id` in properties: `properties={"id": id}`
- Response(s) include the same `id`: `properties={"id": id}`
This follows the existing pattern used throughout the codebase (e.g., `agent_service.py`, `llm_service.py`).
### Streaming Support
Tool services can return streaming responses:
- Multiple response messages with the same `id` in properties
- Each response includes `end_of_stream: bool` field
- Final response has `end_of_stream: True`
This matches the pattern used in `AgentResponse` and other streaming services.
### Response Handling: String Return
All existing tools follow the same pattern: **receive arguments as a dict, return observation as a string**.
| Tool | Response Handling |
|------|------------------|
| `KnowledgeQueryImpl` | Returns `client.rag()` directly (string) |
| `TextCompletionImpl` | Returns `client.question()` directly (string) |
| `McpToolImpl` | Returns string, or `json.dumps(output)` if not string |
| `StructuredQueryImpl` | Formats result to string |
| `PromptImpl` | Returns `client.prompt()` directly (string) |
Tool services follow the same contract:
- The service returns a string response (the observation)
- If the response is not a string, it is converted via `json.dumps()`
- No extraction configuration needed in the descriptor
This keeps the descriptor simple and places responsibility on the service to return an appropriate text response for the agent.
## Configuration Guide
To add a new tool service, two configuration items are required:
### 1. Tool Service Configuration
Stored under the `tool-service` config key. Defines the Pulsar queues and available config parameters.
| Field | Required | Description |
|-------|----------|-------------|
| `id` | Yes | Unique identifier for the tool service |
| `request-queue` | Yes | Full Pulsar topic for requests (e.g., `non-persistent://tg/request/joke`) |
| `response-queue` | Yes | Full Pulsar topic for responses (e.g., `non-persistent://tg/response/joke`) |
| `config-params` | No | Array of config parameters the service accepts |
Each config param can specify:
- `name`: Parameter name (required)
- `required`: Whether the parameter must be provided by tools (default: false)
Example:
```json
{
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [
{"name": "style", "required": false}
]
}
```
### 2. Tool Configuration
Stored under the `tool` config key. Defines a tool that the agent can use.
| Field | Required | Description |
|-------|----------|-------------|
| `type` | Yes | Must be `"tool-service"` |
| `name` | Yes | Tool name exposed to the LLM |
| `description` | Yes | Description of what the tool does (shown to LLM) |
| `service` | Yes | ID of the tool-service to invoke |
| `arguments` | No | Array of argument definitions for the LLM |
| *(config params)* | Varies | Any config params defined by the service |
Each argument can specify:
- `name`: Argument name (required)
- `type`: Data type, e.g., `"string"` (required)
- `description`: Description shown to the LLM (required)
Example:
```json
{
"type": "tool-service",
"name": "tell-joke",
"description": "Tell a joke on a given topic",
"service": "joke-service",
"style": "pun",
"arguments": [
{
"name": "topic",
"type": "string",
"description": "The topic for the joke (e.g., programming, animals, food)"
}
]
}
```
### Loading Configuration
Use `tg-put-config-item` to load configurations:
```bash
# Load tool-service config
tg-put-config-item tool-service/joke-service < joke-service.json
# Load tool config
tg-put-config-item tool/tell-joke < tell-joke.json
```
The agent-manager must be restarted to pick up new configurations.
## Implementation Details
### Schema
Request and response types in `trustgraph-base/trustgraph/schema/services/tool_service.py`:
```python
@dataclass
class ToolServiceRequest:
user: str = "" # User context for multi-tenancy
config: str = "" # JSON-encoded config values from tool descriptor
arguments: str = "" # JSON-encoded arguments from LLM
@dataclass
class ToolServiceResponse:
error: Error | None = None
response: str = "" # String response (the observation)
end_of_stream: bool = False
```
### Server-Side: DynamicToolService
Base class in `trustgraph-base/trustgraph/base/dynamic_tool_service.py`:
```python
class DynamicToolService(AsyncProcessor):
"""Base class for implementing tool services."""
def __init__(self, **params):
topic = params.get("topic", default_topic)
# Constructs topics: non-persistent://tg/request/{topic}, non-persistent://tg/response/{topic}
# Sets up Consumer and Producer
async def invoke(self, user, config, arguments):
"""Override this method to implement the tool's logic."""
raise NotImplementedError()
```
### Client-Side: ToolServiceImpl
Implementation in `trustgraph-flow/trustgraph/agent/react/tools.py`:
```python
class ToolServiceImpl:
def __init__(self, context, request_queue, response_queue, config_values, arguments, processor):
# Uses the provided queue paths directly
# Creates ToolServiceClient on first use
async def invoke(self, **arguments):
client = await self._get_or_create_client()
response = await client.call(user, config_values, arguments)
return response if isinstance(response, str) else json.dumps(response)
```
### Files
| File | Purpose |
|------|---------|
| `trustgraph-base/trustgraph/schema/services/tool_service.py` | Request/response schemas |
| `trustgraph-base/trustgraph/base/tool_service_client.py` | Client for invoking services |
| `trustgraph-base/trustgraph/base/dynamic_tool_service.py` | Base class for service implementation |
| `trustgraph-flow/trustgraph/agent/react/tools.py` | `ToolServiceImpl` class |
| `trustgraph-flow/trustgraph/agent/react/service.py` | Config loading |
### Example: Joke Service
An example service in `trustgraph-flow/trustgraph/tool_service/joke/`:
```python
class Processor(DynamicToolService):
async def invoke(self, user, config, arguments):
style = config.get("style", "pun")
topic = arguments.get("topic", "")
joke = pick_joke(topic, style)
return f"Hey {user}! Here's a {style} for you:\n\n{joke}"
```
Tool service config:
```json
{
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [{"name": "style", "required": false}]
}
```
Tool config:
```json
{
"type": "tool-service",
"name": "tell-joke",
"description": "Tell a joke on a given topic",
"service": "joke-service",
"style": "pun",
"arguments": [
{"name": "topic", "type": "string", "description": "The topic for the joke"}
]
}
```
### Backward Compatibility
- Existing built-in tool types continue to work unchanged
- `tool-service` is a new tool type alongside existing types (`knowledge-query`, `mcp-tool`, etc.)
## Future Considerations
### Self-Announcing Services
A future enhancement could allow services to publish their own descriptors:
- Services publish to a well-known `tool-descriptors` topic on startup
- Agent subscribes and dynamically registers tools
- Enables true plug-and-play without config changes
This is out of scope for the initial implementation.
## References
- Current tool implementation: `trustgraph-flow/trustgraph/agent/react/tools.py`
- Tool registration: `trustgraph-flow/trustgraph/agent/react/service.py:105-214`
- Agent schemas: `trustgraph-base/trustgraph/schema/services/agent.py`

File diff suppressed because one or more lines are too long

View file

@ -1,14 +1,60 @@
type: object
description: RDF value - can be entity/URI or literal
required:
- v
- e
description: |
RDF Term - typed representation of a value in the knowledge graph.
Term types (discriminated by `t` field):
- `i`: IRI (URI reference)
- `l`: Literal (string value, optionally with datatype or language tag)
- `r`: Quoted triple (RDF-star reification)
- `b`: Blank node
properties:
t:
type: string
description: Term type discriminator
enum: [i, l, r, b]
example: i
i:
type: string
description: IRI value (when t=i)
example: http://example.com/Person1
v:
type: string
description: Value (URI or literal text)
example: https://example.com/entity1
e:
type: boolean
description: True if entity/URI, false if literal
example: true
description: Literal value (when t=l)
example: John Doe
d:
type: string
description: Datatype IRI for literal (when t=l, optional)
example: http://www.w3.org/2001/XMLSchema#integer
l:
type: string
description: Language tag for literal (when t=l, optional)
example: en
r:
type: object
description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure
properties:
s:
type: object
description: Subject term
p:
type: object
description: Predicate term
o:
type: object
description: Object term
required:
- t
examples:
- description: IRI term
value:
t: i
i: http://schema.org/name
- description: Literal term
value:
t: l
v: John Doe
- description: Literal with language tag
value:
t: l
v: Bonjour
l: fr

View file

@ -1,5 +1,6 @@
type: object
description: RDF triple (subject-predicate-object)
description: |
RDF triple (subject-predicate-object), optionally scoped to a named graph.
required:
- s
- p
@ -14,3 +15,7 @@ properties:
o:
$ref: './RdfValue.yaml'
description: Object
g:
type: string
description: Named graph URI (optional)
example: urn:graph:source

View file

@ -9,12 +9,26 @@ properties:
- action
- observation
- answer
- final-answer
- error
example: answer
content:
type: string
description: Chunk content (streaming mode only)
example: Paris is the capital of France.
message_type:
type: string
description: Message type - "chunk" for agent chunks, "explain" for explainability events
enum: [chunk, explain]
example: chunk
explain_id:
type: string
description: Explainability node URI (for explain messages)
example: urn:trustgraph:agent:abc123
explain_graph:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
end-of-message:
type: boolean
description: Current chunk type is complete (streaming mode)

View file

@ -1,21 +1,60 @@
type: object
description: |
RDF value - represents either a URI/entity or a literal value.
RDF Term - typed representation of a value in the knowledge graph.
When `e` is true, `v` must be a full URI (e.g., http://schema.org/name).
When `e` is false, `v` is a literal value (string, number, etc.).
Term types (discriminated by `t` field):
- `i`: IRI (URI reference)
- `l`: Literal (string value, optionally with datatype or language tag)
- `r`: Quoted triple (RDF-star reification)
- `b`: Blank node
properties:
t:
type: string
description: Term type discriminator
enum: [i, l, r, b]
example: i
i:
type: string
description: IRI value (when t=i)
example: http://example.com/Person1
v:
type: string
description: The value - full URI when e=true, literal when e=false
example: http://example.com/Person1
e:
type: boolean
description: True if entity/URI, false if literal value
example: true
description: Literal value (when t=l)
example: John Doe
d:
type: string
description: Datatype IRI for literal (when t=l, optional)
example: http://www.w3.org/2001/XMLSchema#integer
l:
type: string
description: Language tag for literal (when t=l, optional)
example: en
r:
type: object
description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure
properties:
s:
type: object
description: Subject term
p:
type: object
description: Predicate term
o:
type: object
description: Object term
required:
- v
- e
example:
v: http://schema.org/name
e: true
- t
examples:
- description: IRI term
value:
t: i
i: http://schema.org/name
- description: Literal term
value:
t: l
v: John Doe
- description: Literal with language tag
value:
t: l
v: Bonjour
l: fr

View file

@ -1,6 +1,7 @@
type: object
description: |
RDF triple representing a subject-predicate-object statement in the knowledge graph.
RDF triple representing a subject-predicate-object statement in the knowledge graph,
optionally scoped to a named graph.
Example: (Person1) -[has name]-> ("John Doe")
properties:
@ -13,17 +14,26 @@ properties:
o:
$ref: './RdfValue.yaml'
description: Object - the value or target entity
g:
type: string
description: |
Named graph URI (optional). When absent, the triple is in the default graph.
Well-known graphs:
- (empty/absent): Core knowledge facts
- urn:graph:source: Extraction provenance
- urn:graph:retrieval: Query-time explainability
example: urn:graph:source
required:
- s
- p
- o
example:
s:
v: http://example.com/Person1
e: true
t: i
i: http://example.com/Person1
p:
v: http://schema.org/name
e: true
t: i
i: http://schema.org/name
o:
t: l
v: John Doe
e: false

View file

@ -1,12 +1,22 @@
type: object
description: Document embeddings query response
description: Document embeddings query response with matching chunks and similarity scores
properties:
chunks:
type: array
description: Similar document chunks (text strings)
description: Matching document chunks with similarity scores
items:
type: string
type: object
properties:
chunk_id:
type: string
description: Chunk identifier URI
example: "urn:trustgraph:chunk:abc123"
score:
type: number
description: Similarity score (higher is more similar)
example: 0.89
example:
- "Quantum computing uses quantum mechanics principles for computation..."
- "Neural networks are computing systems inspired by biological neurons..."
- "Machine learning algorithms learn patterns from data..."
- chunk_id: "urn:trustgraph:chunk:abc123"
score: 0.95
- chunk_id: "urn:trustgraph:chunk:def456"
score: 0.82

View file

@ -1,12 +1,21 @@
type: object
description: Graph embeddings query response
description: Graph embeddings query response with matching entities and similarity scores
properties:
entities:
type: array
description: Similar entities (RDF values)
description: Matching graph entities with similarity scores
items:
$ref: '../../common/RdfValue.yaml'
type: object
properties:
entity:
$ref: '../../common/RdfValue.yaml'
description: Matching graph entity
score:
type: number
description: Similarity score (higher is more similar)
example: 0.92
example:
- {v: "https://example.com/person/alice", e: true}
- {v: "https://example.com/person/bob", e: true}
- {v: "https://example.com/concept/quantum", e: true}
- entity: {t: i, i: "https://example.com/person/alice"}
score: 0.95
- entity: {t: i, i: "https://example.com/concept/quantum"}
score: 0.82

View file

@ -28,3 +28,23 @@ properties:
description: Collection to query
default: default
example: research
g:
type: string
description: |
Named graph filter (optional).
- Omitted/null: all graphs
- Empty string: default graph only
- URI string: specific named graph (e.g., urn:graph:source, urn:graph:retrieval)
example: urn:graph:source
streaming:
type: boolean
description: Enable streaming response delivery
default: false
example: true
batch-size:
type: integer
description: Number of triples per streaming batch
default: 20
minimum: 1
maximum: 1000
example: 50

View file

@ -1,13 +1,31 @@
type: object
description: Document RAG response
description: Document RAG response message
properties:
message_type:
type: string
description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events
enum: [chunk, explain]
example: chunk
response:
type: string
description: Generated response based on retrieved documents
example: The research papers found three key findings...
description: Generated response text (for chunk messages)
example: Based on the policy documents, customers can return items within 30 days...
explain_id:
type: string
description: Explainability node URI (for explain messages)
example: urn:trustgraph:question:abc123
explain_graph:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
end-of-stream:
type: boolean
description: Indicates streaming is complete (streaming mode)
description: Indicates LLM response stream is complete
default: false
example: true
end_of_session:
type: boolean
description: Indicates entire session is complete (all messages sent)
default: false
example: true
error:

View file

@ -1,13 +1,31 @@
type: object
description: Graph RAG response
description: Graph RAG response message
properties:
message_type:
type: string
description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events
enum: [chunk, explain]
example: chunk
response:
type: string
description: Generated response based on retrieved knowledge graph
description: Generated response text (for chunk messages)
example: Quantum physics and computer science intersect in quantum computing...
end-of-stream:
explain_id:
type: string
description: Explainability node URI (for explain messages)
example: urn:trustgraph:question:abc123
explain_graph:
type: string
description: Named graph containing the explainability data
example: urn:graph:retrieval
end_of_stream:
type: boolean
description: Indicates streaming is complete (streaming mode)
description: Indicates LLM response stream is complete
default: false
example: true
end_of_session:
type: boolean
description: Indicates entire session is complete (all messages sent)
default: false
example: true
error:

View file

@ -2,7 +2,7 @@ openapi: 3.1.0
info:
title: TrustGraph API Gateway
version: "1.8"
version: "2.1"
description: |
REST API for TrustGraph - an AI-powered knowledge graph and RAG system.
@ -28,7 +28,7 @@ info:
Require running flow instance, accessed via `/api/v1/flow/{flow}/service/{kind}`:
- AI services: agent, text-completion, prompt, RAG (document/graph)
- Embeddings: embeddings, graph-embeddings, document-embeddings
- Query: triples, objects, nlp-query, structured-query
- Query: triples, rows, nlp-query, structured-query, row-embeddings
- Data loading: text-load, document-load
- Utilities: mcp-tool, structured-diag
@ -140,6 +140,10 @@ paths:
/api/v1/flow/{flow}/service/document-load:
$ref: './paths/flow/document-load.yaml'
# Document streaming
/api/v1/document-stream:
$ref: './paths/document-stream.yaml'
# Import/Export endpoints
/api/v1/import-core:
$ref: './paths/import-core.yaml'

View file

@ -0,0 +1,53 @@
get:
tags:
- Import/Export
summary: Stream document content from library
description: |
Streams the raw content of a document stored in the library.
Returns the document content in chunked transfer encoding.
## Parameters
- `user`: User identifier (required)
- `document-id`: Document IRI to retrieve (required)
- `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB)
operationId: documentStream
security:
- bearerAuth: []
parameters:
- name: user
in: query
required: true
schema:
type: string
description: User identifier
example: trustgraph
- name: document-id
in: query
required: true
schema:
type: string
description: Document IRI to retrieve
example: "urn:trustgraph:doc:abc123"
- name: chunk-size
in: query
required: false
schema:
type: integer
default: 1048576
description: Chunk size in bytes (default 1MB)
responses:
'200':
description: Document content streamed as raw bytes
content:
application/octet-stream:
schema:
type: string
format: binary
'400':
description: Missing required parameters
'401':
$ref: '../components/responses/Unauthorized.yaml'
'500':
$ref: '../components/responses/Error.yaml'

View file

@ -24,7 +24,7 @@ echo
# Build WebSocket API documentation
echo "Building WebSocket API documentation (AsyncAPI)..."
cd ../websocket
npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template@3.0.0 --use-new-generator -o /tmp/asyncapi-build -p singleFile=true --force-write
npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template -o /tmp/asyncapi-build -p singleFile=true --force-write
mv /tmp/asyncapi-build/index.html ../../docs/websocket.html
rm -rf /tmp/asyncapi-build
echo "✓ WebSocket API docs generated: docs/websocket.html"

View file

@ -2,7 +2,7 @@ asyncapi: 3.0.0
info:
title: TrustGraph WebSocket API
version: "1.8"
version: "2.1"
description: |
WebSocket API for TrustGraph - providing multiplexed, asynchronous access to all services.
@ -31,7 +31,7 @@ info:
**Flow-Hosted Services** (require `flow` parameter):
- agent, text-completion, prompt, document-rag, graph-rag
- embeddings, graph-embeddings, document-embeddings
- triples, objects, nlp-query, structured-query, structured-diag
- triples, rows, nlp-query, structured-query, structured-diag, row-embeddings
- text-load, document-load, mcp-tool
## Schema Reuse

View file

@ -95,8 +95,7 @@ def sample_message_data():
"Metadata": {
"id": "test-doc-123",
"user": "test_user",
"collection": "test_collection",
"metadata": []
"collection": "test_collection"
},
"Term": {
"type": IRI,

View file

@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
import pytest
from unittest.mock import MagicMock
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
from trustgraph.messaging.translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator,
DocumentEmbeddingsResponseTranslator
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
"""Test that DocumentEmbeddingsRequest has expected fields"""
# Create a request
request = DocumentEmbeddingsRequest(
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify all expected fields exist
assert hasattr(request, 'vectors')
assert hasattr(request, 'vector')
assert hasattr(request, 'limit')
assert hasattr(request, 'user')
assert hasattr(request, 'collection')
# Verify field values
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert request.vector == [0.1, 0.2, 0.3]
assert request.limit == 10
assert request.user == "test_user"
assert request.collection == "test_collection"
@ -41,18 +41,18 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_to_pulsar(self):
"""Test request translator converts dict to Pulsar schema"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2], [0.3, 0.4]],
"vector": [0.1, 0.2, 0.3, 0.4],
"limit": 5,
"user": "custom_user",
"collection": "custom_collection"
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
assert result.vector == [0.1, 0.2, 0.3, 0.4]
assert result.limit == 5
assert result.user == "custom_user"
assert result.collection == "custom_collection"
@ -60,16 +60,16 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_to_pulsar_with_defaults(self):
"""Test request translator uses correct defaults"""
translator = DocumentEmbeddingsRequestTranslator()
data = {
"vectors": [[0.1, 0.2]]
"vector": [0.1, 0.2]
# No limit, user, or collection provided
}
result = translator.to_pulsar(data)
assert isinstance(result, DocumentEmbeddingsRequest)
assert result.vectors == [[0.1, 0.2]]
assert result.vector == [0.1, 0.2]
assert result.limit == 10 # Default
assert result.user == "trustgraph" # Default
assert result.collection == "default" # Default
@ -77,18 +77,18 @@ class TestDocumentEmbeddingsRequestContract:
def test_request_translator_from_pulsar(self):
"""Test request translator converts Pulsar schema to dict"""
translator = DocumentEmbeddingsRequestTranslator()
request = DocumentEmbeddingsRequest(
vectors=[[0.5, 0.6]],
vector=[0.5, 0.6],
limit=20,
user="test_user",
collection="test_collection"
)
result = translator.from_pulsar(request)
assert isinstance(result, dict)
assert result["vectors"] == [[0.5, 0.6]]
assert result["vector"] == [0.5, 0.6]
assert result["limit"] == 20
assert result["user"] == "test_user"
assert result["collection"] == "test_collection"
@ -102,16 +102,22 @@ class TestDocumentEmbeddingsResponseContract:
# Create a response with chunks
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2", "chunk3"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8),
ChunkMatch(chunk_id="chunk3", score=0.7)
]
)
# Verify all expected fields exist
assert hasattr(response, 'error')
assert hasattr(response, 'chunks')
# Verify field values
assert response.error is None
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
assert len(response.chunks) == 3
assert response.chunks[0].chunk_id == "chunk1"
assert response.chunks[0].score == 0.9
def test_response_schema_with_error(self):
"""Test response schema with error"""
@ -119,52 +125,47 @@ class TestDocumentEmbeddingsResponseContract:
type="query_error",
message="Database connection failed"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
chunks=[]
)
assert response.error == error
assert response.chunks is None
assert response.chunks == []
def test_response_translator_from_pulsar_with_chunks(self):
"""Test response translator converts Pulsar schema with chunks to dict"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["doc1", "doc2", "doc3"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85),
ChunkMatch(chunk_id="doc3/c3", score=0.75)
]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["doc1", "doc2", "doc3"]
def test_response_translator_from_pulsar_with_bytes(self):
"""Test response translator handles byte chunks correctly"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
assert len(result["chunks"]) == 3
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
assert result["chunks"][0]["score"] == 0.95
def test_response_translator_from_pulsar_with_empty_chunks(self):
"""Test response translator handles empty chunks list"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = []
response = DocumentEmbeddingsResponse(
error=None,
chunks=[]
)
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == []
@ -172,37 +173,41 @@ class TestDocumentEmbeddingsResponseContract:
def test_response_translator_from_pulsar_with_none_chunks(self):
"""Test response translator handles None chunks"""
translator = DocumentEmbeddingsResponseTranslator()
response = MagicMock()
response.chunks = None
result = translator.from_pulsar(response)
assert isinstance(result, dict)
assert "chunks" not in result or result.get("chunks") is None
def test_response_translator_from_response_with_completion(self):
"""Test response translator with completion flag"""
translator = DocumentEmbeddingsResponseTranslator()
response = DocumentEmbeddingsResponse(
error=None,
chunks=["chunk1", "chunk2"]
chunks=[
ChunkMatch(chunk_id="chunk1", score=0.9),
ChunkMatch(chunk_id="chunk2", score=0.8)
]
)
result, is_final = translator.from_response_with_completion(response)
assert isinstance(result, dict)
assert "chunks" in result
assert result["chunks"] == ["chunk1", "chunk2"]
assert len(result["chunks"]) == 2
assert result["chunks"][0]["chunk_id"] == "chunk1"
assert is_final is True # Document embeddings responses are always final
def test_response_translator_to_pulsar_not_implemented(self):
"""Test that to_pulsar raises NotImplementedError for responses"""
translator = DocumentEmbeddingsResponseTranslator()
with pytest.raises(NotImplementedError):
translator.to_pulsar({"chunks": ["test"]})
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
class TestDocumentEmbeddingsMessageCompatibility:
@ -212,26 +217,29 @@ class TestDocumentEmbeddingsMessageCompatibility:
"""Test complete request-response flow maintains data integrity"""
# Create request
request_data = {
"vectors": [[0.1, 0.2, 0.3]],
"vector": [0.1, 0.2, 0.3],
"limit": 5,
"user": "test_user",
"collection": "test_collection"
}
# Convert to Pulsar request
req_translator = DocumentEmbeddingsRequestTranslator()
pulsar_request = req_translator.to_pulsar(request_data)
# Simulate service processing and creating response
response = DocumentEmbeddingsResponse(
error=None,
chunks=["relevant chunk 1", "relevant chunk 2"]
chunks=[
ChunkMatch(chunk_id="doc1/c1", score=0.95),
ChunkMatch(chunk_id="doc2/c2", score=0.85)
]
)
# Convert response back to dict
resp_translator = DocumentEmbeddingsResponseTranslator()
response_data = resp_translator.from_pulsar(response)
# Verify data integrity
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
assert isinstance(response_data, dict)
@ -245,17 +253,18 @@ class TestDocumentEmbeddingsMessageCompatibility:
type="vector_db_error",
message="Collection not found"
)
response = DocumentEmbeddingsResponse(
error=error,
chunks=None
chunks=[]
)
# Convert response to dict
translator = DocumentEmbeddingsResponseTranslator()
response_data = translator.from_pulsar(response)
# Verify error handling
assert isinstance(response_data, dict)
# The translator doesn't include error in the dict, only chunks
assert "chunks" not in response_data or response_data.get("chunks") is None
assert "chunks" in response_data
assert response_data["chunks"] == []

View file

@ -401,25 +401,6 @@ class TestMetadataMessageContracts:
assert metadata.id == "test-doc-123"
assert metadata.user == "test_user"
assert metadata.collection == "test_collection"
assert isinstance(metadata.metadata, list)
def test_metadata_with_triples_contract(self, sample_message_data):
"""Test Metadata with embedded triples contract"""
# Arrange
triple = Triple(**sample_message_data["Triple"])
metadata_data = {
"id": "doc-with-triples",
"user": "test_user",
"collection": "test_collection",
"metadata": [triple]
}
# Act & Assert
assert validate_schema_contract(Metadata, metadata_data)
metadata = Metadata(**metadata_data)
assert len(metadata.metadata) == 1
assert metadata.metadata[0].s.iri == "http://example.com/subject"
def test_error_schema_contract(self):
"""Test Error schema contract"""

View file

@ -24,7 +24,6 @@ class TestRowsCassandraContracts:
id="test-doc-001",
user="test_user",
collection="test_collection",
metadata=[]
)
test_object = ExtractedObject(
@ -50,7 +49,6 @@ class TestRowsCassandraContracts:
assert hasattr(test_object.metadata, 'id')
assert hasattr(test_object.metadata, 'user')
assert hasattr(test_object.metadata, 'collection')
assert hasattr(test_object.metadata, 'metadata')
# Verify types
assert isinstance(test_object.schema_name, str)
@ -154,7 +152,6 @@ class TestRowsCassandraContracts:
id="serial-001",
user="test_user",
collection="test_coll",
metadata=[]
),
schema_name="test_schema",
values=[{"field1": "value1", "field2": "123"}],
@ -234,7 +231,6 @@ class TestRowsCassandraContracts:
id="meta-001",
user="user123", # -> keyspace
collection="coll456", # -> partition key
metadata=[{"key": "value"}]
),
schema_name="table789", # -> table name
values=[{"field": "value"}],
@ -262,7 +258,6 @@ class TestRowsCassandraContractsBatch:
id="batch-doc-001",
user="test_user",
collection="test_collection",
metadata=[]
)
batch_object = ExtractedObject(
@ -308,10 +303,9 @@ class TestRowsCassandraContractsBatch:
test_metadata = Metadata(
id="empty-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
collection="test_collection",
)
empty_batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="empty_schema",
@ -332,9 +326,8 @@ class TestRowsCassandraContractsBatch:
id="single-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
)
single_batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="customer_records",
@ -362,12 +355,11 @@ class TestRowsCassandraContractsBatch:
id="batch-serial-001",
user="test_user",
collection="test_coll",
metadata=[]
),
schema_name="test_schema",
values=[
{"field1": "value1", "field2": "123"},
{"field1": "value2", "field2": "456"},
{"field1": "value2", "field2": "456"},
{"field1": "value3", "field2": "789"}
],
confidence=0.92,
@ -436,9 +428,8 @@ class TestRowsCassandraContractsBatch:
id="partition-test-001",
user="consistent_user", # Same keyspace
collection="consistent_collection", # Same partition
metadata=[]
)
batch_object = ExtractedObject(
metadata=test_metadata,
schema_name="partition_test",

View file

@ -95,9 +95,8 @@ class TestStructuredDataSchemaContracts:
id="structured-data-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
submission = StructuredDataSubmission(
metadata=metadata,
@ -121,9 +120,8 @@ class TestStructuredDataSchemaContracts:
id="extracted-obj-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
obj = ExtractedObject(
metadata=metadata,
@ -147,9 +145,8 @@ class TestStructuredDataSchemaContracts:
id="extracted-batch-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act - create object with multiple values
obj = ExtractedObject(
metadata=metadata,
@ -180,11 +177,10 @@ class TestStructuredDataSchemaContracts:
# Arrange
metadata = Metadata(
id="extracted-empty-001",
user="test_user",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act - create object with empty values array
obj = ExtractedObject(
metadata=metadata,
@ -283,13 +279,12 @@ class TestStructuredEmbeddingsContracts:
id="struct-embed-001",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
embedding = StructuredObjectEmbedding(
metadata=metadata,
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3],
schema_name="customer_records",
object_id="customer_123",
field_embeddings={
@ -301,7 +296,7 @@ class TestStructuredEmbeddingsContracts:
# Assert
assert embedding.schema_name == "customer_records"
assert embedding.object_id == "customer_123"
assert len(embedding.vectors) == 2
assert len(embedding.vector) == 3
assert len(embedding.field_embeddings) == 2
assert "name" in embedding.field_embeddings
@ -313,7 +308,7 @@ class TestStructuredDataSerializationContracts:
def test_structured_data_submission_serialization(self):
"""Test StructuredDataSubmission serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
metadata = Metadata(id="test", user="user", collection="col")
submission_data = {
"metadata": metadata,
"format": "json",
@ -328,7 +323,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_serialization(self):
"""Test ExtractedObject serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
metadata = Metadata(id="test", user="user", collection="col")
object_data = {
"metadata": metadata,
"schema_name": "test_schema",
@ -378,7 +373,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_batch_serialization(self):
"""Test ExtractedObject batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
metadata = Metadata(id="test", user="user", collection="col")
batch_object_data = {
"metadata": metadata,
"schema_name": "test_schema",
@ -397,7 +392,7 @@ class TestStructuredDataSerializationContracts:
def test_extracted_object_empty_batch_serialization(self):
"""Test ExtractedObject empty batch serialization contract"""
# Arrange
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
metadata = Metadata(id="test", user="user", collection="col")
empty_batch_data = {
"metadata": metadata,
"schema_name": "test_schema",

View file

@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry
class TestRAGTranslatorCompletionFlags:
"""Contract tests for RAG response translator completion flags"""
def test_graph_rag_translator_is_final_with_end_of_stream_true(self):
def test_graph_rag_translator_is_final_with_end_of_session_true(self):
"""
Test that GraphRagResponseTranslator returns is_final=True
when end_of_stream=True.
when end_of_session=True.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="A small domesticated mammal.",
message_type="chunk",
end_of_stream=True,
end_of_session=True,
error=None
)
@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_stream=True"
assert is_final is True, "is_final must be True when end_of_session=True"
assert response_dict["response"] == "A small domesticated mammal."
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is True
assert response_dict["message_type"] == "chunk"
def test_graph_rag_translator_is_final_with_end_of_stream_false(self):
def test_graph_rag_translator_is_final_with_end_of_session_false(self):
"""
Test that GraphRagResponseTranslator returns is_final=False
when end_of_stream=False.
when end_of_session=False (even if end_of_stream=True).
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="Chunk 1",
message_type="chunk",
end_of_stream=False,
end_of_session=False,
error=None
)
@ -55,20 +60,67 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "is_final must be False when end_of_stream=False"
assert is_final is False, "is_final must be False when end_of_session=False"
assert response_dict["response"] == "Chunk 1"
assert response_dict["end_of_stream"] is False
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
def test_graph_rag_translator_provenance_message(self):
"""
Test that GraphRagResponseTranslator handles provenance messages.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="",
message_type="explain",
explain_id="urn:trustgraph:session:abc123",
end_of_stream=False,
end_of_session=False,
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False
assert response_dict["message_type"] == "explain"
assert response_dict["explain_id"] == "urn:trustgraph:session:abc123"
def test_graph_rag_translator_end_of_stream_not_final(self):
"""
Test that end_of_stream=True alone does NOT make is_final=True.
The session continues with provenance messages after LLM stream completes.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("graph-rag")
response = GraphRagResponse(
response="Final chunk",
message_type="chunk",
end_of_stream=True,
end_of_session=False, # Session continues with provenance
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_session_true(self):
"""
Test that DocumentRagResponseTranslator returns is_final=True
when end_of_stream=True.
when end_of_session=True.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse(
response="A document about cats.",
end_of_stream=True,
end_of_session=True,
error=None
)
@ -76,9 +128,31 @@ class TestRAGTranslatorCompletionFlags:
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True, "is_final must be True when end_of_stream=True"
assert is_final is True, "is_final must be True when end_of_session=True"
assert response_dict["response"] == "A document about cats."
assert response_dict["end_of_session"] is True
def test_document_rag_translator_end_of_stream_not_final(self):
"""
Test that end_of_stream=True alone does NOT make is_final=True.
The session continues with provenance messages after LLM stream completes.
"""
# Arrange
translator = TranslatorRegistry.get_response_translator("document-rag")
response = DocumentRagResponse(
response="Final chunk",
end_of_stream=True,
end_of_session=False, # Session continues with provenance
error=None
)
# Act
response_dict, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
assert response_dict["end_of_stream"] is True
assert response_dict["end_of_session"] is False
def test_document_rag_translator_is_final_with_end_of_stream_false(self):
"""

View file

@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.template.prompt_manager import PromptManager
@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration:
agent_client = AsyncMock()
# Mock successful agent response in JSONL format
def mock_agent_response(recipient, question):
def mock_agent_response(question):
# Simulate agent processing and return structured JSONL response
mock_response = MagicMock()
mock_response.error = None
@ -76,13 +76,6 @@ class TestAgentKgExtractionIntegration:
chunk=text.encode('utf-8'),
metadata=Metadata(
id="doc123",
metadata=[
Triple(
s=Term(type=IRI, iri="doc123"),
p=Term(type=IRI, iri="http://example.org/type"),
o=Term(type=LITERAL, value="document")
)
]
)
)
@ -131,16 +124,12 @@ class TestAgentKgExtractionIntegration:
# Get agent response (the mock returns a string directly)
agent_client = flow("agent-request")
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
agent_response = agent_client.invoke(question=prompt)
# Parse and process
extraction_data = extractor.parse_jsonl(agent_response)
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
# Add metadata triples
for t in v.metadata.metadata:
triples.append(t)
triples, entity_contexts, extracted_triples = extractor.process_extraction_data(extraction_data, v.metadata)
# Emit outputs
if triples:
await extractor.emit_triples(flow("triples"), v.metadata, triples)
@ -185,10 +174,6 @@ class TestAgentKgExtractionIntegration:
label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL]
assert len(label_triples) >= 2 # Should have labels for entities
# Check subject-of relationships
subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) >= 2 # Entities should be linked to document
# Verify entity contexts were emitted
entity_contexts_publisher = mock_flow_context("entity-contexts")
entity_contexts_publisher.send.assert_called_once()
@ -208,7 +193,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock agent error response
agent_client = mock_flow_context("agent-request")
def mock_error_response(recipient, question):
def mock_error_response(question):
# Simulate agent error by raising an exception
raise RuntimeError("Agent processing failed")
@ -230,7 +215,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock invalid JSON response
agent_client = mock_flow_context("agent-request")
def mock_invalid_json_response(recipient, question):
def mock_invalid_json_response(question):
return "This is not valid JSON at all"
agent_client.invoke = mock_invalid_json_response
@ -242,9 +227,9 @@ class TestAgentKgExtractionIntegration:
# Act - JSONL parsing is lenient, invalid lines are skipped
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
# Assert - should emit triples (with just metadata) but no entity contexts
# Assert - with no valid extraction data, nothing is emitted
triples_publisher = mock_flow_context("triples")
triples_publisher.send.assert_called_once()
triples_publisher.send.assert_not_called()
entity_contexts_publisher = mock_flow_context("entity-contexts")
entity_contexts_publisher.send.assert_not_called()
@ -255,7 +240,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock empty extraction response
agent_client = mock_flow_context("agent-request")
def mock_empty_response(recipient, question):
def mock_empty_response(question):
# Return empty JSONL (just empty/whitespace)
return ''
@ -268,17 +253,12 @@ class TestAgentKgExtractionIntegration:
# Act
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
# Assert
# Should still emit outputs (even if empty) to maintain flow consistency
# Assert - with empty extraction results, nothing is emitted
triples_publisher = mock_flow_context("triples")
entity_contexts_publisher = mock_flow_context("entity-contexts")
# Triples should include metadata triples at minimum
triples_publisher.send.assert_called_once()
sent_triples = triples_publisher.send.call_args[0][0]
assert isinstance(sent_triples, Triples)
# Entity contexts should not be sent if empty
# No triples or entity contexts emitted for empty results
triples_publisher.send.assert_not_called()
entity_contexts_publisher.send.assert_not_called()
@pytest.mark.asyncio
@ -287,7 +267,7 @@ class TestAgentKgExtractionIntegration:
# Arrange - mock malformed extraction response
agent_client = mock_flow_context("agent-request")
def mock_malformed_response(recipient, question):
def mock_malformed_response(question):
# JSONL with definition missing required field
return '{"type": "definition", "entity": "Missing Definition"}'
@ -308,12 +288,12 @@ class TestAgentKgExtractionIntegration:
test_text = "Test text for prompt rendering"
chunk = Chunk(
chunk=test_text.encode('utf-8'),
metadata=Metadata(id="test-doc", metadata=[])
metadata=Metadata(id="test-doc")
)
agent_client = mock_flow_context("agent-request")
def capture_prompt(recipient, question):
def capture_prompt(question):
# Verify the prompt contains the test text
assert test_text in question
return '' # Empty JSONL response
@ -340,13 +320,13 @@ class TestAgentKgExtractionIntegration:
text = f"Test document {i} content"
chunks.append(Chunk(
chunk=text.encode('utf-8'),
metadata=Metadata(id=f"doc{i}", metadata=[])
metadata=Metadata(id=f"doc{i}")
))
agent_client = mock_flow_context("agent-request")
responses = []
def mock_response(recipient, question):
def mock_response(question):
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
responses.append(response)
return response
@ -375,12 +355,12 @@ class TestAgentKgExtractionIntegration:
unicode_text = "Machine Learning (学习机器) は人工知能の一分野です。"
chunk = Chunk(
chunk=unicode_text.encode('utf-8'),
metadata=Metadata(id="unicode-doc", metadata=[])
metadata=Metadata(id="unicode-doc")
)
agent_client = mock_flow_context("agent-request")
def mock_unicode_response(recipient, question):
def mock_unicode_response(question):
# Verify unicode text was properly decoded and included
assert "学习机器" in question
assert "人工知能" in question
@ -411,12 +391,12 @@ class TestAgentKgExtractionIntegration:
large_text = "Machine Learning is important. " * 1000 # Repeat to create large text
chunk = Chunk(
chunk=large_text.encode('utf-8'),
metadata=Metadata(id="large-doc", metadata=[])
metadata=Metadata(id="large-doc")
)
agent_client = mock_flow_context("agent-request")
def mock_large_text_response(recipient, question):
def mock_large_text_response(question):
# Verify large text was included
assert len(question) > 10000
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'

View file

@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration:
pulsar_client=AsyncMock(),
max_iterations=3
)
# Mock the client method for structured query
proc.client = MagicMock()
# Mock librarian to avoid hanging on save operations
proc.save_answer_content = AsyncMock(return_value=None)
return proc
@pytest.fixture

View file

@ -9,6 +9,15 @@ Following the TEST_STRATEGY.md approach for integration testing.
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration
@ -19,23 +28,35 @@ class TestDocumentRagIntegration:
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
# New batch format: [[[vectors_for_text1], ...]]
# One text input returns one vector set containing two vectors
client.embed.return_value = [
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
[
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
]
]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns realistic document chunks"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
@ -48,17 +69,19 @@ class TestDocumentRagIntegration:
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
@ -76,15 +99,16 @@ class TestDocumentRagIntegration:
)
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with(query)
mock_embeddings_client.embed.assert_called_once_with([query])
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
@ -101,17 +125,19 @@ class TestDocumentRagIntegration:
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents found
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
@ -125,92 +151,98 @@ class TestDocumentRagIntegration:
query="very obscure query",
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -230,14 +262,14 @@ class TestDocumentRagIntegration:
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
@ -245,19 +277,21 @@ class TestDocumentRagIntegration:
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@ -269,25 +303,25 @@ class TestDocumentRagIntegration:
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "Getting documents..." in log_messages
assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large document set (100 documents)
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_doc_set
# Arrange - Mock large chunk match set (100 chunks)
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_matches
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
@ -309,4 +343,4 @@ class TestDocumentRagIntegration:
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20
assert call_args.kwargs['limit'] == 20

View file

@ -8,12 +8,21 @@ response delivery through the complete pipeline.
import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
)
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of AI.",
"doc/c2": "Deep learning uses neural networks.",
"doc/c3": "Supervised learning needs labeled data.",
}
@pytest.mark.integration
class TestDocumentRagStreaming:
"""Integration tests for DocumentRAG streaming"""
@ -22,20 +31,29 @@ class TestDocumentRagStreaming:
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
# New batch format: [[[vectors_for_text1]]]
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
"Machine learning is a subset of AI.",
"Deep learning uses neural networks.",
"Supervised learning needs labeled data."
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
@ -66,12 +84,13 @@ class TestDocumentRagStreaming:
@pytest.fixture
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with streaming support"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@ -190,7 +209,7 @@ class TestDocumentRagStreaming:
mock_doc_embeddings_client):
"""Test streaming with no documents found"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids
callback = AsyncMock()
# Act

View file

@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
@pytest.mark.integration
@ -21,8 +22,12 @@ class TestGraphRagIntegration:
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
# New batch format: [[[vectors_for_text1], ...]]
# One text input returns one vector set containing one vector
client.embed.return_value = [
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
[
[0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text
]
]
return client
@ -31,9 +36,9 @@ class TestGraphRagIntegration:
"""Mock graph embeddings client that returns realistic entities"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
"http://trustgraph.ai/e/artificial-intelligence",
"http://trustgraph.ai/e/neural-networks"
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
]
return client
@ -43,7 +48,7 @@ class TestGraphRagIntegration:
client = AsyncMock()
# Mock different queries return different triples
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
# Mock label queries
if p == "http://www.w3.org/2000/01/rdf-schema#label":
if s == "http://trustgraph.ai/e/machine-learning":
@ -71,18 +76,37 @@ class TestGraphRagIntegration:
return []
client.query.side_effect = query_side_effect
client.query_stream.side_effect = query_stream_side_effect
# Also mock query for label lookups (maybe_label uses query, not query_stream)
client.query.side_effect = query_stream_side_effect
return client
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
"""Mock prompt client that generates realistic responses for two-step process"""
client = AsyncMock()
client.kg_prompt.return_value = (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
# Mock responses for the multi-step process:
# 1. extract-concepts extracts key concepts from the query
# 2. kg-edge-scoring scores edges for relevance
# 3. kg-edge-reasoning provides reasoning for selected edges
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
return ""
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
@ -101,7 +125,7 @@ class TestGraphRagIntegration:
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client, mock_triples_client,
mock_prompt_client):
"""Test complete GraphRAG pipeline from query to response"""
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
# Arrange
query = "What is machine learning?"
user = "test_user"
@ -109,41 +133,51 @@ class TestGraphRagIntegration:
entity_limit = 50
triple_limit = 30
# Collect provenance events
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
result = await graph_rag.query(
response = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit
triple_limit=triple_limit,
explain_callback=collect_provenance
)
# Assert - Verify service coordination
# 1. Should compute embeddings for query
mock_embeddings_client.embed.assert_called_once_with(query)
# 1. Should compute embeddings for query (now expects list of texts)
mock_embeddings_client.embed.assert_called_once_with([query])
# 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query.call_count > 0
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt with knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert call_args.args[0] == query # First arg is query
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert prov_id.startswith("urn:trustgraph:")
@pytest.mark.asyncio
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
@ -197,21 +231,27 @@ class TestGraphRagIntegration:
"""Test GraphRAG handles empty knowledge graph gracefully"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query.return_value = [] # No triples found
mock_triples_client.query_stream.return_value = [] # No triples found
# Collect provenance
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act
result = await graph_rag.query(
response = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection"
collection="test_collection",
explain_callback=collect_provenance
)
# Assert
# Should still call prompt client with empty knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert isinstance(call_args.args[1], list) # kg should be a list
assert result is not None
# Should still call prompt client
assert response is not None
# Provenance should still be emitted (5 events)
assert len(provenance_events) == 5
@pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
@ -226,7 +266,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
first_call_count = mock_triples_client.query.call_count
first_call_count = mock_triples_client.query_stream.call_count
mock_triples_client.reset_mock()
# Second identical query
@ -236,7 +276,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
second_call_count = mock_triples_client.query.call_count
second_call_count = mock_triples_client.query_stream.call_count
# Assert - Second query should make fewer triple queries due to caching
# Note: This is a weak assertion because caching behavior depends on

View file

@ -8,6 +8,7 @@ response delivery through the complete pipeline.
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -24,7 +25,8 @@ class TestGraphRagStreaming:
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
# New batch format: [[[vectors_for_text1]]]
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
return client
@pytest.fixture
@ -32,7 +34,7 @@ class TestGraphRagStreaming:
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
]
return client
@ -51,30 +53,38 @@ class TestGraphRagStreaming:
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
"""Mock prompt client with streaming support for two-stage GraphRAG"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
# Both modes return the same text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
# Full synthesis text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response():
chunks.append(chunk)
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response():
chunks.append(chunk)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
else:
# Non-streaming response - same text
return full_text
return full_text
else:
return full_text
return ""
client.kg_prompt.side_effect = kg_prompt_side_effect
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
@ -91,18 +101,25 @@ class TestGraphRagStreaming:
@pytest.mark.asyncio
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
"""Test basic GraphRAG streaming functionality"""
"""Test basic GraphRAG streaming functionality with real-time provenance"""
# Arrange
query = "What is machine learning?"
collector = streaming_chunk_collector()
# Act
result = await graph_rag_streaming.query(
# Collect provenance events
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
# Act - query() returns response, provenance via callback
response = await graph_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collector.collect
chunk_callback=collector.collect,
explain_callback=collect_provenance
)
# Assert
@ -114,10 +131,15 @@ class TestGraphRagStreaming:
# Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert response == full_from_chunks
# Verify content is reasonable
assert "machine" in result.lower() or "learning" in result.lower()
assert "machine" in response.lower() or "learning" in response.lower()
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert prov_id.startswith("urn:trustgraph:")
@pytest.mark.asyncio
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
@ -128,7 +150,7 @@ class TestGraphRagStreaming:
collection = "test_collection"
# Act - Non-streaming
non_streaming_result = await graph_rag_streaming.query(
non_streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
@ -141,7 +163,7 @@ class TestGraphRagStreaming:
async def collect(chunk, end_of_stream):
streaming_chunks.append(chunk)
streaming_result = await graph_rag_streaming.query(
streaming_response = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
@ -150,9 +172,9 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
assert streaming_response == non_streaming_response
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_response
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -161,7 +183,7 @@ class TestGraphRagStreaming:
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
@ -171,7 +193,7 @@ class TestGraphRagStreaming:
# Assert
assert callback.call_count > 0
assert result is not None
assert response is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -181,7 +203,7 @@ class TestGraphRagStreaming:
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
"""Test streaming parameter without callback (should fall back to non-streaming)"""
# Arrange & Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
@ -190,8 +212,8 @@ class TestGraphRagStreaming:
)
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
assert response is not None
assert isinstance(response, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
@ -202,7 +224,7 @@ class TestGraphRagStreaming:
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
response = await graph_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
@ -211,7 +233,7 @@ class TestGraphRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
assert response is not None
assert callback.call_count > 0
@pytest.mark.asyncio

View file

@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
triples_obj = Triples(
metadata=Metadata(
id=f"export-msg-{i}",
metadata=to_subgraph(msg_data["metadata"]["metadata"]),
user=msg_data["metadata"]["user"],
collection=msg_data["metadata"]["collection"],
),

View file

@ -17,7 +17,7 @@ from trustgraph.extract.kg.relationships.extract import Processor as Relationshi
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
@pytest.mark.integration
@ -92,7 +92,6 @@ class TestKnowledgeGraphPipelineIntegration:
id="doc-123",
user="test_user",
collection="test_collection",
metadata=[]
),
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
)
@ -243,13 +242,12 @@ class TestKnowledgeGraphPipelineIntegration:
id="test-doc",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
triples = []
entities = []
for defn in sample_definitions_response:
s = defn["entity"]
o = defn["definition"]
@ -302,12 +300,11 @@ class TestKnowledgeGraphPipelineIntegration:
id="test-doc",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
triples = []
for rel in sample_relationships_response:
s = rel["subject"]
p = rel["predicate"]
@ -373,7 +370,6 @@ class TestKnowledgeGraphPipelineIntegration:
id="test-doc",
user="test_user",
collection="test_collection",
metadata=[]
),
triples=[
Triple(
@ -406,12 +402,11 @@ class TestKnowledgeGraphPipelineIntegration:
id="test-doc",
user="test_user",
collection="test_collection",
metadata=[]
),
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/entity"),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
]
)
@ -542,7 +537,7 @@ class TestKnowledgeGraphPipelineIntegration:
]
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection", metadata=[]),
metadata=Metadata(id="test", user="user", collection="collection"),
chunk=b"Test chunk"
)
@ -569,7 +564,7 @@ class TestKnowledgeGraphPipelineIntegration:
# Arrange
large_chunk_batch = [
Chunk(
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection", metadata=[]),
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
)
for i in range(100) # Large batch
@ -608,15 +603,8 @@ class TestKnowledgeGraphPipelineIntegration:
id="test-doc-123",
user="test_user",
collection="test_collection",
metadata=[
Triple(
s=Term(type=IRI, iri="doc:test"),
p=Term(type=IRI, iri="dc:title"),
o=Term(type=LITERAL, value="Test Document")
)
]
)
sample_chunk = Chunk(
metadata=original_metadata,
chunk=b"Test content for metadata propagation"

View file

@ -231,7 +231,6 @@ class TestObjectExtractionServiceIntegration:
id="customer-doc-001",
user="integration_test",
collection="test_documents",
metadata=[]
)
chunk_text = """
@ -299,7 +298,6 @@ class TestObjectExtractionServiceIntegration:
id="product-doc-001",
user="integration_test",
collection="test_documents",
metadata=[]
)
chunk_text = """
@ -373,7 +371,6 @@ class TestObjectExtractionServiceIntegration:
id=chunk_id,
user="concurrent_test",
collection="test_collection",
metadata=[]
)
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
chunks.append(chunk)
@ -470,7 +467,7 @@ class TestObjectExtractionServiceIntegration:
await processor.on_schema_config(integration_config, version=1)
# Create test chunk
metadata = Metadata(id="error-test", user="test", collection="test", metadata=[])
metadata = Metadata(id="error-test", user="test", collection="test")
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
mock_msg = MagicMock()
@ -507,7 +504,6 @@ class TestObjectExtractionServiceIntegration:
id="metadata-test-chunk",
user="test_user",
collection="test_collection",
metadata=[] # Could include source document metadata
)
chunk = Chunk(

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
class TestGraphRagStreamingProtocol:
@ -18,14 +19,17 @@ class TestGraphRagStreamingProtocol:
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3]]
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
return client
@pytest.fixture
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = ["entity1", "entity2"]
client.query.return_value = [
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
]
return client
@pytest.fixture
@ -40,18 +44,23 @@ class TestGraphRagStreamingProtocol:
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
await chunk_callback("The", False)
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
else:
return "The answer is here."
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
await chunk_callback("The", False)
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
else:
return "The answer is here."
return ""
client.kg_prompt.side_effect = kg_prompt_side_effect
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
@ -197,16 +206,26 @@ class TestDocumentRagStreamingProtocol:
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3]]
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
client.query.return_value = ["doc1", "doc2"]
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90)
]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return f"Content for {chunk_id}"
return fetch
@pytest.fixture
def mock_streaming_prompt_client(self):
"""Mock prompt client with streaming support"""
@ -227,12 +246,13 @@ class TestDocumentRagStreamingProtocol:
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
mock_streaming_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
@ -312,20 +332,24 @@ class TestStreamingProtocolEdgeCases:
# Arrange
client = AsyncMock()
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
else:
return "textmore"
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
else:
return "textmore"
return ""
client.kg_prompt.side_effect = kg_prompt_with_empties
client.prompt.side_effect = prompt_with_empties
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
prompt_client=client,

View file

@ -120,7 +120,6 @@ class TestRowsCassandraIntegration:
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
@ -201,7 +200,7 @@ class TestRowsCassandraIntegration:
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
metadata=Metadata(id="p1", user="shop", collection="catalog"),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
@ -209,7 +208,7 @@ class TestRowsCassandraIntegration:
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
metadata=Metadata(id="o1", user="shop", collection="sales"),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
@ -254,7 +253,7 @@ class TestRowsCassandraIntegration:
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
metadata=Metadata(id="t1", user="test", collection="test"),
schema_name="indexed_data",
values=[{
"id": "123",
@ -337,7 +336,6 @@ class TestRowsCassandraIntegration:
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
@ -391,7 +389,7 @@ class TestRowsCassandraIntegration:
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
metadata=Metadata(id="empty-1", user="test", collection="empty"),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
@ -426,7 +424,7 @@ class TestRowsCassandraIntegration:
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
metadata=Metadata(id="t1", user="test", collection="test"),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
@ -470,7 +468,7 @@ class TestRowsCassandraIntegration:
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
metadata=Metadata(id="t1", user="test", collection="my_collection"),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,

View file

@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []
@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []
@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming:
max_iterations=10
)
# Mock librarian to avoid hanging on save operations
processor.save_answer_content = AsyncMock(return_value=None)
# Track all responses sent
sent_responses = []

View file

@ -0,0 +1,495 @@
"""
Unit tests for Tool Service functionality
Tests the dynamically pluggable tool services feature including:
- Tool service configuration parsing
- ToolServiceImpl initialization
- Request/response format
- Config parameter handling
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import json
class TestToolServiceConfigParsing:
"""Test cases for tool service configuration parsing"""
def test_tool_service_config_structure(self):
"""Test that tool-service config has required fields"""
# Arrange
valid_config = {
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [
{"name": "style", "required": False}
]
}
# Act & Assert
assert "id" in valid_config
assert "request-queue" in valid_config
assert "response-queue" in valid_config
assert valid_config["request-queue"].startswith("non-persistent://")
assert valid_config["response-queue"].startswith("non-persistent://")
def test_tool_service_config_without_queues_is_invalid(self):
"""Test that tool-service config requires request-queue and response-queue"""
# Arrange
invalid_config = {
"id": "joke-service",
"config-params": []
}
# Act & Assert
def validate_config(config):
request_queue = config.get("request-queue")
response_queue = config.get("response-queue")
if not request_queue or not response_queue:
raise RuntimeError("Tool-service must define 'request-queue' and 'response-queue'")
return True
with pytest.raises(RuntimeError) as exc_info:
validate_config(invalid_config)
assert "request-queue" in str(exc_info.value)
def test_tool_config_references_tool_service(self):
"""Test that tool config correctly references a tool-service"""
# Arrange
tool_services = {
"joke-service": {
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [{"name": "style", "required": False}]
}
}
tool_config = {
"type": "tool-service",
"name": "tell-joke",
"description": "Tell a joke on a given topic",
"service": "joke-service",
"style": "pun",
"arguments": [
{"name": "topic", "type": "string", "description": "The topic for the joke"}
]
}
# Act
service_ref = tool_config.get("service")
service_config = tool_services.get(service_ref)
# Assert
assert service_ref == "joke-service"
assert service_config is not None
assert service_config["request-queue"] == "non-persistent://tg/request/joke"
def test_tool_config_extracts_config_values(self):
"""Test that config values are extracted from tool config"""
# Arrange
tool_services = {
"joke-service": {
"id": "joke-service",
"request-queue": "non-persistent://tg/request/joke",
"response-queue": "non-persistent://tg/response/joke",
"config-params": [
{"name": "style", "required": False},
{"name": "max-length", "required": False}
]
}
}
tool_config = {
"type": "tool-service",
"name": "tell-joke",
"description": "Tell a joke",
"service": "joke-service",
"style": "pun",
"max-length": 100,
"arguments": []
}
# Act - simulate config extraction
service_config = tool_services[tool_config["service"]]
config_params = service_config.get("config-params", [])
config_values = {}
for param in config_params:
param_name = param.get("name") if isinstance(param, dict) else param
if param_name in tool_config:
config_values[param_name] = tool_config[param_name]
# Assert
assert config_values == {"style": "pun", "max-length": 100}
def test_required_config_param_validation(self):
"""Test that required config params are validated"""
# Arrange
tool_services = {
"custom-service": {
"id": "custom-service",
"request-queue": "non-persistent://tg/request/custom",
"response-queue": "non-persistent://tg/response/custom",
"config-params": [
{"name": "collection", "required": True},
{"name": "optional-param", "required": False}
]
}
}
tool_config_missing_required = {
"type": "tool-service",
"name": "custom-tool",
"description": "Custom tool",
"service": "custom-service",
# Missing required "collection" param
"optional-param": "value"
}
# Act & Assert
def validate_config_params(tool_config, service_config):
config_params = service_config.get("config-params", [])
for param in config_params:
param_name = param.get("name")
if param.get("required", False) and param_name not in tool_config:
raise RuntimeError(f"Missing required config param '{param_name}'")
return True
service_config = tool_services["custom-service"]
with pytest.raises(RuntimeError) as exc_info:
validate_config_params(tool_config_missing_required, service_config)
assert "collection" in str(exc_info.value)
class TestToolServiceRequest:
"""Test cases for tool service request format"""
def test_request_format(self):
"""Test that request is properly formatted with user, config, and arguments"""
# Arrange
user = "alice"
config_values = {"style": "pun", "collection": "jokes"}
arguments = {"topic": "programming"}
# Act - simulate request building
request = {
"user": user,
"config": json.dumps(config_values),
"arguments": json.dumps(arguments)
}
# Assert
assert request["user"] == "alice"
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
assert json.loads(request["arguments"]) == {"topic": "programming"}
def test_request_with_empty_config(self):
"""Test request when no config values are provided"""
# Arrange
user = "bob"
config_values = {}
arguments = {"query": "test"}
# Act
request = {
"user": user,
"config": json.dumps(config_values) if config_values else "{}",
"arguments": json.dumps(arguments) if arguments else "{}"
}
# Assert
assert request["config"] == "{}"
assert json.loads(request["arguments"]) == {"query": "test"}
class TestToolServiceResponse:
"""Test cases for tool service response handling"""
def test_success_response_handling(self):
"""Test handling of successful response"""
# Arrange
response = {
"error": None,
"response": "Hey alice! Here's a pun for you:\n\nWhy do programmers prefer dark mode?",
"end_of_stream": True
}
# Act & Assert
assert response["error"] is None
assert "pun" in response["response"]
assert response["end_of_stream"] is True
def test_error_response_handling(self):
"""Test handling of error response"""
# Arrange
response = {
"error": {
"type": "tool-service-error",
"message": "Service unavailable"
},
"response": "",
"end_of_stream": True
}
# Act & Assert
assert response["error"] is not None
assert response["error"]["type"] == "tool-service-error"
assert response["error"]["message"] == "Service unavailable"
def test_string_response_passthrough(self):
"""Test that string responses are passed through directly"""
# Arrange
response_text = "This is a joke response"
# Act - simulate response handling
def handle_response(response):
if isinstance(response, str):
return response
else:
return json.dumps(response)
result = handle_response(response_text)
# Assert
assert result == response_text
def test_dict_response_json_serialization(self):
"""Test that dict responses are JSON serialized"""
# Arrange
response_data = {"joke": "Why did the chicken cross the road?", "category": "classic"}
# Act
def handle_response(response):
if isinstance(response, str):
return response
else:
return json.dumps(response)
result = handle_response(response_data)
# Assert
assert result == json.dumps(response_data)
assert json.loads(result) == response_data
class TestToolServiceImpl:
"""Test cases for ToolServiceImpl class"""
def test_tool_service_impl_initialization(self):
"""Test ToolServiceImpl stores queues and config correctly"""
# Arrange
class MockArgument:
def __init__(self, name, type, description):
self.name = name
self.type = type
self.description = description
# Simulate ToolServiceImpl initialization
class MockToolServiceImpl:
def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None):
self.context = context
self.request_queue = request_queue
self.response_queue = response_queue
self.config_values = config_values or {}
self.arguments = arguments or []
self.processor = processor
self._client = None
def get_arguments(self):
return self.arguments
# Act
arguments = [
MockArgument("topic", "string", "The topic for the joke")
]
impl = MockToolServiceImpl(
context=lambda x: None,
request_queue="non-persistent://tg/request/joke",
response_queue="non-persistent://tg/response/joke",
config_values={"style": "pun"},
arguments=arguments,
processor=Mock()
)
# Assert
assert impl.request_queue == "non-persistent://tg/request/joke"
assert impl.response_queue == "non-persistent://tg/response/joke"
assert impl.config_values == {"style": "pun"}
assert len(impl.get_arguments()) == 1
assert impl.get_arguments()[0].name == "topic"
def test_tool_service_impl_client_caching(self):
"""Test that client is cached and reused"""
# Arrange
client_key = "non-persistent://tg/request/joke|non-persistent://tg/response/joke"
# Simulate client caching behavior
tool_service_clients = {}
def get_or_create_client(request_queue, response_queue, clients_cache):
client_key = f"{request_queue}|{response_queue}"
if client_key in clients_cache:
return clients_cache[client_key], False # False = not created
client = Mock()
clients_cache[client_key] = client
return client, True # True = newly created
# Act
client1, created1 = get_or_create_client(
"non-persistent://tg/request/joke",
"non-persistent://tg/response/joke",
tool_service_clients
)
client2, created2 = get_or_create_client(
"non-persistent://tg/request/joke",
"non-persistent://tg/response/joke",
tool_service_clients
)
# Assert
assert created1 is True
assert created2 is False
assert client1 is client2
class TestJokeServiceLogic:
"""Test cases for the joke service example"""
def test_topic_to_category_mapping(self):
"""Test that topics are mapped to categories correctly"""
# Arrange
def map_topic_to_category(topic):
topic = topic.lower()
if "program" in topic or "code" in topic or "computer" in topic or "software" in topic:
return "programming"
elif "llama" in topic:
return "llama"
elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic:
return "animals"
elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic:
return "food"
else:
return "default"
# Act & Assert
assert map_topic_to_category("programming") == "programming"
assert map_topic_to_category("software engineering") == "programming"
assert map_topic_to_category("llamas") == "llama"
assert map_topic_to_category("llama") == "llama"
assert map_topic_to_category("animals") == "animals"
assert map_topic_to_category("my dog") == "animals"
assert map_topic_to_category("food") == "food"
assert map_topic_to_category("cooking recipes") == "food"
assert map_topic_to_category("random topic") == "default"
assert map_topic_to_category("") == "default"
def test_joke_response_personalization(self):
"""Test that joke responses include user personalization"""
# Arrange
user = "alice"
style = "pun"
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
# Act
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
# Assert
assert "Hey alice!" in response
assert "pun" in response
assert joke in response
def test_style_normalization(self):
"""Test that invalid styles fall back to valid ones"""
import random
# Arrange
valid_styles = ["pun", "dad-joke", "one-liner"]
def normalize_style(style):
if style not in valid_styles:
return random.choice(valid_styles)
return style
# Act & Assert
assert normalize_style("pun") == "pun"
assert normalize_style("dad-joke") == "dad-joke"
assert normalize_style("one-liner") == "one-liner"
assert normalize_style("invalid-style") in valid_styles
assert normalize_style("") in valid_styles
class TestDynamicToolServiceBase:
"""Test cases for DynamicToolService base class behavior"""
def test_topic_to_pulsar_path_conversion(self):
"""Test that topic names are converted to full Pulsar paths"""
# Arrange
topic = "joke"
# Act
request_topic = f"non-persistent://tg/request/{topic}"
response_topic = f"non-persistent://tg/response/{topic}"
# Assert
assert request_topic == "non-persistent://tg/request/joke"
assert response_topic == "non-persistent://tg/response/joke"
def test_request_parsing(self):
"""Test parsing of incoming request"""
# Arrange
request_data = {
"user": "alice",
"config": '{"style": "pun"}',
"arguments": '{"topic": "programming"}'
}
# Act
user = request_data.get("user", "trustgraph")
config = json.loads(request_data["config"]) if request_data["config"] else {}
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
# Assert
assert user == "alice"
assert config == {"style": "pun"}
assert arguments == {"topic": "programming"}
def test_response_building(self):
"""Test building of response message"""
# Arrange
response_text = "Hey alice! Here's a joke"
error = None
# Act
response = {
"error": error,
"response": response_text if isinstance(response_text, str) else json.dumps(response_text),
"end_of_stream": True
}
# Assert
assert response["error"] is None
assert response["response"] == "Hey alice! Here's a joke"
assert response["end_of_stream"] is True
def test_error_response_building(self):
"""Test building of error response"""
# Arrange
error_message = "Service temporarily unavailable"
# Act
response = {
"error": {
"type": "tool-service-error",
"message": error_message
},
"response": "",
"end_of_stream": True
}
# Assert
assert response["error"]["type"] == "tool-service-error"
assert response["error"]["message"] == error_message
assert response["response"] == ""

View file

@ -0,0 +1,624 @@
"""
Tests for tool service lifecycle, invoke contract, streaming responses,
multi-tenancy, and error propagation.
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
classes rather than plain dicts.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
ToolServiceRequest, ToolServiceResponse, Error,
ToolRequest, ToolResponse,
)
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# DynamicToolService tests
# ---------------------------------------------------------------------------
class TestDynamicToolServiceInvokeContract:
@pytest.mark.asyncio
async def test_base_invoke_raises_not_implemented(self):
"""Base class invoke() should raise NotImplementedError."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
with pytest.raises(NotImplementedError):
await svc.invoke("user", {}, {})
@pytest.mark.asyncio
async def test_on_request_calls_invoke_with_parsed_args(self):
"""on_request should JSON-parse config/arguments and pass to invoke."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
calls = []
async def tracking_invoke(user, config, arguments):
calls.append({"user": user, "config": config, "arguments": arguments})
return "ok"
svc.invoke = tracking_invoke
# Ensure the class-level metric exists
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user="alice",
config='{"style": "pun"}',
arguments='{"topic": "cats"}',
)
msg.properties.return_value = {"id": "req-1"}
await svc.on_request(msg, MagicMock(), None)
assert len(calls) == 1
assert calls[0]["user"] == "alice"
assert calls[0]["config"] == {"style": "pun"}
assert calls[0]["arguments"] == {"topic": "cats"}
@pytest.mark.asyncio
async def test_on_request_empty_user_defaults_to_trustgraph(self):
"""Empty user field should default to 'trustgraph'."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
received_user = None
async def capture_invoke(user, config, arguments):
nonlocal received_user
received_user = user
return "ok"
svc.invoke = capture_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
msg.properties.return_value = {"id": "req-2"}
await svc.on_request(msg, MagicMock(), None)
assert received_user == "trustgraph"
@pytest.mark.asyncio
async def test_on_request_string_response_sent_directly(self):
"""String return from invoke → response field is the string."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def string_invoke(user, config, arguments):
return "hello world"
svc.invoke = string_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r1"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert isinstance(sent, ToolServiceResponse)
assert sent.response == "hello world"
assert sent.end_of_stream is True
assert sent.error is None
@pytest.mark.asyncio
async def test_on_request_dict_response_json_encoded(self):
"""Dict return from invoke → response field is JSON-encoded."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def dict_invoke(user, config, arguments):
return {"result": 42}
svc.invoke = dict_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r2"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert json.loads(sent.response) == {"result": 42}
@pytest.mark.asyncio
async def test_on_request_error_sends_error_response(self):
"""Exception in invoke → error response sent."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def failing_invoke(user, config, arguments):
raise ValueError("bad input")
svc.invoke = failing_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r3"}
await svc.on_request(msg, MagicMock(), None)
sent = svc.producer.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-service-error"
assert "bad input" in sent.error.message
assert sent.response == ""
@pytest.mark.asyncio
async def test_on_request_too_many_requests_propagates(self):
"""TooManyRequests should propagate (not caught as error response)."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def rate_limited_invoke(user, config, arguments):
raise TooManyRequests("rate limited")
svc.invoke = rate_limited_invoke
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "r4"}
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), None)
@pytest.mark.asyncio
async def test_on_request_preserves_message_id(self):
"""Response should include the original message id in properties."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test-svc"
svc.producer = AsyncMock()
async def ok_invoke(user, config, arguments):
return "ok"
svc.invoke = ok_invoke
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
msg.properties.return_value = {"id": "unique-42"}
await svc.on_request(msg, MagicMock(), None)
props = svc.producer.send.call_args[1]["properties"]
assert props["id"] == "unique-42"
# ---------------------------------------------------------------------------
# ToolService (flow-based) tests
# ---------------------------------------------------------------------------
class TestToolServiceOnRequest:
@pytest.mark.asyncio
async def test_string_response_sent_as_text(self):
"""String return from invoke_tool → ToolResponse.text is set."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
return "tool result"
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
flow.name = "test-flow"
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
msg.properties.return_value = {"id": "t1"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert isinstance(sent, ToolResponse)
assert sent.text == "tool result"
assert sent.object is None
@pytest.mark.asyncio
async def test_dict_response_sent_as_json_object(self):
"""Dict return from invoke_tool → ToolResponse.object is JSON."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def mock_invoke(name, params):
return {"data": [1, 2, 3]}
svc.invoke_tool = mock_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
if name == "response":
return mock_response_pub
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t2"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.text is None
assert json.loads(sent.object) == {"data": [1, 2, 3]}
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
"""Exception in invoke_tool → error response via flow producer."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def failing_invoke(name, params):
raise RuntimeError("tool broke")
svc.invoke_tool = failing_invoke
mock_response_pub = AsyncMock()
flow = MagicMock()
def flow_callable(name):
return MagicMock()
flow_callable.producer = {"response": mock_response_pub}
flow_callable.name = "test-flow"
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t3"}
await svc.on_request(msg, MagicMock(), flow_callable)
sent = mock_response_pub.send.call_args[0][0]
assert sent.error is not None
assert sent.error.type == "tool-error"
assert "tool broke" in sent.error.message
@pytest.mark.asyncio
async def test_too_many_requests_propagates(self):
"""TooManyRequests should propagate from ToolService.on_request."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
async def rate_limited(name, params):
raise TooManyRequests("slow down")
svc.invoke_tool = rate_limited
msg = MagicMock()
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
msg.properties.return_value = {"id": "t4"}
flow = MagicMock()
flow.producer = {"response": AsyncMock()}
flow.name = "test-flow"
with pytest.raises(TooManyRequests):
await svc.on_request(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_parameters_json_parsed(self):
"""Parameters should be JSON-parsed before passing to invoke_tool."""
from trustgraph.base.tool_service import ToolService
svc = ToolService.__new__(ToolService)
svc.id = "test-tool"
received = {}
async def capture_invoke(name, params):
received["name"] = name
received["params"] = params
return "ok"
svc.invoke_tool = capture_invoke
if not hasattr(ToolService, "tool_invocation_metric"):
ToolService.tool_invocation_metric = MagicMock()
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow.producer = {"response": mock_pub}
flow.name = "f"
msg = MagicMock()
msg.value.return_value = ToolRequest(
name="search",
parameters='{"query": "test", "limit": 10}',
)
msg.properties.return_value = {"id": "t5"}
await svc.on_request(msg, MagicMock(), flow)
assert received["name"] == "search"
assert received["params"] == {"query": "test", "limit": 10}
# ---------------------------------------------------------------------------
# ToolServiceClient tests
# ---------------------------------------------------------------------------
class TestToolServiceClientCall:
@pytest.mark.asyncio
async def test_call_sends_request_and_returns_response(self):
"""call() should send ToolServiceRequest and return response string."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="joke result", end_of_stream=True,
))
result = await client.call(
user="alice",
config={"style": "pun"},
arguments={"topic": "cats"},
)
assert result == "joke result"
req = client.request.call_args[0][0]
assert isinstance(req, ToolServiceRequest)
assert req.user == "alice"
assert json.loads(req.config) == {"style": "pun"}
assert json.loads(req.arguments) == {"topic": "cats"}
@pytest.mark.asyncio
async def test_call_raises_on_error(self):
"""call() should raise RuntimeError when response has error."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=Error(type="tool-service-error", message="service down"),
response="",
))
with pytest.raises(RuntimeError, match="service down"):
await client.call(user="u", config={}, arguments={})
@pytest.mark.asyncio
async def test_call_empty_config_sends_empty_json(self):
"""Empty config/arguments should be sent as '{}'."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="u", config=None, arguments=None)
req = client.request.call_args[0][0]
assert req.config == "{}"
assert req.arguments == "{}"
@pytest.mark.asyncio
async def test_call_passes_timeout(self):
"""call() should forward timeout to underlying request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="u", config={}, arguments={}, timeout=30)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 30
class TestToolServiceClientStreaming:
@pytest.mark.asyncio
async def test_call_streaming_collects_chunks(self):
"""call_streaming should accumulate chunks and return full result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
# Simulate streaming: request() calls recipient with each chunk
chunks = [
ToolServiceResponse(error=None, response="chunk1", end_of_stream=False),
ToolServiceResponse(error=None, response="chunk2", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
)
assert result == "chunk1chunk2"
assert received == ["chunk1", "chunk2"]
@pytest.mark.asyncio
async def test_call_streaming_raises_on_error(self):
"""call_streaming should raise RuntimeError on error chunk."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
async def mock_request(req, timeout=600, recipient=None):
error_resp = ToolServiceResponse(
error=Error(type="tool-service-error", message="stream failed"),
response="",
end_of_stream=True,
)
await recipient(error_resp)
client.request = mock_request
with pytest.raises(RuntimeError, match="stream failed"):
await client.call_streaming(
user="u", config={}, arguments={},
callback=AsyncMock(),
)
@pytest.mark.asyncio
async def test_call_streaming_skips_empty_response(self):
"""Empty response chunks should not be added to result."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
chunks = [
ToolServiceResponse(error=None, response="", end_of_stream=False),
ToolServiceResponse(error=None, response="data", end_of_stream=True),
]
async def mock_request(req, timeout=600, recipient=None):
for chunk in chunks:
done = await recipient(chunk)
if done:
break
client.request = mock_request
received = []
async def callback(text):
received.append(text)
result = await client.call_streaming(
user="u", config={}, arguments={}, callback=callback,
)
# Empty response is falsy, so callback shouldn't be called for it
assert result == "data"
assert received == ["data"]
# ---------------------------------------------------------------------------
# Multi-tenancy
# ---------------------------------------------------------------------------
class TestMultiTenancy:
@pytest.mark.asyncio
async def test_user_propagated_to_invoke(self):
"""User from request should reach the invoke method."""
from trustgraph.base.dynamic_tool_service import DynamicToolService
svc = DynamicToolService.__new__(DynamicToolService)
svc.id = "test"
svc.producer = AsyncMock()
users_seen = []
async def tracking(user, config, arguments):
users_seen.append(user)
return "ok"
svc.invoke = tracking
if not hasattr(DynamicToolService, "tool_service_metric"):
DynamicToolService.tool_service_metric = MagicMock()
for u in ["tenant-a", "tenant-b", "tenant-c"]:
msg = MagicMock()
msg.value.return_value = ToolServiceRequest(
user=u, config="{}", arguments="{}",
)
msg.properties.return_value = {"id": f"req-{u}"}
await svc.on_request(msg, MagicMock(), None)
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
@pytest.mark.asyncio
async def test_client_sends_user_in_request(self):
"""ToolServiceClient.call should include user in request."""
from trustgraph.base.tool_service_client import ToolServiceClient
client = ToolServiceClient.__new__(ToolServiceClient)
client.request = AsyncMock(return_value=ToolServiceResponse(
error=None, response="ok",
))
await client.call(user="isolated-tenant", config={}, arguments={})
req = client.request.call_args[0][0]
assert req.user == "isolated-tenant"

View file

@ -23,27 +23,27 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
# Mock the request method
client.request = AsyncMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = await client.query(
vectors=vectors,
vector=vector,
limit=10,
user="test_user",
collection="test_collection",
timeout=30
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.request.assert_called_once()
call_args = client.request.call_args[0][0]
assert isinstance(call_args, DocumentEmbeddingsRequest)
assert call_args.vectors == vectors
assert call_args.vector == vector
assert call_args.limit == 10
assert call_args.user == "test_user"
assert call_args.collection == "test_collection"
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
# Act & Assert
with pytest.raises(RuntimeError, match="Database connection failed"):
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -76,12 +76,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = []
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -94,11 +94,11 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
client.request.assert_called_once()
@ -116,15 +116,15 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["chunk1"]
client.request = AsyncMock(return_value=mock_response)
# Act
await client.query(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=60
)
# Assert
assert client.request.call_args[1]["timeout"] == 60
@ -137,13 +137,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
mock_response.error = None
mock_response.chunks = ["test_chunk"]
client.request = AsyncMock(return_value=mock_response)
# Act
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
result = await client.query(vector=[0.1, 0.2, 0.3])
# Assert
mock_logger.debug.assert_called_once()
assert "Document embeddings response" in str(mock_logger.debug.call_args)

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text."""
metadata = Metadata(
id="test-doc-1",
metadata=[],
user="test-user",
collection="test-collection"
)
@ -44,7 +43,6 @@ def long_text_document():
"""Long document for testing multiple chunks."""
metadata = Metadata(
id="test-doc-long",
metadata=[],
user="test-user",
collection="test-collection"
)
@ -61,7 +59,6 @@ def unicode_text_document():
"""Document with various unicode characters."""
metadata = Metadata(
id="test-doc-unicode",
metadata=[],
user="test-user",
collection="test-collection"
)
@ -87,7 +84,6 @@ def empty_text_document():
"""Empty document for edge case testing."""
metadata = Metadata(
id="test-doc-empty",
metadata=[],
user="test-user",
collection="test-collection"
)

View file

@ -17,13 +17,17 @@ class MockAsyncProcessor:
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
self.pubsub = MagicMock()
self.taskgroup = params.get('taskgroup', MagicMock())
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
# Arrange
config = {
@ -47,8 +51,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
@ -79,8 +85,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
@ -111,8 +119,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
@ -143,9 +153,11 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
@ -164,26 +176,31 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"
mock_text_doc.document_id = "" # No librarian fetch needed
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 1500,
"chunk-overlap": 150,
"output": mock_producer
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
# Act
@ -202,8 +219,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {

View file

@ -17,13 +17,17 @@ class MockAsyncProcessor:
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
self.pubsub = MagicMock()
self.taskgroup = params.get('taskgroup', MagicMock())
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization"""
# Arrange
config = {
@ -47,8 +51,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
@ -79,8 +85,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
@ -111,8 +119,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
@ -143,9 +153,11 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
@ -164,26 +176,31 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions
processor.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"
mock_text_doc.document_id = "" # No librarian fetch needed
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
"output": mock_producer
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
# Act
@ -206,8 +223,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {
@ -235,8 +254,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self):
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
"""Test that token chunker has different defaults than recursive chunker"""
# Arrange & Act
config = {

View file

@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
# Act
result = client.request(
vectors=vectors,
vector=vector,
user="test_user",
collection="test_collection",
limit=10,
timeout=300
)
# Assert
assert result == ["chunk1", "chunk2", "chunk3"]
client.call.assert_called_once_with(
user="test_user",
collection="test_collection",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["test_chunk"]
client.call = MagicMock(return_value=mock_response)
vectors = [[0.1, 0.2, 0.3]]
vector = [0.1, 0.2, 0.3]
# Act
result = client.request(vectors=vectors)
result = client.request(vector=vector)
# Assert
assert result == ["test_chunk"]
client.call.assert_called_once_with(
user="trustgraph",
collection="default",
vectors=vectors,
vector=vector,
limit=10,
timeout=300
)
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = []
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result == []
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = None
client.call = MagicMock(return_value=mock_response)
# Act
result = client.request(vectors=[[0.1, 0.2, 0.3]])
result = client.request(vector=[0.1, 0.2, 0.3])
# Assert
assert result is None
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
mock_response = MagicMock()
mock_response.chunks = ["chunk1"]
client.call = MagicMock(return_value=mock_response)
# Act
client.request(
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
timeout=600
)
# Assert
assert client.call.call_args[1]["timeout"] == 600

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,286 @@
"""
Tests for Consumer concurrency: TaskGroup-based concurrent message processing,
rate-limit retry with backpressure, and message acknowledgement.
"""
import asyncio
import time
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.base.consumer import Consumer
from trustgraph.exceptions import TooManyRequests
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_consumer(
concurrency=1,
handler=None,
rate_limit_retry_time=0.01,
rate_limit_timeout=1,
):
"""Create a Consumer with mocked infrastructure."""
taskgroup = MagicMock()
flow = MagicMock()
backend = MagicMock()
schema = MagicMock()
handler = handler or AsyncMock()
consumer = Consumer(
taskgroup=taskgroup,
flow=flow,
backend=backend,
topic="test-topic",
subscriber="test-sub",
schema=schema,
handler=handler,
rate_limit_retry_time=rate_limit_retry_time,
rate_limit_timeout=rate_limit_timeout,
concurrency=concurrency,
)
return consumer
def _make_msg():
"""Create a mock Pulsar message."""
return MagicMock()
# ---------------------------------------------------------------------------
# Concurrency configuration tests
# ---------------------------------------------------------------------------
class TestConcurrencyConfiguration:
def test_default_concurrency_is_1(self):
consumer = _make_consumer()
assert consumer.concurrency == 1
def test_custom_concurrency(self):
consumer = _make_consumer(concurrency=10)
assert consumer.concurrency == 10
def test_concurrency_stored(self):
for n in [1, 5, 20, 100]:
consumer = _make_consumer(concurrency=n)
assert consumer.concurrency == n
class TestTaskGroupConcurrency:
@pytest.mark.asyncio
async def test_creates_n_concurrent_tasks(self):
"""consumer_run should create exactly N concurrent consume_from_queue tasks."""
concurrency = 5
consumer = _make_consumer(concurrency=concurrency)
# Track how many consume_from_queue calls are made
call_count = 0
original_running = True
async def mock_consume():
nonlocal call_count
call_count += 1
# Wait a bit to let all tasks start, then signal stop
await asyncio.sleep(0.05)
consumer.running = False
consumer.consume_from_queue = mock_consume
# Mock the backend.create_consumer
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
# Run consumer_run - it will create TaskGroup with N tasks
consumer.running = True
await consumer.consumer_run()
assert call_count == concurrency
@pytest.mark.asyncio
async def test_single_concurrency_creates_one_task(self):
"""With concurrency=1, only one consume_from_queue task is created."""
consumer = _make_consumer(concurrency=1)
call_count = 0
async def mock_consume():
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
consumer.running = False
consumer.consume_from_queue = mock_consume
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
consumer.running = True
await consumer.consumer_run()
assert call_count == 1
# ---------------------------------------------------------------------------
# Rate-limit retry tests
# ---------------------------------------------------------------------------
class TestRateLimitRetry:
@pytest.mark.asyncio
async def test_rate_limit_retries_then_succeeds(self):
"""TooManyRequests should cause retry, then succeed on next attempt."""
call_count = 0
async def handler_with_retry(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("rate limited")
# Second call succeeds
consumer = _make_consumer(
handler=handler_with_retry,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_rate_limit_timeout_negative_acks(self):
"""If rate limit retries exhaust the timeout, message is negative-acked."""
async def always_rate_limited(msg, consumer_ref, flow):
raise TooManyRequests("rate limited")
consumer = _make_consumer(
handler=always_rate_limited,
rate_limit_retry_time=0.01,
rate_limit_timeout=0.05,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called()
@pytest.mark.asyncio
async def test_non_rate_limit_error_negative_acks_immediately(self):
"""Non-TooManyRequests errors should negative-ack immediately (no retry)."""
call_count = 0
async def failing_handler(msg, consumer_ref, flow):
nonlocal call_count
call_count += 1
raise ValueError("bad data")
consumer = _make_consumer(handler=failing_handler)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@pytest.mark.asyncio
async def test_successful_message_acknowledged(self):
"""Successfully processed messages are acknowledged."""
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
# ---------------------------------------------------------------------------
# Metrics integration
# ---------------------------------------------------------------------------
class TestMetricsIntegration:
@pytest.mark.asyncio
async def test_success_metric_on_success(self):
consumer = _make_consumer(handler=AsyncMock())
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.process.assert_called_once_with("success")
@pytest.mark.asyncio
async def test_error_metric_on_failure(self):
async def failing(msg, c, f):
raise ValueError("fail")
consumer = _make_consumer(handler=failing)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.process.assert_called_once_with("error")
@pytest.mark.asyncio
async def test_rate_limit_metric_on_too_many_requests(self):
call_count = 0
async def handler(msg, c, f):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TooManyRequests("limited")
consumer = _make_consumer(
handler=handler,
rate_limit_retry_time=0.01,
)
mock_msg = _make_msg()
consumer.consumer = MagicMock()
mock_metrics = MagicMock()
mock_metrics.record_time.return_value.__enter__ = MagicMock()
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg)
mock_metrics.rate_limit.assert_called_once()
# ---------------------------------------------------------------------------
# Stop / running flag
# ---------------------------------------------------------------------------
class TestStopBehaviour:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self):
consumer = _make_consumer()
consumer.running = True
await consumer.stop()
assert consumer.running is False
def test_initial_running_state(self):
consumer = _make_consumer()
assert consumer.running is True

View file

@ -0,0 +1,136 @@
"""
Tests for MessageDispatcher semaphore-based concurrency enforcement.
Verifies that the dispatcher limits concurrent message processing to
max_workers via asyncio.Semaphore.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
class TestSemaphoreEnforcement:
@pytest.mark.asyncio
async def test_semaphore_limits_concurrent_processing(self):
"""Only max_workers messages should be processed concurrently."""
max_workers = 2
dispatcher = MessageDispatcher(max_workers=max_workers)
concurrent_count = 0
max_concurrent = 0
processing_event = asyncio.Event()
async def slow_process(message):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.05)
concurrent_count -= 1
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = slow_process
# Launch more tasks than max_workers
messages = [
{"id": f"msg-{i}", "service": "test", "request": {}}
for i in range(5)
]
tasks = [
asyncio.create_task(dispatcher.handle_message(m))
for m in messages
]
await asyncio.gather(*tasks)
# At no point should more than max_workers have been active
assert max_concurrent <= max_workers
@pytest.mark.asyncio
async def test_semaphore_value_matches_max_workers(self):
for n in [1, 5, 20]:
dispatcher = MessageDispatcher(max_workers=n)
assert dispatcher.semaphore._value == n
@pytest.mark.asyncio
async def test_active_tasks_tracked(self):
"""Active tasks should be added/removed during processing."""
dispatcher = MessageDispatcher(max_workers=5)
task_was_tracked = False
original_process = dispatcher._process_message
async def tracking_process(message):
nonlocal task_was_tracked
# During processing, our task should be in active_tasks
if len(dispatcher.active_tasks) > 0:
task_was_tracked = True
return {"id": message.get("id"), "response": {"ok": True}}
dispatcher._process_message = tracking_process
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
)
assert task_was_tracked
# After completion, task should be discarded
assert len(dispatcher.active_tasks) == 0
@pytest.mark.asyncio
async def test_semaphore_released_on_error(self):
"""Semaphore should be released even if processing raises."""
dispatcher = MessageDispatcher(max_workers=2)
async def failing_process(message):
raise RuntimeError("process failed")
dispatcher._process_message = failing_process
# Should not deadlock — semaphore must be released on error
with pytest.raises(RuntimeError):
await dispatcher.handle_message(
{"id": "test", "service": "test", "request": {}}
)
# Semaphore should be back at max
assert dispatcher.semaphore._value == 2
@pytest.mark.asyncio
async def test_single_worker_serializes_processing(self):
"""With max_workers=1, messages are processed one at a time."""
dispatcher = MessageDispatcher(max_workers=1)
order = []
async def ordered_process(message):
msg_id = message["id"]
order.append(f"start-{msg_id}")
await asyncio.sleep(0.02)
order.append(f"end-{msg_id}")
return {"id": msg_id, "response": {"ok": True}}
dispatcher._process_message = ordered_process
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
await asyncio.gather(*tasks)
# With semaphore=1, each message should complete before next starts
# Check that no two "start" entries appear without an intervening "end"
active = 0
max_active = 0
for event in order:
if event.startswith("start"):
active += 1
max_active = max(max_active, active)
elif event.startswith("end"):
active -= 1
assert max_active == 1

View file

@ -0,0 +1,268 @@
"""
Tests for Graph RAG concurrent query execution.
Covers: execute_batch_triple_queries concurrent task spawning,
exception handling in gather, and result aggregation.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_query(
triples_client=None,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=2,
):
"""Create a Query object with mocked rag dependencies."""
rag = MagicMock()
rag.triples_client = triples_client or AsyncMock()
rag.label_cache = LRUCacheWithTTL()
query = Query(
rag=rag,
user="test-user",
collection="test-collection",
verbose=False,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
)
return query
def _make_triple(s, p, o):
"""Create a simple mock triple."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBatchTripleQueries:
@pytest.mark.asyncio
async def test_three_queries_per_entity(self):
"""Each entity should generate 3 concurrent queries (s, p, o positions)."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["entity-1"]
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
assert client.query_stream.call_count == 3
@pytest.mark.asyncio
async def test_multiple_entities_multiply_queries(self):
"""N entities should produce N*3 concurrent queries."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["e1", "e2", "e3"]
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
assert client.query_stream.call_count == 9 # 3 * 3
@pytest.mark.asyncio
async def test_queries_executed_concurrently(self):
"""All queries should run concurrently via asyncio.gather."""
concurrent_count = 0
max_concurrent = 0
async def tracking_query(**kwargs):
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.02)
concurrent_count -= 1
return []
client = AsyncMock()
client.query_stream = tracking_query
query = _make_query(triples_client=client)
entities = ["e1", "e2", "e3"]
await query.execute_batch_triple_queries(entities, limit_per_entity=5)
# All 9 queries should have run concurrently
assert max_concurrent == 9
@pytest.mark.asyncio
async def test_results_aggregated(self):
"""Results from all queries should be combined into a single list."""
triple_a = _make_triple("a", "p", "b")
triple_b = _make_triple("c", "p", "d")
call_count = 0
async def alternating_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count % 2 == 0:
return [triple_a]
return [triple_b]
client = AsyncMock()
client.query_stream = alternating_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries, alternating results
assert len(result) == 3
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
"""If one query raises, other results are still collected."""
good_triple = _make_triple("a", "p", "b")
call_count = 0
async def mixed_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
raise RuntimeError("query failed")
return [good_triple]
client = AsyncMock()
client.query_stream = mixed_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
@pytest.mark.asyncio
async def test_none_results_filtered(self):
"""None results from queries should be filtered out."""
call_count = 0
async def sometimes_none(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return None
return [_make_triple("a", "p", "b")]
client = AsyncMock()
client.query_stream = sometimes_none
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
"""Empty entity list should produce no queries."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries([], limit_per_entity=10)
assert result == []
client.query_stream.assert_not_called()
@pytest.mark.asyncio
async def test_query_params_correct(self):
"""Each query should use correct s/p/o positions and params."""
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[])
query = _make_query(triples_client=client)
entities = ["ent-1"]
await query.execute_batch_triple_queries(entities, limit_per_entity=15)
calls = client.query_stream.call_args_list
assert len(calls) == 3
# First call: s=entity, p=None, o=None
assert calls[0].kwargs["s"] == "ent-1"
assert calls[0].kwargs["p"] is None
assert calls[0].kwargs["o"] is None
assert calls[0].kwargs["limit"] == 15
assert calls[0].kwargs["user"] == "test-user"
assert calls[0].kwargs["collection"] == "test-collection"
assert calls[0].kwargs["batch_size"] == 20
# Second call: s=None, p=entity, o=None
assert calls[1].kwargs["s"] is None
assert calls[1].kwargs["p"] == "ent-1"
assert calls[1].kwargs["o"] is None
# Third call: s=None, p=None, o=entity
assert calls[2].kwargs["s"] is None
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
class TestLRUCacheWithTTL:
def test_put_and_get(self):
cache = LRUCacheWithTTL(max_size=10, ttl=60)
cache.put("key1", "value1")
assert cache.get("key1") == "value1"
def test_get_missing_returns_none(self):
cache = LRUCacheWithTTL()
assert cache.get("nonexistent") is None
def test_max_size_eviction(self):
cache = LRUCacheWithTTL(max_size=2, ttl=60)
cache.put("a", 1)
cache.put("b", 2)
cache.put("c", 3) # Should evict "a"
assert cache.get("a") is None
assert cache.get("b") == 2
assert cache.get("c") == 3
def test_lru_order(self):
cache = LRUCacheWithTTL(max_size=2, ttl=60)
cache.put("a", 1)
cache.put("b", 2)
cache.get("a") # Access "a" — now "b" is LRU
cache.put("c", 3) # Should evict "b"
assert cache.get("a") == 1
assert cache.get("b") is None
assert cache.get("c") == 3
def test_ttl_expiration(self):
cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry
cache.put("key", "value")
# With TTL=0, any time check > 0 means expired
import time
time.sleep(0.01)
assert cache.get("key") is None
def test_update_existing_key(self):
cache = LRUCacheWithTTL(max_size=10, ttl=60)
cache.put("key", "v1")
cache.put("key", "v2")
assert cache.get("key") == "v2"

View file

@ -73,7 +73,6 @@ def sample_triples():
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
metadata=[]
),
triples=[
Triple(
@ -93,12 +92,11 @@ def sample_graph_embeddings():
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
metadata=[]
),
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]]
vector=[0.1, 0.2, 0.3]
)
]
)

View file

@ -12,218 +12,184 @@ from trustgraph.decoding.pdf.pdf_decoder import Processor
from trustgraph.schema import Document, TextDocument, Metadata
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
self.pubsub = MagicMock()
self.taskgroup = params.get('taskgroup', MagicMock())
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"""Test PDF decoder processor functionality"""
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_processor_initialization(self, mock_flow_init):
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
"""Test PDF decoder processor initialization"""
# Arrange
mock_flow_init.return_value = None
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
# Act
with patch.object(Processor, 'register_specification') as mock_register:
processor = Processor(**config)
processor = Processor(**config)
# Assert
mock_flow_init.assert_called_once()
# Verify register_specification was called twice (consumer and producer)
assert mock_register.call_count == 2
# Check consumer spec
consumer_call = mock_register.call_args_list[0]
consumer_spec = consumer_call[0][0]
assert consumer_spec.name == "input"
assert consumer_spec.schema == Document
assert consumer_spec.handler == processor.on_message
# Check producer spec
producer_call = mock_register.call_args_list[1]
producer_spec = producer_call[0][0]
assert producer_spec.name == "output"
assert producer_spec.schema == TextDocument
consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')]
assert len(consumer_specs) >= 1
assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_on_message_success(self, mock_flow_init, mock_pdf_loader_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
"""Test successful PDF processing"""
# Arrange
mock_flow_init.return_value = None
# Mock PDF content
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
mock_loader = MagicMock()
mock_page1 = MagicMock(page_content="Page 1 content")
mock_page2 = MagicMock(page_content="Page 2 content")
mock_loader.load.return_value = [mock_page1, mock_page2]
mock_pdf_loader_class.return_value = mock_loader
# Mock message
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
# Mock flow - needs to be a callable that returns an object with send method
# Mock flow - separate mocks for output and triples
mock_output_flow = AsyncMock()
mock_flow = MagicMock(return_value=mock_output_flow)
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
with patch.object(Processor, 'register_specification'):
processor = Processor(**config)
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
# Act
await processor.on_message(mock_msg, None, mock_flow)
# Assert
# Verify PyPDFLoader was called
mock_pdf_loader_class.assert_called_once()
mock_loader.load.assert_called_once()
# Verify output was sent for each page
assert mock_output_flow.send.call_count == 2
# Check first page output
first_call = mock_output_flow.send.call_args_list[0]
first_output = first_call[0][0]
assert isinstance(first_output, TextDocument)
assert first_output.metadata == mock_metadata
assert first_output.text == b"Page 1 content"
# Check second page output
second_call = mock_output_flow.send.call_args_list[1]
second_output = second_call[0][0]
assert isinstance(second_output, TextDocument)
assert second_output.metadata == mock_metadata
assert second_output.text == b"Page 2 content"
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_on_message_empty_pdf(self, mock_flow_init, mock_pdf_loader_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
"""Test handling of empty PDF"""
# Arrange
mock_flow_init.return_value = None
# Mock PDF content
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader with no pages
mock_loader = MagicMock()
mock_loader.load.return_value = []
mock_pdf_loader_class.return_value = mock_loader
# Mock message
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
# Mock flow - needs to be a callable that returns an object with send method
mock_output_flow = AsyncMock()
mock_flow = MagicMock(return_value=mock_output_flow)
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
with patch.object(Processor, 'register_specification'):
processor = Processor(**config)
processor = Processor(**config)
# Act
await processor.on_message(mock_msg, None, mock_flow)
# Assert
# Verify PyPDFLoader was called
mock_pdf_loader_class.assert_called_once()
mock_loader.load.assert_called_once()
# Verify no output was sent
mock_output_flow.send.assert_not_called()
@patch('trustgraph.base.chunking_service.Consumer')
@patch('trustgraph.base.chunking_service.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
async def test_on_message_unicode_content(self, mock_flow_init, mock_pdf_loader_class):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
"""Test handling of unicode content in PDF"""
# Arrange
mock_flow_init.return_value = None
# Mock PDF content
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader with unicode content
mock_loader = MagicMock()
mock_page = MagicMock(page_content="Page with unicode: 你好世界 🌍")
mock_loader.load.return_value = [mock_page]
mock_pdf_loader_class.return_value = mock_loader
# Mock message
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
# Mock flow - needs to be a callable that returns an object with send method
# Mock flow - separate mocks for output and triples
mock_output_flow = AsyncMock()
mock_flow = MagicMock(return_value=mock_output_flow)
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
with patch.object(Processor, 'register_specification'):
processor = Processor(**config)
processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
# Act
await processor.on_message(mock_msg, None, mock_flow)
# Assert
# Verify output was sent
mock_output_flow.send.assert_called_once()
# Check output
call_args = mock_output_flow.send.call_args[0][0]
assert isinstance(call_args, TextDocument)
assert call_args.text == "Page with unicode: 你好世界 🌍".encode('utf-8')
# PDF decoder now forwards document_id, chunker fetches content from librarian
assert call_args.document_id == "test-doc/p1"
assert call_args.text == b"" # Content stored in librarian, not inline
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args):
"""Test add_args calls parent method"""
# Arrange
mock_parser = MagicMock()
# Act
Processor.add_args(mock_parser)
# Assert
mock_parent_add_args.assert_called_once_with(mock_parser)
@patch('trustgraph.decoding.pdf.pdf_decoder.Processor.launch')
def test_run(self, mock_launch):
"""Test run function"""
# Act
from trustgraph.decoding.pdf.pdf_decoder import run
run()
# Assert
mock_launch.assert_called_once_with("pdf-decoder",
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n")
mock_launch.assert_called_once_with("pdf-decoder",
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n\nSupports both inline document data and fetching from librarian via Pulsar\nfor large documents.\n")
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])

View file

@ -305,9 +305,8 @@ class TestEntityCentricKnowledgeGraph:
mock_session.execute.assert_called()
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
"""Test that g='*' returns quads from all graphs"""
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
def test_graph_none_returns_all_graphs(self, entity_kg):
"""Test that g=None returns quads from all graphs"""
kg, mock_session = entity_kg
mock_result = [
@ -320,7 +319,7 @@ class TestEntityCentricKnowledgeGraph:
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
results = kg.get_s('test_collection', 'http://example.org/Alice', g=None)
# Should return quads from both graphs
assert len(results) == 2
@ -547,21 +546,21 @@ class TestServiceHelperFunctions:
"""Test cases for helper functions in service.py"""
def test_create_term_with_uri_otype(self):
"""Test create_term creates IRI Term for otype='u'"""
"""Test create_term creates IRI Term for term_type='u'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice', otype='u')
term = create_term('http://example.org/Alice', term_type='u')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_with_literal_otype(self):
"""Test create_term creates LITERAL Term for otype='l'"""
"""Test create_term creates LITERAL Term for term_type='l'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en')
assert term.type == LITERAL
assert term.value == 'Alice Smith'
@ -569,14 +568,24 @@ class TestServiceHelperFunctions:
assert term.language == 'en'
def test_create_term_with_triple_otype(self):
"""Test create_term creates IRI Term for otype='t'"""
"""Test create_term creates TRIPLE Term for term_type='t' with valid JSON"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
from trustgraph.schema import TRIPLE, IRI
import json
term = create_term('http://example.org/statement1', otype='t')
# Valid JSON triple data
triple_json = json.dumps({
"s": {"type": "i", "iri": "http://example.org/Alice"},
"p": {"type": "i", "iri": "http://example.org/knows"},
"o": {"type": "i", "iri": "http://example.org/Bob"},
})
assert term.type == IRI
assert term.iri == 'http://example.org/statement1'
term = create_term(triple_json, term_type='t')
assert term.type == TRIPLE
assert term.triple is not None
assert term.triple.s.type == IRI
assert term.triple.s.iri == "http://example.org/Alice"
def test_create_term_heuristic_fallback_uri(self):
"""Test create_term uses URL heuristic when otype not provided"""

View file

@ -0,0 +1,441 @@
"""
Tests for entity-centric KG write amplification, delete collection batching,
in-partition filtering, and term type metadata round-trips.
Complements test_entity_centric_kg.py with deeper verification of the
2-table schema mechanics.
"""
import pytest
from unittest.mock import MagicMock, patch, call
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_cassandra():
"""Provide mocked Cassandra cluster, session, and BatchStatement."""
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \
patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls:
mock_cluster = MagicMock()
mock_session = MagicMock()
mock_cluster.connect.return_value = mock_session
mock_cls.return_value = mock_cluster
# Track batch.add calls per batch instance
batches = []
def make_batch():
batch = MagicMock()
batch._adds = []
original_add = batch.add
def tracking_add(stmt, params):
batch._adds.append((stmt, params))
batch.add = tracking_add
batches.append(batch)
return batch
mock_batch_cls.side_effect = make_batch
yield {
"cluster_cls": mock_cls,
"cluster": mock_cluster,
"session": mock_session,
"batch_cls": mock_batch_cls,
"batches": batches,
}
@pytest.fixture
def entity_kg(mock_cassandra):
"""Create an EntityCentricKnowledgeGraph with mocked Cassandra."""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks')
return kg, mock_cassandra
# ---------------------------------------------------------------------------
# Write amplification: row count verification
# ---------------------------------------------------------------------------
class TestWriteAmplification:
def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg):
"""URI object → S + P + O + G-if-non-default entity rows + 1 collection row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/Alice',
p='http://ex.org/knows',
o='http://ex.org/Bob',
g='http://ex.org/g1',
otype='u',
)
# Should be exactly one batch
assert len(ctx["batches"]) == 1
batch = ctx["batches"][0]
# 4 entity rows (S, P, O, G) + 1 collection row = 5
assert len(batch._adds) == 5
# Check roles assigned
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'S' in roles
assert 'P' in roles
assert 'O' in roles
assert 'G' in roles
def test_literal_object_produces_3_entity_rows(self, entity_kg):
"""Literal object → S + P entity rows (no O row) + collection row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/Alice',
p='http://ex.org/name',
o='Alice Smith',
g=None, # default graph
otype='l',
)
batch = ctx["batches"][0]
# S + P entity rows + 1 collection = 3 (no O row for literal, no G for default)
assert len(batch._adds) == 3
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'S' in roles
assert 'P' in roles
assert 'O' not in roles
assert 'G' not in roles
def test_triple_otype_gets_object_entity_row(self, entity_kg):
"""otype='t' (quoted triple) → object gets entity row like URI."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='{"s":{},"p":{},"o":{}}',
g=None,
otype='t',
)
batch = ctx["batches"][0]
# S + P + O entity rows + collection = 4 (no G for default graph)
assert len(batch._adds) == 4
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'O' in roles
def test_default_graph_no_g_row(self, entity_kg):
"""Default graph (g=None) → no G entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
g=None,
otype='u',
)
batch = ctx["batches"][0]
# S + P + O entity rows + collection = 4 (no G)
assert len(batch._adds) == 4
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'G' not in roles
def test_non_default_graph_gets_g_row(self, entity_kg):
"""Non-default graph → gets G entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
g='http://ex.org/graph1',
otype='u',
)
batch = ctx["batches"][0]
# S + P + O + G entity rows + collection = 5
assert len(batch._adds) == 5
roles = [params[2] for _, params in batch._adds if len(params) == 10]
assert 'G' in roles
def test_dtype_and_lang_passed_to_all_rows(self, entity_kg):
"""dtype and lang should be stored in every entity row."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/label',
o='thing',
g=None,
otype='l',
dtype='xsd:string',
lang='en',
)
batch = ctx["batches"][0]
# Check entity rows carry dtype and lang
for _, params in batch._adds:
if len(params) == 10:
# Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang)
assert params[8] == 'xsd:string'
assert params[9] == 'en'
# ---------------------------------------------------------------------------
# In-partition filtering: get_os, get_spo
# ---------------------------------------------------------------------------
class TestInPartitionFiltering:
def test_get_os_filters_by_object(self, entity_kg):
"""get_os should filter results by matching object value."""
kg, ctx = entity_kg
# Simulate rows returned from subject partition (all have same s)
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice')
# Only the Bob row should pass the filter
assert len(results) == 1
assert results[0].o == 'http://ex.org/Bob'
assert results[0].p == 'http://ex.org/knows'
def test_get_os_returns_empty_when_no_match(self, entity_kg):
"""get_os should return empty list when object doesn't match any row."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice')
assert len(results) == 0
def test_get_spo_filters_by_object(self, entity_kg):
"""get_spo should filter results by matching object value."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''),
MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_spo(
'col', 'http://ex.org/Alice', 'http://ex.org/knows',
'http://ex.org/Bob',
)
assert len(results) == 1
assert results[0].o == 'http://ex.org/Bob'
def test_get_os_with_graph_filter(self, entity_kg):
"""get_os with specific graph should filter both object and graph."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='http://ex.org/g1', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
d='http://ex.org/g2', otype='u', dtype='', lang='',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_os(
'col', 'http://ex.org/Bob', 'http://ex.org/Alice',
g='http://ex.org/g1',
)
assert len(results) == 1
assert results[0].g == 'http://ex.org/g1'
# ---------------------------------------------------------------------------
# Delete collection batching
# ---------------------------------------------------------------------------
class TestDeleteCollectionBatching:
def test_extracts_unique_entities_from_quads(self, entity_kg):
"""delete_collection should extract s, p, and URI o as entities."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows',
o='http://ex.org/B', otype='u', dtype='', lang=''),
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
o='Alice', otype='l', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
# Unique entities: A, knows, B, name (literal 'Alice' excluded)
# The batches should include entity partition deletes
all_adds = []
for batch in ctx["batches"]:
all_adds.extend(batch._adds)
# We expect entity deletes + collection row deletes + metadata delete
# Just verify the function completes and calls execute
assert ctx["session"].execute.called
def test_literal_objects_not_treated_as_entities(self, entity_kg):
"""Literal objects (otype='l') should not get entity partition deletes."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
o='Alice', otype='l', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
# Entity partition deletes should only include A and name, not Alice
entity_deletes = []
for batch in ctx["batches"]:
for _, params in batch._adds:
if len(params) == 2: # delete_entity_partition takes (collection, entity)
entity_deletes.append(params[1])
assert 'http://ex.org/A' in entity_deletes
assert 'http://ex.org/name' in entity_deletes
assert 'Alice' not in entity_deletes
def test_non_default_graph_treated_as_entity(self, entity_kg):
"""Non-default graphs should get entity partition deletes."""
kg, ctx = entity_kg
mock_rows = [
MagicMock(d='http://ex.org/g1', s='http://ex.org/A',
p='http://ex.org/p', o='http://ex.org/B',
otype='u', dtype='', lang=''),
]
ctx["session"].execute.return_value = mock_rows
ctx["batches"].clear()
kg.delete_collection('col')
entity_deletes = []
for batch in ctx["batches"]:
for _, params in batch._adds:
if len(params) == 2:
entity_deletes.append(params[1])
assert 'http://ex.org/g1' in entity_deletes
def test_empty_collection_delete_completes(self, entity_kg):
"""Deleting an empty collection should not error."""
kg, ctx = entity_kg
ctx["session"].execute.return_value = []
ctx["batches"].clear()
# Should not raise
kg.delete_collection('empty-col')
# ---------------------------------------------------------------------------
# Term type metadata round-trip
# ---------------------------------------------------------------------------
class TestTermTypeMetadata:
def test_query_results_include_otype(self, entity_kg):
"""Query results should include otype from Cassandra rows."""
kg, ctx = entity_kg
from trustgraph.direct.cassandra_kg import QuadResult
mock_rows = [
MagicMock(p='http://ex.org/name', o='Alice',
d='', otype='l', dtype='xsd:string', lang='en',
s='http://ex.org/Alice'),
]
ctx["session"].execute.return_value = mock_rows
results = kg.get_s('col', 'http://ex.org/Alice')
assert len(results) == 1
assert results[0].otype == 'l'
assert results[0].dtype == 'xsd:string'
assert results[0].lang == 'en'
def test_auto_detect_otype_uri(self, entity_kg):
"""Auto-detect should classify http:// as URI."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='http://ex.org/o',
)
batch = ctx["batches"][0]
# Check otype in entity rows (position 4)
for _, params in batch._adds:
if len(params) == 10:
assert params[4] == 'u'
def test_auto_detect_otype_literal(self, entity_kg):
"""Auto-detect should classify non-http:// as literal."""
kg, ctx = entity_kg
ctx["batches"].clear()
kg.insert(
collection='col',
s='http://ex.org/s',
p='http://ex.org/p',
o='plain text',
)
batch = ctx["batches"][0]
for _, params in batch._adds:
if len(params) == 10:
assert params[4] == 'l'

View file

@ -0,0 +1,164 @@
"""
Tests for document embeddings processor single-chunk embedding via batch API.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.embeddings.document_embeddings.embeddings import Processor
from trustgraph.schema import (
Chunk, DocumentEmbeddings, ChunkEmbeddings,
EmbeddingsRequest, EmbeddingsResponse, Metadata,
)
@pytest.fixture
def processor():
return Processor(
taskgroup=AsyncMock(),
id="test-doc-embeddings",
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
return msg
class TestDocumentEmbeddingsProcessor:
@pytest.mark.asyncio
async def test_sends_single_text_as_list(self, processor):
"""Document embeddings should wrap single chunk in a list for the API."""
msg = _make_chunk_message("test chunk text")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.1, 0.2, 0.3]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# Should send EmbeddingsRequest with texts=[chunk]
mock_request.assert_called_once()
req = mock_request.call_args[0][0]
assert isinstance(req, EmbeddingsRequest)
assert req.texts == ["test chunk text"]
@pytest.mark.asyncio
async def test_extracts_first_vector(self, processor):
"""Should use vectors[0] from the response."""
msg = _make_chunk_message("chunk")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[1.0, 2.0, 3.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert isinstance(result, DocumentEmbeddings)
assert len(result.chunks) == 1
assert result.chunks[0].vector == [1.0, 2.0, 3.0]
@pytest.mark.asyncio
async def test_empty_vectors_response(self, processor):
"""Should handle empty vectors response gracefully."""
msg = _make_chunk_message("chunk")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.chunks[0].vector == []
@pytest.mark.asyncio
async def test_chunk_id_is_document_id(self, processor):
"""ChunkEmbeddings should use document_id as chunk_id."""
msg = _make_chunk_message(doc_id="my-doc-42")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.chunks[0].chunk_id == "my-doc-42"
@pytest.mark.asyncio
async def test_metadata_preserved(self, processor):
"""Output should carry the original metadata."""
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
mock_request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]]
))
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"
@pytest.mark.asyncio
async def test_error_propagates(self, processor):
"""Embedding errors should propagate for retry."""
msg = _make_chunk_message()
mock_request = AsyncMock(side_effect=RuntimeError("service down"))
def flow(name):
if name == "embeddings-request":
return MagicMock(request=mock_request)
return MagicMock()
with pytest.raises(RuntimeError, match="service down"):
await processor.on_message(msg, MagicMock(), flow)

View file

@ -0,0 +1,109 @@
"""
Tests for EmbeddingsClient the client interface for batch embeddings.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base.embeddings_client import EmbeddingsClient
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
class TestEmbeddingsClient:
@pytest.mark.asyncio
async def test_embed_sends_request_and_returns_vectors(self):
"""embed() should send an EmbeddingsRequest and return vectors."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[0.1, 0.2], [0.3, 0.4]],
))
result = await client.embed(texts=["hello", "world"])
assert result == [[0.1, 0.2], [0.3, 0.4]]
client.request.assert_called_once()
req = client.request.call_args[0][0]
assert isinstance(req, EmbeddingsRequest)
assert req.texts == ["hello", "world"]
@pytest.mark.asyncio
async def test_embed_single_text(self):
"""embed() should work with a single text."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None,
vectors=[[1.0, 2.0, 3.0]],
))
result = await client.embed(texts=["single"])
assert result == [[1.0, 2.0, 3.0]]
@pytest.mark.asyncio
async def test_embed_raises_on_error_response(self):
"""embed() should raise RuntimeError when response contains an error."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=Error(type="embeddings-error", message="model not found"),
vectors=[],
))
with pytest.raises(RuntimeError, match="model not found"):
await client.embed(texts=["test"])
@pytest.mark.asyncio
async def test_embed_passes_timeout(self):
"""embed() should pass timeout to the underlying request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"], timeout=60)
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 60
@pytest.mark.asyncio
async def test_embed_default_timeout(self):
"""embed() should use 300s default timeout."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[[0.0]],
))
await client.embed(texts=["test"])
_, kwargs = client.request.call_args
assert kwargs["timeout"] == 300
@pytest.mark.asyncio
async def test_embed_empty_texts(self):
"""embed() with empty list should still make the request."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=[],
))
result = await client.embed(texts=[])
assert result == []
@pytest.mark.asyncio
async def test_embed_large_batch(self):
"""embed() should handle large batches."""
client = EmbeddingsClient.__new__(EmbeddingsClient)
n = 100
vectors = [[float(i)] for i in range(n)]
client.request = AsyncMock(return_value=EmbeddingsResponse(
error=None, vectors=vectors,
))
texts = [f"text {i}" for i in range(n)]
result = await client.embed(texts=texts)
assert len(result) == n
req = client.request.call_args[0][0]
assert len(req.texts) == n

View file

@ -0,0 +1,135 @@
"""
Tests for EmbeddingsService.on_request the request handler that dispatches
to on_embeddings and sends responses.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.base import EmbeddingsService
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
from trustgraph.exceptions import TooManyRequests
class StubEmbeddingsService(EmbeddingsService):
"""Minimal concrete implementation for testing on_request."""
def __init__(self, embed_result=None, embed_error=None):
# Skip super().__init__ to avoid taskgroup/registration
self.embed_result = embed_result or [[0.1, 0.2]]
self.embed_error = embed_error
async def on_embeddings(self, texts, model=None):
if self.embed_error:
raise self.embed_error
return self.embed_result
def _make_msg(texts, msg_id="req-1"):
request = EmbeddingsRequest(texts=texts)
msg = MagicMock()
msg.value.return_value = request
msg.properties.return_value = {"id": msg_id}
return msg
def _make_flow(model="test-model"):
mock_response_producer = AsyncMock()
mock_flow = MagicMock()
def flow_callable(name):
if name == "model":
return model
if name == "response":
return mock_response_producer
return MagicMock()
flow_callable.producer = {"response": mock_response_producer}
return flow_callable, mock_response_producer
class TestEmbeddingsServiceOnRequest:
@pytest.mark.asyncio
async def test_successful_request(self):
"""on_request should call on_embeddings and send response."""
service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]])
msg = _make_msg(["hello", "world"], msg_id="r1")
flow, mock_response = _make_flow(model="my-model")
await service.on_request(msg, MagicMock(), flow)
mock_response.send.assert_called_once()
resp = mock_response.send.call_args[0][0]
assert isinstance(resp, EmbeddingsResponse)
assert resp.error is None
assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]]
# Check id is passed through
props = mock_response.send.call_args[1]["properties"]
assert props["id"] == "r1"
@pytest.mark.asyncio
async def test_passes_model_from_flow(self):
"""on_request should pass model parameter from flow to on_embeddings."""
calls = []
class TrackingService(EmbeddingsService):
def __init__(self):
pass
async def on_embeddings(self, texts, model=None):
calls.append({"texts": texts, "model": model})
return [[0.0]]
service = TrackingService()
msg = _make_msg(["test"])
flow, _ = _make_flow(model="custom-model-v2")
await service.on_request(msg, MagicMock(), flow)
assert len(calls) == 1
assert calls[0]["model"] == "custom-model-v2"
assert calls[0]["texts"] == ["test"]
@pytest.mark.asyncio
async def test_error_sends_error_response(self):
"""Non-rate-limit errors should send an error response."""
service = StubEmbeddingsService(
embed_error=ValueError("dimension mismatch")
)
msg = _make_msg(["test"], msg_id="r2")
flow, mock_response = _make_flow()
await service.on_request(msg, MagicMock(), flow)
mock_response.send.assert_called_once()
resp = mock_response.send.call_args[0][0]
assert resp.error is not None
assert resp.error.type == "embeddings-error"
assert "dimension mismatch" in resp.error.message
assert resp.vectors == []
@pytest.mark.asyncio
async def test_rate_limit_propagates(self):
"""TooManyRequests should propagate (not caught as error response)."""
service = StubEmbeddingsService(
embed_error=TooManyRequests("rate limited")
)
msg = _make_msg(["test"])
flow, _ = _make_flow()
with pytest.raises(TooManyRequests):
await service.on_request(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_message_id_preserved(self):
"""The request message id should be forwarded in the response properties."""
service = StubEmbeddingsService()
msg = _make_msg(["test"], msg_id="unique-id-42")
flow, mock_response = _make_flow()
await service.on_request(msg, MagicMock(), flow)
props = mock_response.send.call_args[1]["properties"]
assert props["id"] == "unique-id-42"

View file

@ -103,7 +103,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text")
result = await processor.on_embeddings(["test text"])
# Assert
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
mock_text_embedding_class.reset_mock()
# Act
result = await processor.on_embeddings("test text", model="custom-model")
result = await processor.on_embeddings(["test text"], model="custom-model")
# Assert
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
initial_call_count = mock_text_embedding_class.call_count
# Act - switch between models
await processor.on_embeddings("text1", model="model-a")
await processor.on_embeddings(["text1"], model="model-a")
call_count_after_a = mock_text_embedding_class.call_count
await processor.on_embeddings("text2", model="model-a") # Same, no reload
await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
call_count_after_a_repeat = mock_text_embedding_class.call_count
await processor.on_embeddings("text3", model="model-b") # Different, reload
await processor.on_embeddings(["text3"], model="model-b") # Different, reload
call_count_after_b = mock_text_embedding_class.call_count
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
call_count_after_a_again = mock_text_embedding_class.call_count
# Assert
@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
initial_count = mock_text_embedding_class.call_count
# Act
result = await processor.on_embeddings("test text", model=None)
result = await processor.on_embeddings(["test text"], model=None)
# Assert
# No reload, using cached default

View file

@ -0,0 +1,233 @@
"""
Tests for graph embeddings processor batch embedding of entity contexts.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.embeddings.graph_embeddings.embeddings import Processor
from trustgraph.schema import (
EntityContexts, EntityEmbeddings, GraphEmbeddings,
Term, IRI, Metadata,
)
@pytest.fixture
def processor():
return Processor(
taskgroup=AsyncMock(),
id="test-graph-embeddings",
batch_size=3,
)
def _make_entity_context(name, context, chunk_id="chunk-1"):
"""Create an entity context for testing."""
entity = Term(type=IRI, iri=f"urn:entity:{name}")
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
return msg
class TestGraphEmbeddingsInit:
def test_default_batch_size(self):
p = Processor(taskgroup=AsyncMock(), id="test")
assert p.batch_size == 5
def test_custom_batch_size(self):
p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20)
assert p.batch_size == 20
class TestGraphEmbeddingsBatchProcessing:
@pytest.mark.asyncio
async def test_single_batch_call_for_all_entities(self, processor):
"""All entity contexts should be embedded in a single API call."""
entities = [
_make_entity_context("Alice", "Alice is a person"),
_make_entity_context("Bob", "Bob is a developer"),
_make_entity_context("Acme", "Acme is a company"),
]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[
[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],
])
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# Single batch call with all three texts
mock_embed.assert_called_once_with(
texts=["Alice is a person", "Bob is a developer", "Acme is a company"]
)
@pytest.mark.asyncio
async def test_vectors_paired_with_correct_entities(self, processor):
"""Each vector should be paired with its corresponding entity."""
entities = [
_make_entity_context("Alice", "ctx-A", chunk_id="c1"),
_make_entity_context("Bob", "ctx-B", chunk_id="c2"),
]
msg = _make_message(entities)
vectors = [[1.0, 2.0], [3.0, 4.0]]
mock_embed = AsyncMock(return_value=vectors)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
# With batch_size=3, all 2 entities fit in one output message
mock_output.send.assert_called_once()
result = mock_output.send.call_args[0][0]
assert isinstance(result, GraphEmbeddings)
assert len(result.entities) == 2
assert result.entities[0].vector == [1.0, 2.0]
assert result.entities[0].entity.iri == "urn:entity:Alice"
assert result.entities[0].chunk_id == "c1"
assert result.entities[1].vector == [3.0, 4.0]
assert result.entities[1].entity.iri == "urn:entity:Bob"
@pytest.mark.asyncio
async def test_output_batching(self, processor):
"""Output should be split into batches of batch_size."""
# batch_size=3, 7 entities -> 3 output messages (3+3+1)
entities = [
_make_entity_context(f"E{i}", f"context {i}")
for i in range(7)
]
msg = _make_message(entities)
vectors = [[float(i)] for i in range(7)]
mock_embed = AsyncMock(return_value=vectors)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
assert mock_output.send.call_count == 3
# First batch has 3 entities
batch1 = mock_output.send.call_args_list[0][0][0]
assert len(batch1.entities) == 3
# Second batch has 3 entities
batch2 = mock_output.send.call_args_list[1][0][0]
assert len(batch2.entities) == 3
# Third batch has 1 entity
batch3 = mock_output.send.call_args_list[2][0][0]
assert len(batch3.entities) == 1
@pytest.mark.asyncio
async def test_output_batches_preserve_metadata(self, processor):
"""Each output batch should carry the original metadata."""
entities = [
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio
async def test_single_entity(self, processor):
"""Single entity should work with one embed call and one output."""
entities = [_make_entity_context("Solo", "solo context")]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]])
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
mock_embed.assert_called_once_with(texts=["solo context"])
mock_output.send.assert_called_once()
@pytest.mark.asyncio
async def test_embed_error_propagates(self, processor):
"""Embedding service errors should propagate for retry."""
entities = [_make_entity_context("E", "ctx")]
msg = _make_message(entities)
mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed"))
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
return MagicMock()
with pytest.raises(RuntimeError, match="embedding failed"):
await processor.on_message(msg, MagicMock(), flow)
@pytest.mark.asyncio
async def test_exact_batch_size(self, processor):
"""When entity count equals batch_size, exactly one output message."""
entities = [
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(3) # batch_size=3
]
msg = _make_message(entities)
mock_embed = AsyncMock(return_value=[[0.0]] * 3)
mock_output = AsyncMock()
def flow(name):
if name == "embeddings-request":
return MagicMock(embed=mock_embed)
elif name == "output":
return mock_output
return MagicMock()
await processor.on_message(msg, MagicMock(), flow)
mock_output.send.assert_called_once()
assert len(mock_output.send.call_args[0][0].entities) == 3

View file

@ -53,12 +53,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text")
result = await processor.on_embeddings(["test text"])
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
input=["test text"]
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@ -79,12 +79,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model="custom-model")
result = await processor.on_embeddings(["test text"], model="custom-model")
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="custom-model",
input="test text"
input=["test text"]
)
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act - switch between different models
await processor.on_embeddings("text1", model="model-a")
await processor.on_embeddings("text2", model="model-b")
await processor.on_embeddings("text3", model="model-a")
await processor.on_embeddings("text4") # Use default
await processor.on_embeddings(["text1"], model="model-a")
await processor.on_embeddings(["text2"], model="model-b")
await processor.on_embeddings(["text3"], model="model-a")
await processor.on_embeddings(["text4"]) # Use default
# Assert
calls = mock_ollama_client.embed.call_args_list
@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
# Act
result = await processor.on_embeddings("test text", model=None)
result = await processor.on_embeddings(["test text"], model=None)
# Assert
mock_ollama_client.embed.assert_called_once_with(
model="test-model",
input="test text"
input=["test text"]
)
@patch('trustgraph.embeddings.ollama.processor.Client')

View file

@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
# Mock the flow
mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
# Return batch of vector sets (one per text)
# 4 unique texts: CUST001, John Doe, CUST002, Jane Smith
mock_embeddings_request.embed.return_value = [
[[0.1, 0.2, 0.3]], # vectors for text 1
[[0.2, 0.3, 0.4]], # vectors for text 2
[[0.3, 0.4, 0.5]], # vectors for text 3
[[0.4, 0.5, 0.6]], # vectors for text 4
]
mock_output = AsyncMock()
@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text
# 4 values: CUST001, John Doe, CUST002, Jane Smith
assert mock_embeddings_request.embed.call_count == 4
# Should have called embed once with all texts in a batch
assert mock_embeddings_request.embed.call_count == 1
# Verify it was called with a list of texts
call_args = mock_embeddings_request.embed.call_args
assert 'texts' in call_args.kwargs
assert len(call_args.kwargs['texts']) == 4
# Should have sent output
mock_output.send.assert_called()

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,407 @@
"""
Tests for streaming triple and entity context batching in the definitions
KG extractor.
Covers: triples batch splitting, entity context batch splitting,
metadata preservation, provenance, and empty/null filtering.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_processor(triples_batch_size=default_triples_batch_size,
entity_batch_size=default_entity_batch_size):
proc = Processor.__new__(Processor)
proc.triples_batch_size = triples_batch_size
proc.entity_batch_size = entity_batch_size
return proc
def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
)
msg = MagicMock()
msg.value.return_value = chunk
return msg
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
)
def flow(name):
if name == "prompt-request":
return mock_prompt_client
if name == "triples":
return mock_triples_pub
if name == "entity-contexts":
return mock_ecs_pub
if name == "llm-model":
return llm_model
if name == "ontology":
return ontology_uri
return MagicMock()
return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client
def _sent_triples(mock_pub):
return [call.args[0] for call in mock_pub.send.call_args_list]
def _sent_ecs(mock_pub):
return [call.args[0] for call in mock_pub.send.call_args_list]
def _all_triples_flat(mock_pub):
result = []
for triples_msg in _sent_triples(mock_pub):
result.extend(triples_msg.triples)
return result
def _all_entities_flat(mock_pub):
result = []
for ecs_msg in _sent_ecs(mock_pub):
result.extend(ecs_msg.entities)
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDefaults:
def test_default_triples_batch_size(self):
assert default_triples_batch_size == 50
def test_default_entity_batch_size(self):
assert default_entity_batch_size == 5
class TestTriplesBatching:
@pytest.mark.asyncio
async def test_single_batch_when_under_limit(self):
proc = _make_processor(triples_batch_size=100)
defs = [_make_defn("Cat", "A feline animal")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_triples_batches(self):
proc = _make_processor(triples_batch_size=2)
defs = [
_make_defn("Cat", "A feline"),
_make_defn("Dog", "A canine"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 2 defs → 2 labels + 2 definitions = 4 triples + provenance
# With batch_size=2, should produce multiple batches
assert triples_pub.send.call_count > 1
@pytest.mark.asyncio
async def test_triples_batch_sizes_within_limit(self):
batch_size = 3
proc = _make_processor(triples_batch_size=batch_size)
defs = [
_make_defn("A", "def A"),
_make_defn("B", "def B"),
_make_defn("C", "def C"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(triples_pub):
assert len(triples_msg.triples) <= batch_size
class TestEntityContextBatching:
@pytest.mark.asyncio
async def test_single_entity_batch_when_under_limit(self):
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 1 def → 2 entity contexts (name + definition)
assert ecs_pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_entity_batches(self):
proc = _make_processor(entity_batch_size=2)
defs = [
_make_defn("Cat", "A feline"),
_make_defn("Dog", "A canine"),
]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
# 2 defs → 4 entity contexts, batch_size=2 → 2 batches
assert ecs_pub.send.call_count == 2
@pytest.mark.asyncio
async def test_entity_batch_sizes_within_limit(self):
batch_size = 3
proc = _make_processor(entity_batch_size=batch_size)
defs = [
_make_defn("A", "def A"),
_make_defn("B", "def B"),
_make_defn("C", "def C"),
]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for ecs_msg in _sent_ecs(ecs_pub):
assert len(ecs_msg.entities) <= batch_size
@pytest.mark.asyncio
async def test_entity_contexts_have_name_and_definition(self):
"""Each definition produces 2 entity contexts: name and definition."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline animal")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
assert len(entities) == 2
contexts = {e.context for e in entities}
assert "Cat" in contexts
assert "A feline animal" in contexts
class TestMetadataPreservation:
@pytest.mark.asyncio
async def test_triples_metadata(self):
proc = _make_processor(triples_batch_size=2)
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
async def test_entity_contexts_metadata(self):
proc = _make_processor(entity_batch_size=1)
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)
for ecs_msg in _sent_ecs(ecs_pub):
assert ecs_msg.metadata.id == "c-2"
assert ecs_msg.metadata.root == "r-2"
class TestEmptyAndNullFiltering:
@pytest.mark.asyncio
async def test_empty_entity_skipped(self):
proc = _make_processor()
defs = [
_make_defn("", "some definition"),
_make_defn("Valid", "a valid definition"),
]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
all_e = _all_entities_flat(ecs_pub)
# Only "Valid" should be present
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
assert any("valid" in iri for iri in entity_iris)
assert len(all_e) == 2 # name + definition for "Valid" only
@pytest.mark.asyncio
async def test_empty_definition_skipped(self):
proc = _make_processor()
defs = [
_make_defn("Entity", ""),
_make_defn("Good", "good definition"),
]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
assert any("good" in iri for iri in entity_iris)
# "Entity" with empty def should have been skipped
assert not any("entity" in iri and "good" not in iri for iri in entity_iris)
@pytest.mark.asyncio
async def test_none_fields_skipped(self):
proc = _make_processor()
defs = [
_make_defn(None, "some definition"),
_make_defn("Entity", None),
]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_all_filtered_no_output(self):
proc = _make_processor()
defs = [_make_defn("", ""), _make_defn(None, None)]
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_empty_prompt_response(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, _ = _make_flow([])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
class TestProvenanceInclusion:
@pytest.mark.asyncio
async def test_provenance_triples_present(self):
proc = _make_processor(triples_batch_size=200)
defs = [_make_defn("Cat", "A feline")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(triples_pub)
# 1 def → 1 label + 1 definition = 2 content triples
# Provenance adds more
assert len(all_t) > 2
class TestErrorHandling:
@pytest.mark.asyncio
async def test_prompt_error_caught(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, prompt = _make_flow([])
prompt.extract_definitions = AsyncMock(
side_effect=RuntimeError("LLM error")
)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
@pytest.mark.asyncio
async def test_non_list_response_caught(self):
proc = _make_processor()
flow, triples_pub, ecs_pub, prompt = _make_flow("not a list")
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert triples_pub.send.call_count == 0
assert ecs_pub.send.call_count == 0
class TestDocumentIdProvenance:
@pytest.mark.asyncio
async def test_document_id_used_for_chunk_id(self):
"""When document_id is set, entity contexts should use it as chunk_id."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text", document_id="doc-123")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
for e in entities:
assert e.chunk_id == "doc-123"
@pytest.mark.asyncio
async def test_metadata_id_fallback_for_chunk_id(self):
"""When document_id is empty, metadata.id is used as chunk_id."""
proc = _make_processor(entity_batch_size=100)
defs = [_make_defn("Cat", "A feline")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="")
await proc.on_message(msg, MagicMock(), flow)
entities = _all_entities_flat(ecs_pub)
for e in entities:
assert e.chunk_id == "chunk-42"

View file

@ -0,0 +1,408 @@
"""
Tests for streaming triple batching in the relationships KG extractor.
Covers: batch size configuration, output splitting, metadata preservation,
provenance inclusion, empty/null filtering, and error propagation.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.relationships.extract import (
Processor, default_triples_batch_size,
)
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_processor(triples_batch_size=default_triples_batch_size):
"""Create a Processor without triggering FlowProcessor.__init__."""
proc = Processor.__new__(Processor)
proc.triples_batch_size = triples_batch_size
return proc
def _make_rel(subject, predicate, obj, object_entity=True):
"""Build a relationship dict as returned by the prompt client."""
return {
"subject": subject,
"predicate": predicate,
"object": obj,
"object-entity": object_entity,
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
)
msg = MagicMock()
msg.value.return_value = chunk
return msg
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
"""Build a mock flow callable that provides prompt client, triples
producer, and parameter specs."""
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
)
def flow(name):
if name == "prompt-request":
return mock_prompt_client
if name == "triples":
return mock_triples_pub
if name == "llm-model":
return llm_model
if name == "ontology":
return ontology_uri
return MagicMock()
return flow, mock_triples_pub, mock_prompt_client
def _sent_triples(mock_pub):
"""Collect all Triples objects sent to a mock publisher."""
return [call.args[0] for call in mock_pub.send.call_args_list]
def _all_triples_flat(mock_pub):
"""Flatten all batches into one list of Triple objects."""
result = []
for triples_msg in _sent_triples(mock_pub):
result.extend(triples_msg.triples)
return result
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDefaultBatchSize:
def test_default_is_50(self):
assert default_triples_batch_size == 50
def test_processor_uses_default(self):
proc = _make_processor()
assert proc.triples_batch_size == 50
class TestBatchSplitting:
@pytest.mark.asyncio
async def test_single_batch_when_under_limit(self):
"""Few triples → single send call."""
proc = _make_processor(triples_batch_size=50)
rels = [_make_rel("A", "knows", "B")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("some text")
await proc.on_message(msg, MagicMock(), flow)
# One relationship produces: rel triple + 3 labels + provenance
# All should fit in one batch of 50
assert pub.send.call_count == 1
@pytest.mark.asyncio
async def test_multiple_batches_with_small_batch_size(self):
"""With batch_size=3 and many triples, multiple batches are sent."""
proc = _make_processor(triples_batch_size=3)
# 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance
rels = [
_make_rel("A", "knows", "B"),
_make_rel("C", "likes", "D"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("some text")
await proc.on_message(msg, MagicMock(), flow)
# Should have more than one batch
assert pub.send.call_count > 1
@pytest.mark.asyncio
async def test_batch_sizes_respect_limit(self):
"""No batch should exceed the configured batch size."""
batch_size = 3
proc = _make_processor(triples_batch_size=batch_size)
rels = [
_make_rel("A", "knows", "B"),
_make_rel("C", "likes", "D"),
_make_rel("E", "has", "F"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(pub):
assert len(triples_msg.triples) <= batch_size
@pytest.mark.asyncio
async def test_all_triples_present_across_batches(self):
"""Total triples across batches equals expected count."""
proc = _make_processor(triples_batch_size=2)
# 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples
# + provenance triples
rels = [_make_rel("A", "knows", "B", object_entity=True)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# At minimum: 1 rel + 3 labels = 4 content triples
assert len(all_t) >= 4
@pytest.mark.asyncio
async def test_custom_batch_size(self):
"""Processor respects custom triples_batch_size parameter."""
proc = _make_processor(triples_batch_size=100)
assert proc.triples_batch_size == 100
class TestMetadataPreservation:
@pytest.mark.asyncio
async def test_metadata_forwarded_to_all_batches(self):
"""Every batch should carry the original chunk metadata."""
proc = _make_processor(triples_batch_size=2)
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
class TestRelationshipTriples:
@pytest.mark.asyncio
async def test_entity_object_produces_iri(self):
"""object-entity=True → object is an IRI, with label triple."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Find the relationship triple (not a label)
rel_triples = [
t for t in all_t
if t.o.type == IRI and "bob" in t.o.iri
]
assert len(rel_triples) >= 1
@pytest.mark.asyncio
async def test_literal_object_produces_literal(self):
"""object-entity=False → object is a LITERAL, no label for object."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "age", "30", object_entity=False)]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Find the relationship triple with literal object
lit_triples = [
t for t in all_t
if t.o.type == LITERAL and t.o.value == "30"
]
assert len(lit_triples) == 1
@pytest.mark.asyncio
async def test_labels_emitted_for_subject_and_predicate(self):
"""Every relationship should produce label triples for s and p."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("Alice", "knows", "Bob")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
label_triples = [
t for t in all_t
if t.p.type == IRI and "label" in t.p.iri.lower()
]
labels = {t.o.value for t in label_triples}
assert "Alice" in labels
assert "knows" in labels
assert "Bob" in labels # object-entity default is True
class TestEmptyAndNullFiltering:
@pytest.mark.asyncio
async def test_empty_string_fields_skipped(self):
"""Relationships with empty string s/p/o are skipped."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel("", "knows", "Bob"),
_make_rel("Alice", "", "Bob"),
_make_rel("Alice", "knows", ""),
_make_rel("Good", "triple", "Here"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Only the "Good triple Here" relationship should produce content triples
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
assert any("good" in iri for iri in rel_iris)
assert not any("alice" in iri for iri in rel_iris)
@pytest.mark.asyncio
async def test_none_fields_skipped(self):
"""Relationships with None s/p/o are skipped."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel(None, "knows", "Bob"),
_make_rel("Alice", None, "Bob"),
_make_rel("Alice", "knows", None),
_make_rel("Valid", "rel", "Here"),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
assert any("valid" in iri for iri in rel_iris)
assert not any("alice" in iri for iri in rel_iris)
@pytest.mark.asyncio
async def test_all_filtered_produces_no_output(self):
"""If all relationships are empty/null, nothing is emitted."""
proc = _make_processor(triples_batch_size=200)
rels = [
_make_rel("", "", ""),
_make_rel(None, None, None),
]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
@pytest.mark.asyncio
async def test_empty_prompt_response_produces_no_output(self):
"""Empty relationship list from prompt → no triples emitted."""
proc = _make_processor()
flow, pub, _ = _make_flow([])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestProvenanceInclusion:
@pytest.mark.asyncio
async def test_provenance_triples_present(self):
"""Extracted relationships should include provenance triples."""
proc = _make_processor(triples_batch_size=200)
rels = [_make_rel("A", "knows", "B")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
all_t = _all_triples_flat(pub)
# Provenance triples use GRAPH_SOURCE graph context
# They contain terms referencing prov: namespace or subgraph URIs
# We just check that total count > 4 (1 rel + 3 labels)
assert len(all_t) > 4
@pytest.mark.asyncio
async def test_no_provenance_when_no_extracted_triples(self):
"""Empty relationships → no provenance generated."""
proc = _make_processor()
flow, pub, _ = _make_flow([_make_rel("", "x", "y")])
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestErrorPropagation:
@pytest.mark.asyncio
async def test_prompt_error_is_caught(self):
"""Errors from the prompt client are caught (logged, not raised)."""
proc = _make_processor()
flow, pub, prompt = _make_flow([])
prompt.extract_relationships = AsyncMock(
side_effect=RuntimeError("LLM unavailable")
)
msg = _make_chunk_msg("text")
# The outer try/except in on_message catches and logs
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
@pytest.mark.asyncio
async def test_non_list_response_is_caught(self):
"""Non-list prompt response triggers RuntimeError, caught by handler."""
proc = _make_processor()
flow, pub, prompt = _make_flow("not a list")
msg = _make_chunk_msg("text")
await proc.on_message(msg, MagicMock(), flow)
assert pub.send.call_count == 0
class TestToUri:
def test_spaces_replaced_with_hyphens(self):
proc = _make_processor()
uri = proc.to_uri("hello world")
assert "hello-world" in uri
def test_lowercased(self):
proc = _make_processor()
uri = proc.to_uri("Hello World")
assert "hello-world" in uri
def test_special_chars_encoded(self):
proc = _make_processor()
# urllib.parse.quote keeps / as safe by default
uri = proc.to_uri("a/b")
assert "a/b" in uri
# Characters like spaces are encoded (handled via replace → hyphen)
uri2 = proc.to_uri("hello world")
assert " " not in uri2

View file

@ -55,13 +55,6 @@ def sample_objects_message():
return {
"metadata": {
"id": "obj-123",
"metadata": [
{
"s": {"v": "obj-123", "e": False},
"p": {"v": "source", "e": False},
"o": {"v": "test", "e": False}
}
],
"user": "testuser",
"collection": "testcollection"
},
@ -244,7 +237,6 @@ class TestRowsImportMessageProcessing:
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
@ -277,7 +269,6 @@ class TestRowsImportMessageProcessing:
assert sent_object.values[0]["field1"] == "value1"
assert sent_object.confidence == 1.0 # Default value
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio

View file

@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator:
assert is_final is False
assert result["end_of_stream"] is False
# Test final chunk with empty content
# Test final message with end_of_session=True
final_response = GraphRagResponse(
response="",
end_of_stream=True,
end_of_session=True,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
# Assert
# Assert - is_final is based on end_of_session, not end_of_stream
assert is_final is True
assert result["response"] == ""
assert result["end_of_stream"] is True
assert result["end_of_session"] is True
class TestDocumentRagResponseTranslator:

View file

@ -29,11 +29,11 @@ class Triple:
self.o = o
class Metadata:
def __init__(self, id, user, collection, metadata):
def __init__(self, id, user, collection, root=""):
self.id = id
self.root = root
self.user = user
self.collection = collection
self.metadata = metadata
class Triples:
def __init__(self, metadata, triples):
@ -110,7 +110,6 @@ def sample_triples(sample_triple):
id="test-doc-123",
user="test_user",
collection="test_collection",
metadata=[]
)
return Triples(
@ -126,7 +125,6 @@ def sample_chunk():
id="test-chunk-456",
user="test_user",
collection="test_collection",
metadata=[]
)
return Chunk(

View file

@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.template.prompt_manager import PromptManager
@ -51,13 +51,6 @@ class TestAgentKgExtractor:
"""Sample metadata for testing"""
return Metadata(
id="doc123",
metadata=[
Triple(
s=Term(type=IRI, iri="doc123"),
p=Term(type=IRI, iri="http://example.org/type"),
o=Term(type=LITERAL, value="document")
)
]
)
@pytest.fixture
@ -175,7 +168,7 @@ This is not JSON at all
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
# Check entity label triple
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
@ -190,12 +183,6 @@ This is not JSON at all
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert def_triple.o.value == "A subset of AI that enables learning from data."
# Check subject-of triple
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
assert subject_of_triple is not None
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert subject_of_triple.o.iri == "doc123"
# Check entity context
assert len(entity_contexts) == 1
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
@ -213,7 +200,7 @@ This is not JSON at all
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that subject, predicate, and object labels are created
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
@ -235,10 +222,6 @@ This is not JSON at all
assert rel_triple.o.iri == object_uri
assert rel_triple.o.type == IRI
# Check subject-of relationships
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
"""Test processing of relationships with literal objects"""
data = [
@ -251,7 +234,7 @@ This is not JSON at all
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that object labels are not created for literal objects
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
@ -260,7 +243,7 @@ This is not JSON at all
def test_process_extraction_data_combined(self, agent_extractor, sample_metadata, sample_extraction_data):
"""Test processing of combined definitions and relationships"""
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
# Check that we have both definition and relationship triples
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
@ -274,16 +257,12 @@ This is not JSON at all
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
"""Test processing when metadata has no ID"""
metadata = Metadata(id=None, metadata=[])
metadata = Metadata(id=None)
data = [
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of relationships when no metadata ID
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should still create entity contexts
assert len(entity_contexts) == 1
@ -292,7 +271,7 @@ This is not JSON at all
"""Test processing of empty extraction data"""
data = []
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
# Should have no entity contexts
assert len(entity_contexts) == 0
@ -307,7 +286,7 @@ This is not JSON at all
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
# Should process valid items and ignore unknown types
assert len(entity_contexts) == 1 # Only the definition creates entity context
@ -345,8 +324,6 @@ This is not JSON at all
assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection
# Note: metadata.metadata is now empty array in the new implementation
assert sent_triples.metadata.metadata == []
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject"
@ -371,8 +348,6 @@ This is not JSON at all
assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection
# Note: metadata.metadata is now empty array in the new implementation
assert sent_contexts.metadata.metadata == []
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
@pytest.mark.unit
@ -168,7 +168,7 @@ class TestAgentKgExtractionEdgeCases:
"""Test processing with empty or minimal metadata"""
# Test with None metadata - may not raise AttributeError depending on implementation
try:
triples, contexts = agent_extractor.process_extraction_data([], None)
triples, contexts, _ = agent_extractor.process_extraction_data([], None)
# If it doesn't raise, check the results
assert len(triples) == 0
assert len(contexts) == 0
@ -177,23 +177,19 @@ class TestAgentKgExtractionEdgeCases:
pass
# Test with metadata without ID
metadata = Metadata(id=None, metadata=[])
triples, contexts = agent_extractor.process_extraction_data([], metadata)
metadata = Metadata(id=None)
triples, contexts, _ = agent_extractor.process_extraction_data([], metadata)
assert len(triples) == 0
assert len(contexts) == 0
# Test with metadata with empty string ID
metadata = Metadata(id="", metadata=[])
metadata = Metadata(id="")
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of triples when ID is empty string
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
def test_process_extraction_data_special_entity_names(self, agent_extractor):
"""Test processing with special characters in entity names"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
special_entities = [
"Entity with spaces",
@ -213,7 +209,7 @@ class TestAgentKgExtractionEdgeCases:
for entity in special_entities
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Verify all entities were processed
assert len(contexts) == len(special_entities)
@ -225,7 +221,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
"""Test processing with very long entity definitions"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
# Create very long definition
long_definition = "This is a very long definition. " * 1000
@ -234,7 +230,7 @@ class TestAgentKgExtractionEdgeCases:
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should handle long definitions without issues
assert len(contexts) == 1
@ -247,7 +243,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
"""Test processing with duplicate entity names"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
data = [
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
@ -256,7 +252,7 @@ class TestAgentKgExtractionEdgeCases:
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should process all entries (including duplicates)
assert len(contexts) == 4
@ -269,7 +265,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_empty_strings(self, agent_extractor):
"""Test processing with empty strings in data"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
data = [
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
@ -280,7 +276,7 @@ class TestAgentKgExtractionEdgeCases:
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should handle empty strings by creating URIs (even if empty)
assert len(contexts) == 3
@ -291,7 +287,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
"""Test processing when definitions contain JSON-like strings"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
data = [
{
@ -306,7 +302,7 @@ class TestAgentKgExtractionEdgeCases:
}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should handle JSON strings in definitions without parsing them
assert len(contexts) == 2
@ -315,7 +311,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
"""Test processing with various boolean values for object-entity"""
metadata = Metadata(id="doc123", metadata=[])
metadata = Metadata(id="doc123")
data = [
# Explicit True
@ -334,16 +330,16 @@ class TestAgentKgExtractionEdgeCases:
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
# Should process all relationships
# Note: The current implementation has some logic issues that these tests document
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
assert len([t for t in triples if t.p.iri != RDF_LABEL]) >= 7
@pytest.mark.asyncio
async def test_emit_empty_collections(self, agent_extractor):
"""Test emitting empty triples and entity contexts"""
metadata = Metadata(id="test", metadata=[])
metadata = Metadata(id="test")
# Test emitting empty triples
mock_publisher = AsyncMock()
@ -389,7 +385,7 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
"""Test performance with large extraction datasets"""
metadata = Metadata(id="large-doc", metadata=[])
metadata = Metadata(id="large-doc")
# Create large dataset in JSONL format
num_definitions = 1000
@ -416,7 +412,7 @@ class TestAgentKgExtractionEdgeCases:
import time
start_time = time.time()
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
triples, contexts, _ = agent_extractor.process_extraction_data(large_data, metadata)
end_time = time.time()
processing_time = end_time - start_time

View file

@ -314,7 +314,6 @@ class TestObjectExtractionBusinessLogic:
id="test-extraction-001",
user="test_user",
collection="test_collection",
metadata=[]
)
values = [{

View file

@ -373,7 +373,6 @@ class TestTripleConstructionLogic:
id="test-doc-123",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act

View file

View file

@ -0,0 +1,716 @@
"""
Tests for librarian chunked upload operations:
begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status,
list_uploads, and stream_document.
"""
import base64
import json
import math
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE
from trustgraph.exceptions import RequestError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_librarian(min_chunk_size=1):
"""Create a Librarian with mocked blob_store and table_store."""
lib = Librarian.__new__(Librarian)
lib.blob_store = MagicMock()
lib.table_store = AsyncMock()
lib.load_document = AsyncMock()
lib.min_chunk_size = min_chunk_size
return lib
def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
):
meta = MagicMock()
meta.id = doc_id
meta.kind = kind
meta.user = user
meta.title = title
meta.time = 1700000000
meta.comments = ""
meta.tags = []
return meta
def _make_begin_request(
doc_id="doc-1", kind="application/pdf", user="alice",
total_size=10_000_000, chunk_size=0
):
req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
req.total_size = total_size
req.chunk_size = chunk_size
return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
req = MagicMock()
req.upload_id = upload_id
req.chunk_index = chunk_index
req.user = user
req.content = base64.b64encode(content)
return req
def _make_session(
user="alice", total_chunks=5, chunk_size=2_000_000,
total_size=10_000_000, chunks_received=None, object_id="obj-1",
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
):
if chunks_received is None:
chunks_received = {}
if document_metadata is None:
document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf",
"user": user, "title": "Test", "time": 1700000000,
"comments": "", "tags": [],
})
return {
"user": user,
"total_chunks": total_chunks,
"chunk_size": chunk_size,
"total_size": total_size,
"chunks_received": chunks_received,
"object_id": object_id,
"s3_upload_id": s3_upload_id,
"document_metadata": document_metadata,
"document_id": document_id,
}
# ---------------------------------------------------------------------------
# begin_upload
# ---------------------------------------------------------------------------
class TestBeginUpload:
@pytest.mark.asyncio
async def test_creates_session(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
req = _make_begin_request(total_size=10_000_000)
resp = await lib.begin_upload(req)
assert resp.error is None
assert resp.upload_id is not None
assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE)
assert resp.chunk_size == DEFAULT_CHUNK_SIZE
@pytest.mark.asyncio
async def test_custom_chunk_size(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=10_000, chunk_size=3000)
resp = await lib.begin_upload(req)
assert resp.chunk_size == 3000
assert resp.total_chunks == math.ceil(10_000 / 3000)
@pytest.mark.asyncio
async def test_rejects_invalid_kind(self):
lib = _make_librarian()
req = _make_begin_request(kind="image/png")
with pytest.raises(RequestError, match="Invalid document kind"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_duplicate_document(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = True
req = _make_begin_request()
with pytest.raises(RequestError, match="already exists"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_zero_size(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
req = _make_begin_request(total_size=0)
with pytest.raises(RequestError, match="positive"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_rejects_chunk_below_minimum(self):
lib = _make_librarian(min_chunk_size=1024)
lib.table_store.document_exists.return_value = False
req = _make_begin_request(total_size=10_000, chunk_size=512)
with pytest.raises(RequestError, match="below minimum"):
await lib.begin_upload(req)
@pytest.mark.asyncio
async def test_calls_s3_create_multipart(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="application/pdf")
await lib.begin_upload(req)
lib.blob_store.create_multipart_upload.assert_called_once()
# create_multipart_upload(object_id, kind) — positional args
args = lib.blob_store.create_multipart_upload.call_args[0]
assert args[1] == "application/pdf"
@pytest.mark.asyncio
async def test_stores_session_in_cassandra(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=5_000_000)
resp = await lib.begin_upload(req)
lib.table_store.create_upload_session.assert_called_once()
kwargs = lib.table_store.create_upload_session.call_args[1]
assert kwargs["upload_id"] == resp.upload_id
assert kwargs["total_size"] == 5_000_000
assert kwargs["total_chunks"] == resp.total_chunks
@pytest.mark.asyncio
async def test_accepts_text_plain(self):
lib = _make_librarian()
lib.table_store.document_exists.return_value = False
lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="text/plain", total_size=1000)
resp = await lib.begin_upload(req)
assert resp.error is None
# ---------------------------------------------------------------------------
# upload_chunk
# ---------------------------------------------------------------------------
class TestUploadChunk:
@pytest.mark.asyncio
async def test_successful_chunk_upload(self):
lib = _make_librarian()
session = _make_session(total_chunks=5, chunks_received={})
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag-1"
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
resp = await lib.upload_chunk(req)
assert resp.error is None
assert resp.chunk_index == 0
assert resp.total_chunks == 5
# The chunk is added to the dict (len=1), then +1 applied => 2
assert resp.chunks_received == 2
@pytest.mark.asyncio
async def test_s3_part_number_is_1_indexed(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=0)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
@pytest.mark.asyncio
async def test_chunk_index_3_becomes_part_4(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=3)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 4
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = _make_upload_chunk_request()
with pytest.raises(RequestError, match="not found"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(user="bob")
with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_negative_chunk_index(self):
lib = _make_librarian()
session = _make_session(total_chunks=5)
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(chunk_index=-1)
with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_rejects_out_of_range_chunk_index(self):
lib = _make_librarian()
session = _make_session(total_chunks=5)
lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(chunk_index=5)
with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req)
@pytest.mark.asyncio
async def test_progress_tracking(self):
lib = _make_librarian()
session = _make_session(
total_chunks=4, chunk_size=1000, total_size=3500,
chunks_received={0: "e1", 1: "e2"},
)
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "e3"
req = _make_upload_chunk_request(chunk_index=2)
resp = await lib.upload_chunk(req)
# Dict gets chunk 2 added (len=3), then +1 => 4
assert resp.chunks_received == 4
assert resp.total_chunks == 4
assert resp.total_bytes == 3500
@pytest.mark.asyncio
async def test_bytes_capped_at_total_size(self):
"""bytes_received should not exceed total_size for the final chunk."""
lib = _make_librarian()
session = _make_session(
total_chunks=2, chunk_size=3000, total_size=5000,
chunks_received={0: "e1"},
)
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "e2"
req = _make_upload_chunk_request(chunk_index=1)
resp = await lib.upload_chunk(req)
# 3 chunks × 3000 = 9000 > 5000, so capped
assert resp.bytes_received <= 5000
@pytest.mark.asyncio
async def test_base64_decodes_content(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
lib.blob_store.upload_part.return_value = "etag"
raw = b"hello world binary data"
req = _make_upload_chunk_request(content=raw)
await lib.upload_chunk(req)
kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["data"] == raw
# ---------------------------------------------------------------------------
# complete_upload
# ---------------------------------------------------------------------------
class TestCompleteUpload:
@pytest.mark.asyncio
async def test_successful_completion(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3,
chunks_received={0: "e1", 1: "e2", 2: "e3"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.complete_upload(req)
assert resp.error is None
assert resp.document_id == "doc-1"
lib.blob_store.complete_multipart_upload.assert_called_once()
lib.table_store.add_document.assert_called_once()
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
@pytest.mark.asyncio
async def test_parts_sorted_by_index(self):
lib = _make_librarian()
# Chunks received out of order
session = _make_session(
total_chunks=3,
chunks_received={2: "e3", 0: "e1", 1: "e2"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
await lib.complete_upload(req)
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
part_numbers = [p[0] for p in parts]
assert part_numbers == [1, 2, 3] # Sorted, 1-indexed
@pytest.mark.asyncio
async def test_rejects_missing_chunks(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3,
chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req)
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req)
# ---------------------------------------------------------------------------
# abort_upload
# ---------------------------------------------------------------------------
class TestAbortUpload:
@pytest.mark.asyncio
async def test_aborts_and_cleans_up(self):
lib = _make_librarian()
session = _make_session()
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.abort_upload(req)
assert resp.error is None
lib.blob_store.abort_multipart_upload.assert_called_once_with(
object_id="obj-1", upload_id="s3-up-1"
)
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
@pytest.mark.asyncio
async def test_rejects_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-gone"
req.user = "alice"
with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req)
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req)
# ---------------------------------------------------------------------------
# get_upload_status
# ---------------------------------------------------------------------------
class TestGetUploadStatus:
@pytest.mark.asyncio
async def test_in_progress_status(self):
lib = _make_librarian()
session = _make_session(
total_chunks=5, chunk_size=2000, total_size=10_000,
chunks_received={0: "e1", 2: "e3", 4: "e5"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.upload_state == "in-progress"
assert resp.chunks_received == 3
assert resp.total_chunks == 5
assert sorted(resp.received_chunks) == [0, 2, 4]
assert sorted(resp.missing_chunks) == [1, 3]
assert resp.total_bytes == 10_000
@pytest.mark.asyncio
async def test_expired_session(self):
lib = _make_librarian()
lib.table_store.get_upload_session.return_value = None
req = MagicMock()
req.upload_id = "up-expired"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.upload_state == "expired"
@pytest.mark.asyncio
async def test_all_chunks_received(self):
lib = _make_librarian()
session = _make_session(
total_chunks=3, chunk_size=1000, total_size=2500,
chunks_received={0: "e1", 1: "e2", 2: "e3"},
)
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "alice"
resp = await lib.get_upload_status(req)
assert resp.missing_chunks == []
assert resp.chunks_received == 3
# 3 * 1000 = 3000 > 2500, so capped
assert resp.bytes_received <= 2500
@pytest.mark.asyncio
async def test_rejects_wrong_user(self):
lib = _make_librarian()
session = _make_session(user="alice")
lib.table_store.get_upload_session.return_value = session
req = MagicMock()
req.upload_id = "up-1"
req.user = "bob"
with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req)
# ---------------------------------------------------------------------------
# stream_document
# ---------------------------------------------------------------------------
class TestStreamDocument:
@pytest.mark.asyncio
async def test_streams_chunks_with_progress(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 3 # ceil(5000/2000)
assert chunks[0].chunk_index == 0
assert chunks[0].total_chunks == 3
assert chunks[0].is_final is False
assert chunks[-1].is_final is True
assert chunks[-1].chunk_index == 2
@pytest.mark.asyncio
async def test_single_chunk_document(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=500)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].bytes_received == 500
assert chunks[0].total_bytes == 500
@pytest.mark.asyncio
async def test_byte_ranges_correct(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 2000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
# Verify the byte ranges passed to get_range
calls = lib.blob_store.get_range.call_args_list
assert calls[0][0] == ("obj-1", 0, 2000)
assert calls[1][0] == ("obj-1", 2000, 2000)
assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000
@pytest.mark.asyncio
async def test_default_chunk_size(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=2_000_000)
lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert len(chunks) == 2 # ceil(2MB / 1MB)
@pytest.mark.asyncio
async def test_content_is_base64_encoded(self):
lib = _make_librarian()
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=100)
raw = b"hello world"
lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 1000
chunks = []
async for resp in lib.stream_document(req):
chunks.append(resp)
assert chunks[0].content == base64.b64encode(raw)
@pytest.mark.asyncio
async def test_rejects_chunk_below_minimum(self):
lib = _make_librarian(min_chunk_size=1024)
lib.table_store.get_document_object_id.return_value = "obj-1"
lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock()
req.user = "alice"
req.document_id = "doc-1"
req.chunk_size = 512
with pytest.raises(RequestError, match="below minimum"):
async for _ in lib.stream_document(req):
pass
# ---------------------------------------------------------------------------
# list_uploads
# ---------------------------------------------------------------------------
class TestListUploads:
@pytest.mark.asyncio
async def test_returns_sessions(self):
lib = _make_librarian()
lib.table_store.list_upload_sessions.return_value = [
{
"upload_id": "up-1",
"document_id": "doc-1",
"document_metadata": '{"id":"doc-1"}',
"total_size": 10000,
"chunk_size": 2000,
"total_chunks": 5,
"chunks_received": {0: "e1", 1: "e2"},
"created_at": "2024-01-01",
},
]
req = MagicMock()
req.user = "alice"
resp = await lib.list_uploads(req)
assert resp.error is None
assert len(resp.upload_sessions) == 1
assert resp.upload_sessions[0].upload_id == "up-1"
assert resp.upload_sessions[0].total_chunks == 5
@pytest.mark.asyncio
async def test_empty_uploads(self):
lib = _make_librarian()
lib.table_store.list_upload_sessions.return_value = []
req = MagicMock()
req.user = "alice"
resp = await lib.list_uploads(req)
assert resp.upload_sessions == []

View file

View file

@ -0,0 +1,336 @@
"""
Tests for agent provenance triple builder functions.
"""
import json
import pytest
from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.agent import (
agent_session_triples,
agent_iteration_triples,
agent_final_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_AGENT_QUESTION,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
for t in triples:
if (t.s.iri == subject and t.p.iri == RDF_TYPE
and t.o.type == IRI and t.o.iri == rdf_type):
return True
return False
# ---------------------------------------------------------------------------
# agent_session_triples
# ---------------------------------------------------------------------------
class TestAgentSessionTriples:
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_session_types(self):
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
def test_session_query_text(self):
triples = agent_session_triples(
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
)
query = find_triple(triples, TG_QUERY, self.SESSION_URI)
assert query is not None
assert query.o.value == "What is X?"
def test_session_timestamp(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-06-15T10:00:00Z"
)
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
assert ts is not None
assert ts.o.value == "2024-06-15T10:00:00Z"
def test_session_default_timestamp(self):
triples = agent_session_triples(self.SESSION_URI, "Q")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
assert ts is not None
assert len(ts.o.value) > 0
def test_session_label(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
label = find_triple(triples, RDFS_LABEL, self.SESSION_URI)
assert label is not None
assert label.o.value == "Agent Question"
def test_session_triple_count(self):
triples = agent_session_triples(
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
)
assert len(triples) == 6
# ---------------------------------------------------------------------------
# agent_iteration_triples
# ---------------------------------------------------------------------------
class TestAgentIterationTriples:
ITER_URI = "urn:trustgraph:agent:test-session/i1"
SESSION_URI = "urn:trustgraph:agent:test-session"
PREV_URI = "urn:trustgraph:agent:test-session/i0"
def test_iteration_types(self):
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
def test_first_iteration_generated_by_question(self):
"""First iteration uses wasGeneratedBy to link to question activity."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
# Should NOT have wasDerivedFrom
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is None
def test_subsequent_iteration_derived_from_previous(self):
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
triples = agent_iteration_triples(
self.ITER_URI, previous_uri=self.PREV_URI,
action="search",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
# Should NOT have wasGeneratedBy
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
assert gen is None
def test_iteration_label_includes_action(self):
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
assert label is not None
assert "graph-rag-query" in label.o.value
def test_iteration_thought_sub_entity(self):
"""Thought is a sub-entity with Reflection and Thought types."""
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
thought_doc = "urn:doc:thought-1"
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
thought_uri=thought_uri,
thought_document_id=thought_doc,
)
# Iteration links to thought sub-entity
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
assert thought_link is not None
assert thought_link.o.iri == thought_uri
# Thought has correct types
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
# Thought was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Thought has document reference
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
assert doc is not None
assert doc.o.iri == thought_doc
def test_iteration_observation_sub_entity(self):
"""Observation is a sub-entity with Reflection and Observation types."""
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
obs_doc = "urn:doc:obs-1"
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
observation_uri=obs_uri,
observation_document_id=obs_doc,
)
# Iteration links to observation sub-entity
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert obs_link is not None
assert obs_link.o.iri == obs_uri
# Observation has correct types
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
# Observation was generated by iteration
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
assert gen is not None
assert gen.o.iri == self.ITER_URI
# Observation has document reference
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
assert doc is not None
assert doc.o.iri == obs_doc
def test_iteration_action_recorded(self):
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="graph-rag-query",
)
action = find_triple(triples, TG_ACTION, self.ITER_URI)
assert action is not None
assert action.o.value == "graph-rag-query"
def test_iteration_arguments_json_encoded(self):
args = {"query": "test query", "limit": 10}
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
arguments=args,
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
assert arguments is not None
parsed = json.loads(arguments.o.value)
assert parsed == args
def test_iteration_default_arguments_empty_dict(self):
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="search",
)
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
assert arguments is not None
parsed = json.loads(arguments.o.value)
assert parsed == {}
def test_iteration_no_thought_or_observation(self):
"""Minimal iteration with just action — no thought or observation triples."""
triples = agent_iteration_triples(
self.ITER_URI, question_uri=self.SESSION_URI,
action="noop",
)
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
assert thought is None
assert obs is None
def test_iteration_chaining(self):
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
iter1_uri = "urn:trustgraph:agent:sess/i1"
iter2_uri = "urn:trustgraph:agent:sess/i2"
triples1 = agent_iteration_triples(
iter1_uri, question_uri=self.SESSION_URI, action="step1",
)
triples2 = agent_iteration_triples(
iter2_uri, previous_uri=iter1_uri, action="step2",
)
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
assert gen1.o.iri == self.SESSION_URI
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
assert derived2.o.iri == iter1_uri
# ---------------------------------------------------------------------------
# agent_final_triples
# ---------------------------------------------------------------------------
class TestAgentFinalTriples:
FINAL_URI = "urn:trustgraph:agent:test-session/final"
PREV_URI = "urn:trustgraph:agent:test-session/i3"
SESSION_URI = "urn:trustgraph:agent:test-session"
def test_final_types(self):
triples = agent_final_triples(
self.FINAL_URI, previous_uri=self.PREV_URI,
)
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
def test_final_derived_from_previous(self):
"""Conclusion with iterations uses wasDerivedFrom."""
triples = agent_final_triples(
self.FINAL_URI, previous_uri=self.PREV_URI,
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is not None
assert derived.o.iri == self.PREV_URI
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is None
def test_final_generated_by_question_when_no_iterations(self):
"""When agent answers immediately, final uses wasGeneratedBy."""
triples = agent_final_triples(
self.FINAL_URI, question_uri=self.SESSION_URI,
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
assert gen is not None
assert gen.o.iri == self.SESSION_URI
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
assert derived is None
def test_final_label(self):
triples = agent_final_triples(
self.FINAL_URI, previous_uri=self.PREV_URI,
)
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
assert label is not None
assert label.o.value == "Conclusion"
def test_final_document_reference(self):
triples = agent_final_triples(
self.FINAL_URI, previous_uri=self.PREV_URI,
document_id="urn:trustgraph:agent:sess/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is not None
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
def test_final_no_document(self):
triples = agent_final_triples(
self.FINAL_URI, previous_uri=self.PREV_URI,
)
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
assert doc is None

View file

@ -0,0 +1,543 @@
"""
Tests for the explainability API (entity parsing, wire format conversion,
and ExplainabilityClient).
"""
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.api.explainability import (
EdgeSelection,
ExplainEntity,
Question,
Grounding,
Exploration,
Focus,
Synthesis,
Reflection,
Analysis,
Conclusion,
parse_edge_selection_triples,
extract_term_value,
wire_triples_to_tuples,
ExplainabilityClient,
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
RDF_TYPE, RDFS_LABEL,
)
# ---------------------------------------------------------------------------
# Entity from_triples parsing
# ---------------------------------------------------------------------------
class TestExplainEntityFromTriples:
"""Test ExplainEntity.from_triples dispatches to correct subclass."""
def test_graphrag_question(self):
triples = [
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "What is AI?"),
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
]
entity = ExplainEntity.from_triples("urn:q:1", triples)
assert isinstance(entity, Question)
assert entity.query == "What is AI?"
assert entity.timestamp == "2024-01-01T00:00:00Z"
assert entity.question_type == "graph-rag"
def test_docrag_question(self):
triples = [
("urn:q:2", RDF_TYPE, TG_QUESTION),
("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION),
("urn:q:2", TG_QUERY, "Find info"),
]
entity = ExplainEntity.from_triples("urn:q:2", triples)
assert isinstance(entity, Question)
assert entity.question_type == "document-rag"
def test_agent_question(self):
triples = [
("urn:q:3", RDF_TYPE, TG_QUESTION),
("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION),
("urn:q:3", TG_QUERY, "Agent query"),
]
entity = ExplainEntity.from_triples("urn:q:3", triples)
assert isinstance(entity, Question)
assert entity.question_type == "agent"
def test_grounding(self):
triples = [
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
("urn:gnd:1", TG_CONCEPT, "machine learning"),
("urn:gnd:1", TG_CONCEPT, "neural networks"),
]
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
assert isinstance(entity, Grounding)
assert len(entity.concepts) == 2
assert "machine learning" in entity.concepts
assert "neural networks" in entity.concepts
def test_exploration(self):
triples = [
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
("urn:exp:1", TG_EDGE_COUNT, "15"),
]
entity = ExplainEntity.from_triples("urn:exp:1", triples)
assert isinstance(entity, Exploration)
assert entity.edge_count == 15
def test_exploration_with_chunk_count(self):
triples = [
("urn:exp:2", RDF_TYPE, TG_EXPLORATION),
("urn:exp:2", TG_CHUNK_COUNT, "5"),
]
entity = ExplainEntity.from_triples("urn:exp:2", triples)
assert isinstance(entity, Exploration)
assert entity.chunk_count == 5
def test_exploration_with_entities(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "10"),
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert len(entity.entities) == 2
def test_exploration_invalid_count(self):
triples = [
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
("urn:exp:3", TG_EDGE_COUNT, "not-a-number"),
]
entity = ExplainEntity.from_triples("urn:exp:3", triples)
assert isinstance(entity, Exploration)
assert entity.edge_count == 0
def test_focus(self):
triples = [
("urn:foc:1", RDF_TYPE, TG_FOCUS),
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"),
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"),
]
entity = ExplainEntity.from_triples("urn:foc:1", triples)
assert isinstance(entity, Focus)
assert len(entity.selected_edge_uris) == 2
assert "urn:edge:1" in entity.selected_edge_uris
assert "urn:edge:2" in entity.selected_edge_uris
def test_synthesis_with_document(self):
triples = [
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
]
entity = ExplainEntity.from_triples("urn:syn:1", triples)
assert isinstance(entity, Synthesis)
assert entity.document == "urn:doc:answer-1"
def test_synthesis_no_document(self):
triples = [
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
]
entity = ExplainEntity.from_triples("urn:syn:2", triples)
assert isinstance(entity, Synthesis)
assert entity.document == ""
def test_reflection_thought(self):
triples = [
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
]
entity = ExplainEntity.from_triples("urn:ref:1", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "thought"
assert entity.document == "urn:doc:thought-1"
def test_reflection_observation(self):
triples = [
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ref:2", triples)
assert isinstance(entity, Reflection)
assert entity.reflection_type == "observation"
assert entity.document == "urn:doc:obs-1"
def test_analysis(self):
triples = [
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
("urn:ana:1", TG_ACTION, "graph-rag-query"),
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
]
entity = ExplainEntity.from_triples("urn:ana:1", triples)
assert isinstance(entity, Analysis)
assert entity.action == "graph-rag-query"
assert entity.arguments == '{"query": "test"}'
assert entity.thought == "urn:ref:thought-1"
assert entity.observation == "urn:ref:obs-1"
def test_conclusion_with_document(self):
triples = [
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
]
entity = ExplainEntity.from_triples("urn:conc:1", triples)
assert isinstance(entity, Conclusion)
assert entity.document == "urn:doc:final"
def test_conclusion_no_document(self):
triples = [
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
]
entity = ExplainEntity.from_triples("urn:conc:2", triples)
assert isinstance(entity, Conclusion)
assert entity.document == ""
def test_unknown_type(self):
triples = [
("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"),
]
entity = ExplainEntity.from_triples("urn:x:1", triples)
assert isinstance(entity, ExplainEntity)
assert entity.entity_type == "unknown"
# ---------------------------------------------------------------------------
# parse_edge_selection_triples
# ---------------------------------------------------------------------------
class TestParseEdgeSelectionTriples:
def test_with_edge_and_reasoning(self):
triples = [
("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}),
("urn:edge:1", TG_REASONING, "Alice and Bob are connected"),
]
result = parse_edge_selection_triples(triples)
assert isinstance(result, EdgeSelection)
assert result.uri == "urn:edge:1"
assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"}
assert result.reasoning == "Alice and Bob are connected"
def test_with_edge_only(self):
triples = [
("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}),
]
result = parse_edge_selection_triples(triples)
assert result.edge is not None
assert result.reasoning == ""
def test_with_reasoning_only(self):
triples = [
("urn:edge:3", TG_REASONING, "some reason"),
]
result = parse_edge_selection_triples(triples)
assert result.edge is None
assert result.reasoning == "some reason"
def test_empty_triples(self):
result = parse_edge_selection_triples([])
assert result.uri == ""
assert result.edge is None
assert result.reasoning == ""
def test_edge_must_be_dict(self):
"""Non-dict values for TG_EDGE should not be treated as edges."""
triples = [
("urn:edge:4", TG_EDGE, "not-a-dict"),
]
result = parse_edge_selection_triples(triples)
assert result.edge is None
# ---------------------------------------------------------------------------
# extract_term_value
# ---------------------------------------------------------------------------
class TestExtractTermValue:
def test_iri_short_format(self):
assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test"
def test_iri_long_format(self):
assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test"
def test_literal_short_format(self):
assert extract_term_value({"t": "l", "v": "hello"}) == "hello"
def test_literal_long_format(self):
assert extract_term_value({"type": "l", "value": "hello"}) == "hello"
def test_quoted_triple(self):
term = {
"t": "t",
"tr": {
"s": {"t": "i", "i": "urn:s"},
"p": {"t": "i", "i": "urn:p"},
"o": {"t": "i", "i": "urn:o"},
}
}
result = extract_term_value(term)
assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"}
def test_quoted_triple_long_format(self):
term = {
"type": "t",
"triple": {
"s": {"type": "i", "iri": "urn:s"},
"p": {"type": "i", "iri": "urn:p"},
"o": {"type": "l", "value": "val"},
}
}
result = extract_term_value(term)
assert result == {"s": "urn:s", "p": "urn:p", "o": "val"}
def test_unknown_type_fallback(self):
result = extract_term_value({"t": "x", "i": "urn:fallback"})
assert result == "urn:fallback"
# ---------------------------------------------------------------------------
# wire_triples_to_tuples
# ---------------------------------------------------------------------------
class TestWireTriplesToTuples:
def test_basic_conversion(self):
wire = [
{
"s": {"t": "i", "i": "urn:s1"},
"p": {"t": "i", "i": "urn:p1"},
"o": {"t": "l", "v": "value1"},
},
]
result = wire_triples_to_tuples(wire)
assert len(result) == 1
assert result[0] == ("urn:s1", "urn:p1", "value1")
def test_multiple_triples(self):
wire = [
{
"s": {"t": "i", "i": "urn:s1"},
"p": {"t": "i", "i": "urn:p1"},
"o": {"t": "l", "v": "v1"},
},
{
"s": {"t": "i", "i": "urn:s2"},
"p": {"t": "i", "i": "urn:p2"},
"o": {"t": "i", "i": "urn:o2"},
},
]
result = wire_triples_to_tuples(wire)
assert len(result) == 2
assert result[0] == ("urn:s1", "urn:p1", "v1")
assert result[1] == ("urn:s2", "urn:p2", "urn:o2")
def test_empty_list(self):
assert wire_triples_to_tuples([]) == []
def test_missing_fields(self):
wire = [{"s": {}, "p": {}, "o": {}}]
result = wire_triples_to_tuples(wire)
assert len(result) == 1
# ---------------------------------------------------------------------------
# ExplainabilityClient
# ---------------------------------------------------------------------------
def _make_wire_triples(tuples):
"""Convert (s, p, o) tuples to wire format for mocking."""
result = []
for s, p, o in tuples:
entry = {
"s": {"t": "i", "i": s},
"p": {"t": "i", "i": p},
}
if o.startswith("urn:") or o.startswith("http"):
entry["o"] = {"t": "i", "i": o}
else:
entry["o"] = {"t": "l", "v": o}
result.append(entry)
return result
class TestExplainabilityClientFetchEntity:
def test_fetch_question_entity(self):
wire = _make_wire_triples([
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "What is AI?"),
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
])
mock_flow = MagicMock()
# Return same results twice for quiescence
mock_flow.triples_query.side_effect = [wire, wire]
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval")
assert isinstance(entity, Question)
assert entity.query == "What is AI?"
assert entity.question_type == "graph-rag"
def test_fetch_returns_none_when_no_data(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = []
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
entity = client.fetch_entity("urn:nonexistent")
assert entity is None
def test_fetch_retries_on_empty_results(self):
wire = _make_wire_triples([
("urn:q:1", RDF_TYPE, TG_QUESTION),
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
("urn:q:1", TG_QUERY, "Q"),
])
mock_flow = MagicMock()
# Empty, then data, then same data (stable)
mock_flow.triples_query.side_effect = [[], wire, wire]
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
entity = client.fetch_entity("urn:q:1")
assert isinstance(entity, Question)
assert mock_flow.triples_query.call_count == 3
class TestExplainabilityClientResolveLabel:
def test_resolve_label_found(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = _make_wire_triples([
("urn:entity:1", RDFS_LABEL, "Entity One"),
])
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
label = client.resolve_label("urn:entity:1")
assert label == "Entity One"
def test_resolve_label_not_found(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = []
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
label = client.resolve_label("urn:entity:1")
assert label == "urn:entity:1"
def test_resolve_label_cached(self):
mock_flow = MagicMock()
mock_flow.triples_query.return_value = _make_wire_triples([
("urn:entity:1", RDFS_LABEL, "Entity One"),
])
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
client.resolve_label("urn:entity:1")
client.resolve_label("urn:entity:1")
# Only one query should be made
assert mock_flow.triples_query.call_count == 1
def test_resolve_label_non_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.resolve_label("plain text") == "plain text"
assert client.resolve_label("") == ""
mock_flow.triples_query.assert_not_called()
def test_resolve_edge_labels(self):
mock_flow = MagicMock()
def mock_query(s=None, p=None, **kwargs):
labels = {
"urn:e:Alice": "Alice",
"urn:r:knows": "knows",
"urn:e:Bob": "Bob",
}
if s in labels:
return _make_wire_triples([(s, RDFS_LABEL, labels[s])])
return []
mock_flow.triples_query.side_effect = mock_query
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
s, p, o = client.resolve_edge_labels(
{"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"}
)
assert s == "Alice"
assert p == "knows"
assert o == "Bob"
class TestExplainabilityClientContentFetching:
def test_fetch_document_content_from_librarian(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.return_value = b"librarian content"
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
result = client.fetch_document_content(
"urn:document:abc123", api=mock_api
)
assert result == "librarian content"
def test_fetch_document_content_truncated(self):
mock_flow = MagicMock()
mock_api = MagicMock()
mock_library = MagicMock()
mock_api.library.return_value = mock_library
mock_library.get_document_content.return_value = b"x" * 20000
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
result = client.fetch_document_content(
"urn:doc:1", api=mock_api, max_content=100
)
assert len(result) < 20000
assert result.endswith("... [truncated]")
def test_fetch_document_content_empty_uri(self):
mock_flow = MagicMock()
mock_api = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
result = client.fetch_document_content("", api=mock_api)
assert result == ""
class TestExplainabilityClientDetectSessionType:
def test_detect_agent_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent"
def test_detect_graphrag_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag"
def test_detect_docrag_from_uri(self):
mock_flow = MagicMock()
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"

View file

@ -0,0 +1,812 @@
"""
Tests for provenance triple builder functions (extraction-time and query-time).
"""
import pytest
from unittest.mock import patch
from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE
from trustgraph.provenance.triples import (
set_graph,
document_triples,
derived_entity_triples,
subgraph_provenance_triples,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
docrag_question_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
)
from trustgraph.provenance.namespaces import (
RDF_TYPE, RDFS_LABEL,
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER,
TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_DOCUMENT,
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANSWER_TYPE,
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
GRAPH_SOURCE, GRAPH_RETRIEVAL,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
"""Find first triple matching predicate (and optionally subject)."""
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
"""Find all triples matching predicate (and optionally subject)."""
return [
t for t in triples
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
"""Check if subject has rdf:type rdf_type."""
for t in triples:
if (t.s.iri == subject and t.p.iri == RDF_TYPE
and t.o.type == IRI and t.o.iri == rdf_type):
return True
return False
# ---------------------------------------------------------------------------
# set_graph
# ---------------------------------------------------------------------------
class TestSetGraph:
def test_sets_graph_on_all_triples(self):
triples = [
Triple(
s=Term(type=IRI, iri="urn:s1"),
p=Term(type=IRI, iri="urn:p1"),
o=Term(type=LITERAL, value="v1"),
),
Triple(
s=Term(type=IRI, iri="urn:s2"),
p=Term(type=IRI, iri="urn:p2"),
o=Term(type=LITERAL, value="v2"),
),
]
result = set_graph(triples, GRAPH_RETRIEVAL)
assert len(result) == 2
for t in result:
assert t.g == GRAPH_RETRIEVAL
def test_does_not_modify_originals(self):
original = Triple(
s=Term(type=IRI, iri="urn:s"),
p=Term(type=IRI, iri="urn:p"),
o=Term(type=LITERAL, value="v"),
)
result = set_graph([original], "urn:graph:test")
assert original.g is None
assert result[0].g == "urn:graph:test"
def test_empty_list(self):
result = set_graph([], GRAPH_SOURCE)
assert result == []
def test_preserves_spo(self):
original = Triple(
s=Term(type=IRI, iri="urn:s"),
p=Term(type=IRI, iri="urn:p"),
o=Term(type=LITERAL, value="hello"),
)
result = set_graph([original], "urn:g")[0]
assert result.s.iri == "urn:s"
assert result.p.iri == "urn:p"
assert result.o.value == "hello"
# ---------------------------------------------------------------------------
# document_triples
# ---------------------------------------------------------------------------
class TestDocumentTriples:
DOC_URI = "https://example.com/doc/abc"
def test_minimal_document(self):
triples = document_triples(self.DOC_URI)
assert has_type(triples, self.DOC_URI, PROV_ENTITY)
assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE)
assert len(triples) == 2
def test_with_title(self):
triples = document_triples(self.DOC_URI, title="My Doc")
title_t = find_triple(triples, DC_TITLE)
assert title_t is not None
assert title_t.o.value == "My Doc"
# Title also creates an rdfs:label
label_t = find_triple(triples, RDFS_LABEL)
assert label_t is not None
assert label_t.o.value == "My Doc"
def test_with_source(self):
triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf")
source_t = find_triple(triples, DC_SOURCE)
assert source_t is not None
assert source_t.o.type == IRI
assert source_t.o.iri == "https://source.com/f.pdf"
def test_with_date(self):
triples = document_triples(self.DOC_URI, date="2024-01-15")
date_t = find_triple(triples, DC_DATE)
assert date_t is not None
assert date_t.o.value == "2024-01-15"
def test_with_creator(self):
triples = document_triples(self.DOC_URI, creator="Alice")
creator_t = find_triple(triples, DC_CREATOR)
assert creator_t is not None
assert creator_t.o.value == "Alice"
def test_with_page_count(self):
triples = document_triples(self.DOC_URI, page_count=42)
pc_t = find_triple(triples, TG_PAGE_COUNT)
assert pc_t is not None
assert pc_t.o.value == "42"
def test_with_page_count_zero(self):
triples = document_triples(self.DOC_URI, page_count=0)
pc_t = find_triple(triples, TG_PAGE_COUNT)
assert pc_t is not None
assert pc_t.o.value == "0"
def test_with_mime_type(self):
triples = document_triples(self.DOC_URI, mime_type="application/pdf")
mt_t = find_triple(triples, TG_MIME_TYPE)
assert mt_t is not None
assert mt_t.o.value == "application/pdf"
def test_all_metadata(self):
triples = document_triples(
self.DOC_URI,
title="Test",
source="https://s.com",
date="2024-01-01",
creator="Bob",
page_count=10,
mime_type="application/pdf",
)
# 2 type triples + title + label + source + date + creator + page_count + mime_type
assert len(triples) == 9
def test_subject_is_doc_uri(self):
triples = document_triples(self.DOC_URI, title="T")
for t in triples:
assert t.s.iri == self.DOC_URI
# ---------------------------------------------------------------------------
# derived_entity_triples
# ---------------------------------------------------------------------------
class TestDerivedEntityTriples:
ENTITY_URI = "https://example.com/doc/abc/p1"
PARENT_URI = "https://example.com/doc/abc"
def test_page_entity_has_page_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
page_number=1,
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",
chunk_index=0,
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
def test_no_specific_type_without_page_or_chunk(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"component", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
def test_was_derived_from_parent(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI)
assert derived is not None
assert derived.o.iri == self.PARENT_URI
def test_activity_created(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
# Entity was generated by an activity
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
assert gen is not None
act_uri = gen.o.iri
# Activity has correct type and metadata
assert has_type(triples, act_uri, PROV_ACTIVITY)
# Activity used the parent
used = find_triple(triples, PROV_USED, act_uri)
assert used is not None
assert used.o.iri == self.PARENT_URI
# Activity has component version
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
assert version is not None
assert version.o.value == "1.0"
def test_agent_created(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
# Find the agent URI via wasAssociatedWith
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
act_uri = gen.o.iri
assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri)
assert assoc is not None
agt_uri = assoc.o.iri
assert has_type(triples, agt_uri, PROV_AGENT)
label = find_triple(triples, RDFS_LABEL, agt_uri)
assert label is not None
assert label.o.value == "pdf-extractor"
def test_timestamp_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
timestamp="2024-06-15T12:30:00Z",
)
ts = find_triple(triples, PROV_STARTED_AT_TIME)
assert ts is not None
assert ts.o.value == "2024-06-15T12:30:00Z"
def test_default_timestamp_generated(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
)
ts = find_triple(triples, PROV_STARTED_AT_TIME)
assert ts is not None
assert len(ts.o.value) > 0
def test_optional_label(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
label="Page 1",
timestamp="2024-01-01T00:00:00Z",
)
label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI)
assert label is not None
assert label.o.value == "Page 1"
def test_page_number_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"pdf-extractor", "1.0",
page_number=3,
timestamp="2024-01-01T00:00:00Z",
)
pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI)
assert pn is not None
assert pn.o.value == "3"
def test_chunk_metadata_recorded(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "2.0",
chunk_index=5,
char_offset=1000,
char_length=500,
chunk_size=512,
chunk_overlap=64,
timestamp="2024-01-01T00:00:00Z",
)
ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI)
assert ci is not None and ci.o.value == "5"
co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI)
assert co is not None and co.o.value == "1000"
cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI)
assert cl is not None and cl.o.value == "500"
# chunk_size and chunk_overlap are on the activity, not the entity
cs = find_triple(triples, TG_CHUNK_SIZE)
assert cs is not None and cs.o.value == "512"
ov = find_triple(triples, TG_CHUNK_OVERLAP)
assert ov is not None and ov.o.value == "64"
# ---------------------------------------------------------------------------
# subgraph_provenance_triples
# ---------------------------------------------------------------------------
class TestSubgraphProvenanceTriples:
SG_URI = "https://trustgraph.ai/subgraph/test-sg"
CHUNK_URI = "https://example.com/doc/abc/p1/c0"
def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"):
return Triple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
def test_contains_quoted_triples(self):
extracted = [self._make_extracted_triple()]
triples = subgraph_provenance_triples(
self.SG_URI, extracted, self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 1
assert contains[0].o.type == TRIPLE
assert contains[0].o.triple.s.iri == "urn:e:Alice"
assert contains[0].o.triple.p.iri == "urn:r:knows"
assert contains[0].o.triple.o.iri == "urn:e:Bob"
def test_multiple_extracted_triples(self):
extracted = [
self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"),
self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"),
self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"),
]
triples = subgraph_provenance_triples(
self.SG_URI, extracted, self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 3
def test_empty_extracted_triples(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
assert len(contains) == 0
# Should still have subgraph provenance metadata
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
def test_subgraph_has_correct_types(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
assert has_type(triples, self.SG_URI, PROV_ENTITY)
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
def test_derived_from_chunk(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI)
assert derived is not None
assert derived.o.iri == self.CHUNK_URI
def test_activity_and_agent(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI)
assert gen is not None
act_uri = gen.o.iri
assert has_type(triples, act_uri, PROV_ACTIVITY)
used = find_triple(triples, PROV_USED, act_uri)
assert used is not None
assert used.o.iri == self.CHUNK_URI
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
assert version is not None
assert version.o.value == "1.0"
def test_optional_llm_model(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
llm_model="claude-3-opus",
timestamp="2024-01-01T00:00:00Z",
)
llm = find_triple(triples, TG_LLM_MODEL)
assert llm is not None
assert llm.o.value == "claude-3-opus"
def test_no_llm_model_when_omitted(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
timestamp="2024-01-01T00:00:00Z",
)
llm = find_triple(triples, TG_LLM_MODEL)
assert llm is None
def test_optional_ontology(self):
triples = subgraph_provenance_triples(
self.SG_URI, [], self.CHUNK_URI,
"kg-extractor", "1.0",
ontology_uri="https://example.com/ontology/v1",
timestamp="2024-01-01T00:00:00Z",
)
ont = find_triple(triples, TG_ONTOLOGY)
assert ont is not None
assert ont.o.type == IRI
assert ont.o.iri == "https://example.com/ontology/v1"
# ---------------------------------------------------------------------------
# GraphRAG query-time triples
# ---------------------------------------------------------------------------
class TestQuestionTriples:
Q_URI = "urn:trustgraph:question:test-session"
def test_question_types(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
def test_question_query_text(self):
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
query = find_triple(triples, TG_QUERY, self.Q_URI)
assert query is not None
assert query.o.value == "What is AI?"
def test_question_timestamp(self):
triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
assert ts is not None
assert ts.o.value == "2024-06-15T10:00:00Z"
def test_question_default_timestamp(self):
triples = question_triples(self.Q_URI, "Q")
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
assert ts is not None
assert len(ts.o.value) > 0
def test_question_label(self):
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
assert label is not None
assert label.o.value == "GraphRAG Question"
def test_question_triple_count(self):
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
assert len(triples) == 6
class TestGroundingTriples:
GND_URI = "urn:trustgraph:prov:grounding:test-session"
Q_URI = "urn:trustgraph:question:test-session"
def test_grounding_types(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
assert has_type(triples, self.GND_URI, PROV_ENTITY)
assert has_type(triples, self.GND_URI, TG_GROUNDING)
def test_grounding_generated_by_question(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
assert gen is not None
assert gen.o.iri == self.Q_URI
def test_grounding_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 3
values = {t.o.value for t in concepts}
assert values == {"AI", "ML", "robots"}
def test_grounding_empty_concepts(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
assert len(concepts) == 0
def test_grounding_label(self):
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
assert label is not None
assert label.o.value == "Grounding"
class TestExplorationTriples:
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
GND_URI = "urn:trustgraph:prov:grounding:test-session"
def test_exploration_types(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_exploration_derived_from_grounding(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived is not None
assert derived.o.iri == self.GND_URI
def test_exploration_edge_count(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "15"
def test_exploration_zero_edges(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
assert ec is not None
assert ec.o.value == "0"
def test_exploration_with_entities(self):
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
assert len(ent_triples) == 2
def test_exploration_triple_count(self):
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
assert len(triples) == 5
class TestFocusTriples:
FOC_URI = "urn:trustgraph:prov:focus:test-session"
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
SESSION_ID = "test-session"
def test_focus_types(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
assert has_type(triples, self.FOC_URI, PROV_ENTITY)
assert has_type(triples, self.FOC_URI, TG_FOCUS)
def test_focus_derived_from_exploration(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI)
assert derived is not None
assert derived.o.iri == self.EXP_URI
def test_focus_no_edges(self):
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
selected = find_triples(triples, TG_SELECTED_EDGE)
assert len(selected) == 0
def test_focus_with_edges_and_reasoning(self):
edges = [
{
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
"reasoning": "Alice is connected to Bob",
},
{
"edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"),
"reasoning": "Bob works at Acme",
},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
# Two selectedEdge links
selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI)
assert len(selected) == 2
# Each edge selection has a quoted triple
edge_triples = find_triples(triples, TG_EDGE)
assert len(edge_triples) == 2
for et in edge_triples:
assert et.o.type == TRIPLE
# Each edge selection has reasoning
reasoning_triples = find_triples(triples, TG_REASONING)
assert len(reasoning_triples) == 2
def test_focus_edge_without_reasoning(self):
edges = [
{"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
reasoning = find_triples(triples, TG_REASONING)
assert len(reasoning) == 0
def test_focus_edge_without_edge_data(self):
edges = [
{"edge": None, "reasoning": "some reasoning"},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
selected = find_triples(triples, TG_SELECTED_EDGE)
assert len(selected) == 0
def test_focus_quoted_triple_content(self):
edges = [
{
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
"reasoning": "test",
},
]
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
edge_t = find_triple(triples, TG_EDGE)
qt = edge_t.o.triple
assert qt.s.iri == "urn:e:Alice"
assert qt.p.iri == "urn:r:knows"
assert qt.o.iri == "urn:e:Bob"
class TestSynthesisTriples:
SYN_URI = "urn:trustgraph:prov:synthesis:test-session"
FOC_URI = "urn:trustgraph:prov:focus:test-session"
def test_synthesis_types(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_synthesis_derived_from_focus(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived is not None
assert derived.o.iri == self.FOC_URI
def test_synthesis_with_document_reference(self):
triples = synthesis_triples(
self.SYN_URI, self.FOC_URI,
document_id="urn:trustgraph:question:abc/answer",
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is not None
assert doc.o.type == IRI
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
def test_synthesis_no_document(self):
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is None
# ---------------------------------------------------------------------------
# DocumentRAG query-time triples
# ---------------------------------------------------------------------------
class TestDocRagQuestionTriples:
Q_URI = "urn:trustgraph:docrag:test-session"
def test_docrag_question_types(self):
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
assert has_type(triples, self.Q_URI, TG_QUESTION)
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
def test_docrag_question_label(self):
triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
assert label.o.value == "DocumentRAG Question"
def test_docrag_question_query_text(self):
triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z")
query = find_triple(triples, TG_QUERY, self.Q_URI)
assert query.o.value == "search query"
class TestDocRagExplorationTriples:
EXP_URI = "urn:trustgraph:docrag:test/exploration"
GND_URI = "urn:trustgraph:docrag:test/grounding"
def test_docrag_exploration_types(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
def test_docrag_exploration_derived_from_grounding(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
assert derived.o.iri == self.GND_URI
def test_docrag_exploration_chunk_count(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
assert cc.o.value == "7"
def test_docrag_exploration_without_chunk_ids(self):
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
chunks = find_triples(triples, TG_SELECTED_CHUNK)
assert len(chunks) == 0
def test_docrag_exploration_with_chunk_ids(self):
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
assert len(chunks) == 3
chunk_uris = {t.o.iri for t in chunks}
assert chunk_uris == set(chunk_ids)
class TestDocRagSynthesisTriples:
SYN_URI = "urn:trustgraph:docrag:test/synthesis"
EXP_URI = "urn:trustgraph:docrag:test/exploration"
def test_docrag_synthesis_types(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
def test_docrag_synthesis_derived_from_exploration(self):
"""DocRAG skips the focus step — synthesis derives from exploration."""
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
assert derived.o.iri == self.EXP_URI
def test_docrag_synthesis_has_answer_type(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
def test_docrag_synthesis_with_document(self):
triples = docrag_synthesis_triples(
self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans"
)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc.o.iri == "urn:doc:ans"
def test_docrag_synthesis_no_document(self):
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
assert doc is None

View file

@ -0,0 +1,292 @@
"""
Tests for provenance URI generation functions.
"""
import pytest
from unittest.mock import patch
from trustgraph.provenance.uris import (
TRUSTGRAPH_BASE,
_encode_id,
document_uri,
page_uri,
chunk_uri_from_page,
chunk_uri_from_doc,
activity_uri,
subgraph_uri,
agent_uri,
question_uri,
exploration_uri,
focus_uri,
synthesis_uri,
edge_selection_uri,
agent_session_uri,
agent_iteration_uri,
agent_final_uri,
docrag_question_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
)
class TestEncodeId:
"""Tests for the _encode_id helper."""
def test_plain_string(self):
assert _encode_id("abc123") == "abc123"
def test_string_with_spaces(self):
assert _encode_id("hello world") == "hello%20world"
def test_string_with_slashes(self):
assert _encode_id("a/b/c") == "a%2Fb%2Fc"
def test_integer_input(self):
assert _encode_id(42) == "42"
def test_empty_string(self):
assert _encode_id("") == ""
def test_special_characters(self):
result = _encode_id("name@domain.com")
assert "@" not in result or result == "name%40domain.com"
class TestDocumentUris:
"""Tests for document, page, and chunk URI generation."""
def test_document_uri_passthrough(self):
iri = "https://example.com/doc/123"
assert document_uri(iri) == iri
def test_page_uri_format(self):
result = page_uri("https://example.com/doc/123", 5)
assert result == "https://example.com/doc/123/p5"
def test_page_uri_page_zero(self):
result = page_uri("https://example.com/doc/123", 0)
assert result == "https://example.com/doc/123/p0"
def test_chunk_uri_from_page_format(self):
result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
assert result == "https://example.com/doc/123/p2/c3"
def test_chunk_uri_from_doc_format(self):
result = chunk_uri_from_doc("https://example.com/doc/123", 7)
assert result == "https://example.com/doc/123/c7"
def test_page_uri_preserves_doc_iri(self):
doc = "urn:isbn:978-3-16-148410-0"
result = page_uri(doc, 1)
assert result.startswith(doc)
def test_chunk_from_page_hierarchy(self):
"""Chunk URI should contain both page and chunk identifiers."""
result = chunk_uri_from_page("https://example.com/doc", 3, 5)
assert "/p3/" in result
assert result.endswith("/c5")
class TestActivityAndSubgraphUris:
"""Tests for activity_uri, subgraph_uri, and agent_uri."""
def test_activity_uri_with_id(self):
result = activity_uri("my-activity-id")
assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id"
def test_activity_uri_auto_generates_uuid(self):
result = activity_uri()
assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/")
# UUID part should be non-empty
uuid_part = result.split("/activity/")[1]
assert len(uuid_part) > 0
def test_activity_uri_unique_uuids(self):
r1 = activity_uri()
r2 = activity_uri()
assert r1 != r2
def test_activity_uri_encodes_special_chars(self):
result = activity_uri("id with spaces")
assert "id%20with%20spaces" in result
def test_subgraph_uri_with_id(self):
result = subgraph_uri("sg-123")
assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123"
def test_subgraph_uri_auto_generates_uuid(self):
result = subgraph_uri()
assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/")
uuid_part = result.split("/subgraph/")[1]
assert len(uuid_part) > 0
def test_subgraph_uri_unique_uuids(self):
r1 = subgraph_uri()
r2 = subgraph_uri()
assert r1 != r2
def test_agent_uri_format(self):
result = agent_uri("pdf-extractor")
assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor"
def test_agent_uri_encodes_special_chars(self):
result = agent_uri("my component")
assert "my%20component" in result
class TestGraphRagQueryUris:
"""Tests for GraphRAG query-time provenance URIs."""
FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000"
def test_question_uri_with_session_id(self):
result = question_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:question:{self.FIXED_UUID}"
def test_question_uri_auto_generates(self):
result = question_uri()
assert result.startswith("urn:trustgraph:question:")
uuid_part = result.split("urn:trustgraph:question:")[1]
assert len(uuid_part) > 0
def test_question_uri_unique(self):
r1 = question_uri()
r2 = question_uri()
assert r1 != r2
def test_exploration_uri_format(self):
result = exploration_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}"
def test_focus_uri_format(self):
result = focus_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}"
def test_synthesis_uri_format(self):
result = synthesis_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}"
def test_edge_selection_uri_format(self):
result = edge_selection_uri(self.FIXED_UUID, 3)
assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3"
def test_edge_selection_uri_zero_index(self):
result = edge_selection_uri(self.FIXED_UUID, 0)
assert result.endswith(":0")
def test_session_uris_share_session_id(self):
"""All URIs for a session should contain the same session ID."""
sid = self.FIXED_UUID
q = question_uri(sid)
e = exploration_uri(sid)
f = focus_uri(sid)
s = synthesis_uri(sid)
for uri in [q, e, f, s]:
assert sid in uri
class TestAgentProvenanceUris:
"""Tests for agent provenance URIs."""
FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000"
def test_agent_session_uri_with_id(self):
result = agent_session_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}"
def test_agent_session_uri_auto_generates(self):
result = agent_session_uri()
assert result.startswith("urn:trustgraph:agent:")
def test_agent_session_uri_unique(self):
r1 = agent_session_uri()
r2 = agent_session_uri()
assert r1 != r2
def test_agent_iteration_uri_format(self):
result = agent_iteration_uri(self.FIXED_UUID, 1)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1"
def test_agent_iteration_uri_numbering(self):
r1 = agent_iteration_uri(self.FIXED_UUID, 1)
r2 = agent_iteration_uri(self.FIXED_UUID, 2)
assert r1 != r2
assert r1.endswith("/i1")
assert r2.endswith("/i2")
def test_agent_final_uri_format(self):
result = agent_final_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final"
def test_agent_uris_share_session_id(self):
sid = self.FIXED_UUID
session = agent_session_uri(sid)
iteration = agent_iteration_uri(sid, 1)
final = agent_final_uri(sid)
for uri in [session, iteration, final]:
assert sid in uri
class TestDocRagProvenanceUris:
"""Tests for Document RAG provenance URIs."""
FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000"
def test_docrag_question_uri_with_id(self):
result = docrag_question_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}"
def test_docrag_question_uri_auto_generates(self):
result = docrag_question_uri()
assert result.startswith("urn:trustgraph:docrag:")
def test_docrag_question_uri_unique(self):
r1 = docrag_question_uri()
r2 = docrag_question_uri()
assert r1 != r2
def test_docrag_exploration_uri_format(self):
result = docrag_exploration_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration"
def test_docrag_synthesis_uri_format(self):
result = docrag_synthesis_uri(self.FIXED_UUID)
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis"
def test_docrag_uris_share_session_id(self):
sid = self.FIXED_UUID
q = docrag_question_uri(sid)
e = docrag_exploration_uri(sid)
s = docrag_synthesis_uri(sid)
for uri in [q, e, s]:
assert sid in uri
class TestUriNamespaceIsolation:
"""Verify that different provenance types use distinct URI namespaces."""
FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000"
def test_graphrag_vs_agent_namespace(self):
graphrag = question_uri(self.FIXED_UUID)
agent = agent_session_uri(self.FIXED_UUID)
assert graphrag != agent
assert "question" in graphrag
assert "agent" in agent
def test_graphrag_vs_docrag_namespace(self):
graphrag = question_uri(self.FIXED_UUID)
docrag = docrag_question_uri(self.FIXED_UUID)
assert graphrag != docrag
def test_agent_vs_docrag_namespace(self):
agent = agent_session_uri(self.FIXED_UUID)
docrag = docrag_question_uri(self.FIXED_UUID)
assert agent != docrag
def test_extraction_vs_query_namespace(self):
"""Extraction URIs use https://, query URIs use urn:."""
ext = activity_uri(self.FIXED_UUID)
query = question_uri(self.FIXED_UUID)
assert ext.startswith("https://")
assert query.startswith("urn:")

View file

@ -0,0 +1,124 @@
"""
Tests for provenance vocabulary bootstrap.
"""
import pytest
from trustgraph.schema import Triple, Term, IRI, LITERAL
from trustgraph.provenance.vocabulary import (
get_vocabulary_triples,
PROV_CLASS_LABELS,
PROV_PREDICATE_LABELS,
DC_PREDICATE_LABELS,
SCHEMA_LABELS,
SKOS_LABELS,
TG_CLASS_LABELS,
TG_PREDICATE_LABELS,
)
from trustgraph.provenance.namespaces import (
RDFS_LABEL,
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
)
class TestVocabularyTriples:
"""Tests for the vocabulary bootstrap function."""
def test_returns_list_of_triples(self):
result = get_vocabulary_triples()
assert isinstance(result, list)
assert len(result) > 0
for t in result:
assert isinstance(t, Triple)
def test_all_triples_are_label_triples(self):
"""Every vocabulary triple should use rdfs:label as predicate."""
for t in get_vocabulary_triples():
assert t.p.type == IRI
assert t.p.iri == RDFS_LABEL
def test_all_subjects_are_iris(self):
for t in get_vocabulary_triples():
assert t.s.type == IRI
assert len(t.s.iri) > 0
def test_all_objects_are_literals(self):
for t in get_vocabulary_triples():
assert t.o.type == LITERAL
assert len(t.o.value) > 0
def test_no_duplicate_subjects(self):
subjects = [t.s.iri for t in get_vocabulary_triples()]
assert len(subjects) == len(set(subjects))
def test_includes_prov_classes(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert PROV_ENTITY in subjects
assert PROV_ACTIVITY in subjects
assert PROV_AGENT in subjects
def test_includes_prov_predicates(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert PROV_WAS_DERIVED_FROM in subjects
assert PROV_WAS_GENERATED_BY in subjects
assert PROV_USED in subjects
assert PROV_WAS_ASSOCIATED_WITH in subjects
assert PROV_STARTED_AT_TIME in subjects
def test_includes_dc_predicates(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert DC_TITLE in subjects
assert DC_SOURCE in subjects
assert DC_DATE in subjects
assert DC_CREATOR in subjects
def test_includes_tg_classes(self):
subjects = {t.s.iri for t in get_vocabulary_triples()}
assert TG_DOCUMENT_TYPE in subjects
assert TG_PAGE_TYPE in subjects
assert TG_CHUNK_TYPE in subjects
assert TG_SUBGRAPH_TYPE in subjects
def test_component_lists_sum_to_total(self):
total = get_vocabulary_triples()
components = (
PROV_CLASS_LABELS +
PROV_PREDICATE_LABELS +
DC_PREDICATE_LABELS +
SCHEMA_LABELS +
SKOS_LABELS +
TG_CLASS_LABELS +
TG_PREDICATE_LABELS
)
assert len(total) == len(components)
def test_idempotent(self):
"""Calling twice should return equivalent triples."""
r1 = get_vocabulary_triples()
r2 = get_vocabulary_triples()
assert len(r1) == len(r2)
for t1, t2 in zip(r1, r2):
assert t1.s.iri == t2.s.iri
assert t1.o.value == t2.o.value
class TestNamespaceConstants:
"""Verify namespace constants are well-formed IRIs."""
def test_prov_namespace_prefix(self):
assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#")
def test_dc_namespace_prefix(self):
assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/")
def test_tg_namespace_prefix(self):
assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/")
def test_rdfs_label_iri(self):
assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label"

View file

@ -37,7 +37,7 @@ def mock_qdrant_client():
def mock_graph_embeddings_request():
"""Mock graph embeddings request message"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -46,9 +46,9 @@ def mock_graph_embeddings_request():
@pytest.fixture
def mock_graph_embeddings_multiple_vectors():
"""Mock graph embeddings request with multiple vectors"""
"""Mock graph embeddings request with multiple vectors (legacy name, now single vector)"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2, 0.3, 0.4]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
@ -82,7 +82,7 @@ def mock_graph_embeddings_uri_response():
def mock_document_embeddings_request():
"""Mock document embeddings request message"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -91,9 +91,9 @@ def mock_document_embeddings_request():
@pytest.fixture
def mock_document_embeddings_multiple_vectors():
"""Mock document embeddings request with multiple vectors"""
"""Mock document embeddings request with multiple vectors (legacy name, now single vector)"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2, 0.3, 0.4]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
@ -139,9 +139,9 @@ def mock_large_query_response():
@pytest.fixture
def mock_mixed_dimension_vectors():
"""Mock request with vectors of different dimensions"""
"""Mock request with vector (legacy name suggested mixed dimensions, now single vector)"""
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.doc_embeddings.milvus.service import Processor
from trustgraph.schema import DocumentEmbeddingsRequest
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
class TestMilvusDocEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -71,15 +71,15 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
# Mock search results
mock_results = [
{"entity": {"doc": "First document chunk"}},
{"entity": {"doc": "Second document chunk"}},
{"entity": {"doc": "Third document chunk"}},
{"entity": {"chunk_id": "First document chunk"}},
{"entity": {"chunk_id": "Second document chunk"}},
{"entity": {"chunk_id": "Third document chunk"}},
]
processor.vecstore.search.return_value = mock_results
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
)
# Verify results are document chunks
# Verify results are ChunkMatch objects
assert len(result) == 3
assert result[0] == "First document chunk"
assert result[1] == "Second document chunk"
assert result[2] == "Third document chunk"
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == "First document chunk"
assert result[1].chunk_id == "Second document chunk"
assert result[2].chunk_id == "Third document chunk"
@pytest.mark.asyncio
async def test_query_document_embeddings_multiple_vectors(self, processor):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_longer_vector(self, processor):
"""Test querying document embeddings with a longer vector"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=3
)
# Mock search results - different results for each vector
mock_results_1 = [
{"entity": {"doc": "Document from first vector"}},
{"entity": {"doc": "Another doc from first vector"}},
# Mock search results
mock_results = [
{"entity": {"chunk_id": "First document"}},
{"entity": {"chunk_id": "Second document"}},
{"entity": {"chunk_id": "Third document"}},
]
mock_results_2 = [
{"entity": {"doc": "Document from second vector"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results from all vectors are combined
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
)
# Verify results are ChunkMatch objects
assert len(result) == 3
assert "Document from first vector" in result
assert "Another doc from first vector" in result
assert "Document from second vector" in result
chunk_ids = [r.chunk_id for r in result]
assert "First document" in chunk_ids
assert "Second document" in chunk_ids
assert "Third document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_with_limit(self, processor):
@ -141,16 +135,16 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
# Mock search results - more results than limit
mock_results = [
{"entity": {"doc": "Document 1"}},
{"entity": {"doc": "Document 2"}},
{"entity": {"doc": "Document 3"}},
{"entity": {"doc": "Document 4"}},
{"entity": {"chunk_id": "Document 1"}},
{"entity": {"chunk_id": "Document 2"}},
{"entity": {"chunk_id": "Document 3"}},
{"entity": {"chunk_id": "Document 4"}},
]
processor.vecstore.search.return_value = mock_results
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -211,25 +205,26 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
# Mock search results with Unicode content
mock_results = [
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
{"entity": {"doc": "Regular ASCII document"}},
{"entity": {"doc": "Document with émojis: 😀🎉"}},
{"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}},
{"entity": {"chunk_id": "Regular ASCII document"}},
{"entity": {"chunk_id": "Document with émojis: 😀🎉"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify Unicode content is preserved
# Verify Unicode content is preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with Unicode: éñ中文🚀" in result
assert "Regular ASCII document" in result
assert "Document with émojis: 😀🎉" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with Unicode: éñ中文🚀" in chunk_ids
assert "Regular ASCII document" in chunk_ids
assert "Document with émojis: 😀🎉" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_large_documents(self, processor):
@ -237,24 +232,25 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
# Mock search results with large content
large_doc = "A" * 10000 # 10KB of content
mock_results = [
{"entity": {"doc": large_doc}},
{"entity": {"doc": "Small document"}},
{"entity": {"chunk_id": large_doc}},
{"entity": {"chunk_id": "Small document"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify large content is preserved
# Verify large content is preserved in ChunkMatch objects
assert len(result) == 2
assert large_doc in result
assert "Small document" in result
chunk_ids = [r.chunk_id for r in result]
assert large_doc in chunk_ids
assert "Small document" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_special_characters(self, processor):
@ -262,25 +258,26 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
# Mock search results with special characters
mock_results = [
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
{"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}},
{"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}},
{"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify special characters are preserved
# Verify special characters are preserved in ChunkMatch objects
assert len(result) == 3
assert "Document with \"quotes\" and 'apostrophes'" in result
assert "Document with\nnewlines\tand\ttabs" in result
assert "Document with special chars: @#$%^&*()" in result
chunk_ids = [r.chunk_id for r in result]
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
assert "Document with special chars: @#$%^&*()" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_zero_limit(self, processor):
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=-1
)
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"chunk_id": "Document 1"}},
{"entity": {"chunk_id": "Document 2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
assert "Document from 2D vector" in result
assert "Document from 4D vector" in result
assert "Document from 3D vector" in result
# Verify search was called with the vector
processor.vecstore.search.assert_called_once()
# Verify results are ChunkMatch objects
assert len(result) == 2
chunk_ids = [r.chunk_id for r in result]
assert "Document 1" in chunk_ids
assert "Document 2" in chunk_ids
@pytest.mark.asyncio
async def test_query_document_embeddings_duplicate_documents(self, processor):
"""Test querying document embeddings with duplicate documents in results"""
async def test_query_document_embeddings_multiple_results(self, processor):
"""Test querying document embeddings with multiple results"""
query = DocumentEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates across vectors
mock_results_1 = [
{"entity": {"doc": "Document A"}},
{"entity": {"doc": "Document B"}},
# Mock search results with multiple documents
mock_results = [
{"entity": {"chunk_id": "Document A"}},
{"entity": {"chunk_id": "Document B"}},
{"entity": {"chunk_id": "Document C"}},
]
mock_results_2 = [
{"entity": {"doc": "Document B"}}, # Duplicate
{"entity": {"doc": "Document C"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_document_embeddings(query)
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
# This preserves ranking and allows multiple occurrences
assert len(result) == 4
assert result.count("Document B") == 2 # Should appear twice
assert "Document A" in result
assert "Document C" in result
# Verify results are ChunkMatch objects
assert len(result) == 3
chunk_ids = [r.chunk_id for r in result]
assert "Document A" in chunk_ids
assert "Document B" in chunk_ids
assert "Document C" in chunk_ids
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
@ -458,5 +449,5 @@ class TestMilvusDocEmbeddingsQueryProcessor:
mock_launch.assert_called_once_with(
default_ident,
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
"\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n"
)

View file

@ -18,10 +18,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -103,7 +100,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_single_vector(self, processor):
"""Test querying document embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -179,7 +176,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -208,7 +205,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -226,7 +223,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -242,12 +239,9 @@ class TestPineconeDocEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions using same index"""
"""Test querying with single vector (legacy test name, schema now uses single vector)"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -285,7 +279,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -304,7 +298,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -325,7 +319,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_unicode_content(self, processor):
"""Test querying document embeddings with Unicode content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -351,7 +345,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_large_content(self, processor):
"""Test querying document embeddings with large content results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 1
message.user = 'test_user'
message.collection = 'test_collection'
@ -377,7 +371,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_mixed_content_types(self, processor):
"""Test querying document embeddings with mixed content types"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -409,7 +403,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -425,7 +419,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
async def test_query_document_embeddings_index_access_failure(self, processor):
"""Test handling of index access failure"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -437,13 +431,9 @@ class TestPineconeDocEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_document_embeddings_vector_accumulation(self, processor):
"""Test that results from multiple vectors are properly accumulated"""
"""Test that results from single vector query are returned (legacy multi-vector test)"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.doc_embeddings.qdrant.service import Processor
from trustgraph.schema import ChunkMatch
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -77,9 +78,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'first document chunk'}
mock_point1.payload = {'chunk_id': 'first document chunk'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'second document chunk'}
mock_point2.payload = {'chunk_id': 'second document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected documents
# Verify result contains expected ChunkMatch objects
assert len(result) == 2
# Results should be strings (document chunks)
assert isinstance(result[0], str)
assert isinstance(result[1], str)
# Results should be ChunkMatch objects
assert isinstance(result[0], ChunkMatch)
assert isinstance(result[1], ChunkMatch)
# Verify content
assert result[0] == 'first document chunk'
assert result[1] == 'second document chunk'
assert result[0].chunk_id == 'first document chunk'
assert result[1].chunk_id == 'second document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with multiple vectors"""
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings returns multiple results"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses for different vectors
# Mock query response with multiple results
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from vector 1'}
mock_point1.payload = {'chunk_id': 'document chunk 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from vector 2'}
mock_point2.payload = {'chunk_id': 'document chunk 2'}
mock_point3 = MagicMock()
mock_point3.payload = {'doc': 'another document from vector 2'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2, mock_point3]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point3.payload = {'chunk_id': 'document chunk 3'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2, mock_point3]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried correctly
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify results from both vectors are combined
# Verify results are ChunkMatch objects
assert len(result) == 3
assert 'document from vector 1' in result
assert 'document from vector 2' in result
assert 'another document from vector 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document chunk 1' in chunk_ids
assert 'document chunk 2' in chunk_ids
assert 'document chunk 3' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -192,7 +190,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_points = []
for i in range(10):
mock_point = MagicMock()
mock_point.payload = {'doc': f'document chunk {i}'}
mock_point.payload = {'chunk_id': f'document chunk {i}'}
mock_points.append(mock_point)
mock_response = MagicMock()
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
"""Test querying document embeddings with different vector dimensions"""
"""Test querying document embeddings with a higher dimension vector"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query responses
# Mock query response
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document from 2D vector'}
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document from 3D vector'}
mock_response1 = MagicMock()
mock_response1.points = [mock_point1]
mock_response2 = MagicMock()
mock_response2.points = [mock_point2]
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
mock_point2.payload = {'chunk_id': 'another 5D document'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_qdrant_instance.query_points.return_value = mock_response
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with 5D vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_document_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once with correct collection
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Call should use 5D collection
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
# Verify results are ChunkMatch objects
assert len(result) == 2
assert 'document from 2D vector' in result
assert 'document from 3D vector' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document from 5D vector' in chunk_ids
assert 'another 5D document' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -326,9 +319,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with UTF-8 content
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'utf8_user'
mock_message.collection = 'utf8_collection'
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 2
# Verify UTF-8 content works correctly
assert 'Document with UTF-8: café, naïve, résumé' in result
assert 'Chinese text: 你好世界' in result
# Verify UTF-8 content works correctly in ChunkMatch objects
chunk_ids = [r.chunk_id for r in result]
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
assert 'Chinese text: 你好世界' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -399,7 +393,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response
mock_point = MagicMock()
mock_point.payload = {'doc': 'document chunk'}
mock_point.payload = {'chunk_id': 'document chunk'}
mock_response = MagicMock()
mock_response.points = [mock_point]
mock_qdrant_instance.query_points.return_value = mock_response
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 0
# Result should contain all returned documents
# Result should contain all returned documents as ChunkMatch objects
assert len(result) == 1
assert result[0] == 'document chunk'
assert isinstance(result[0], ChunkMatch)
assert result[0].chunk_id == 'document chunk'
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -442,9 +437,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Mock query response with fewer results than limit
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'document 1'}
mock_point1.payload = {'chunk_id': 'document 1'}
mock_point2 = MagicMock()
mock_point2.payload = {'doc': 'document 2'}
mock_point2.payload = {'chunk_id': 'document 2'}
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with large limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 1000 # Large limit
mock_message.user = 'large_user'
mock_message.collection = 'large_collection'
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance.query_points.assert_called_once()
call_args = mock_qdrant_instance.query_points.call_args
assert call_args[1]['limit'] == 1000
# Result should contain all available documents
# Result should contain all available documents as ChunkMatch objects
assert len(result) == 2
assert 'document 1' in result
assert 'document 2' in result
chunk_ids = [r.chunk_id for r in result]
assert 'document 1' in chunk_ids
assert 'document 2' in chunk_ids
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
@ -487,11 +483,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
# Mock query response with missing 'doc' key
# Mock query response with missing 'chunk_id' key
mock_point1 = MagicMock()
mock_point1.payload = {'doc': 'valid document'}
mock_point1.payload = {'chunk_id': 'valid document'}
mock_point2 = MagicMock()
mock_point2.payload = {} # Missing 'doc' key
mock_point2.payload = {} # Missing 'chunk_id' key
mock_point3 = MagicMock()
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
@ -508,13 +504,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'payload_user'
mock_message.collection = 'payload_collection'
# Act & Assert
# This should raise a KeyError when trying to access payload['doc']
# This should raise a KeyError when trying to access payload['chunk_id']
with pytest.raises(KeyError):
await processor.query_document_embeddings(mock_message)

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
class TestMilvusGraphEmbeddingsQueryProcessor:
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=10
)
return query
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Term objects
# Verify results are converted to EntityMatch objects
assert len(result) == 3
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].type == LITERAL
assert isinstance(result[0], EntityMatch)
assert result[0].entity.iri == "http://example.com/entity1"
assert result[0].entity.type == IRI
assert isinstance(result[1], EntityMatch)
assert result[1].entity.iri == "http://example.com/entity2"
assert result[1].entity.type == IRI
assert isinstance(result[2], EntityMatch)
assert result[2].entity.value == "literal entity"
assert result[2].entity.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_multiple_results(self, processor):
"""Test querying graph embeddings returns multiple results"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=3
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results - different results for each vector
mock_results_1 = [
# Mock search results with multiple entities
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called twice with correct parameters including user/collection
expected_calls = [
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
]
assert processor.vecstore.search.call_count == 2
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
actual_call = processor.vecstore.search.call_args_list[i]
assert actual_call[0] == expected_args
assert actual_call[1] == expected_kwargs
# Verify results are deduplicated and limited
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
)
# Verify results are EntityMatch objects
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=2
)
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 2
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication(self, processor):
"""Test that duplicate entities are properly deduplicated"""
async def test_query_graph_embeddings_preserves_order(self, processor):
"""Test that query results preserve order from the vector store"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=5
)
# Mock search results with duplicates
mock_results_1 = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
mock_results_2 = [
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
{"entity": {"entity": "http://example.com/entity3"}}, # New
]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
result = await processor.query_graph_embeddings(query)
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
limit=2
)
# Mock search results - first vector returns enough results
mock_results_1 = [
# Mock search results in specific order
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results_1
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify only first vector was searched (limit reached)
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
# Verify results are in the same order as returned by the store
assert len(result) == 3
assert result[0].entity.iri == "http://example.com/entity1"
assert result[1].entity.iri == "http://example.com/entity2"
assert result[2].entity.iri == "http://example.com/entity3"
@pytest.mark.asyncio
async def test_query_graph_embeddings_results_limited(self, processor):
"""Test that results are properly limited when store returns more than requested"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
limit=2
)
# Verify results are limited
# Mock search results - returns more results than limit
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
{"entity": {"entity": "http://example.com/entity3"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify search was called with the full vector
processor.vecstore.search.assert_called_once_with(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
)
# Verify results are limited to requested amount
assert len(result) == 2
@pytest.mark.asyncio
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[],
vector=[],
limit=5
)
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify all results are properly typed
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.type == IRI]
uri_results = [r for r in result if r.entity.type == IRI]
assert len(uri_results) == 2
uri_values = [r.iri for r in uri_results]
uri_values = [r.entity.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.type == IRI]
literal_results = [r for r in result if not r.entity.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
literal_values = [r.entity.value for r in literal_results]
assert "literal entity text" in literal_values
assert "another literal" in literal_values
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=5
)
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[[0.1, 0.2, 0.3]],
vector=[0.1, 0.2, 0.3],
limit=0
)
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 0
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying graph embeddings with different vector dimensions"""
async def test_query_graph_embeddings_longer_vector(self, processor):
"""Test querying graph embeddings with a longer vector"""
query = GraphEmbeddingsRequest(
user='test_user',
collection='test_collection',
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
[0.7, 0.8, 0.9] # 3D vector
],
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
limit=5
)
# Mock search results for each vector
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
# Mock search results
mock_results = [
{"entity": {"entity": "http://example.com/entity1"}},
{"entity": {"entity": "http://example.com/entity2"}},
]
processor.vecstore.search.return_value = mock_results
result = await processor.query_graph_embeddings(query)
# Verify all vectors were searched
assert processor.vecstore.search.call_count == 3
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values
# Verify search was called once with the full vector
processor.vecstore.search.assert_called_once()
# Verify results
assert len(result) == 2
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
class TestPineconeGraphEmbeddingsQueryProcessor:
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
def mock_query_message(self):
"""Create a mock query message for testing"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_single_vector(self, processor):
"""Test querying graph embeddings with a single vector"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
include_metadata=True
)
# Verify results
# Verify results use EntityMatch structure
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].type == IRI
assert entities[1].value == 'entity2'
assert entities[1].type == LITERAL
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].type == IRI
assert entities[0].entity.iri == 'http://example.org/entity1'
assert entities[0].entity.type == IRI
assert entities[1].entity.value == 'entity2'
assert entities[1].entity.type == LITERAL
assert entities[2].entity.iri == 'http://example.org/entity3'
assert entities[2].entity.type == IRI
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
"""Test querying graph embeddings with multiple vectors"""
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
"""Test basic graph embeddings query"""
# Mock index and query results
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query results with distinct entities
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'})
]
# Second query results
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(mock_query_message)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify deduplication occurred
entity_values = [e.value for e in entities]
# Verify query was made once
assert mock_index.query.call_count == 1
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert len(entity_values) == 3
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_limit_handling(self, processor):
"""Test that query respects the limit parameter"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_zero_limit(self, processor):
"""Test querying with zero limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 0
message.user = 'test_user'
message.collection = 'test_collection'
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_negative_limit(self, processor):
"""Test querying with negative limit returns empty results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = -1
message.user = 'test_user'
message.collection = 'test_collection'
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions using same index"""
async def test_query_graph_embeddings_2d_vector(self, processor):
"""Test querying with a 2D vector"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6] # 4D vector
]
message.vector = [0.1, 0.2] # 2D vector
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
# Mock single index that handles all dimensions
# Mock index
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
# Mock results for 2D vector query
mock_results = MagicMock()
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Verify different indexes used for different dimensions
assert processor.pinecone.Index.call_count == 2
index_calls = processor.pinecone.Index.call_args_list
index_names = [call[0][0] for call in index_calls]
assert "t-test_user-test_collection-2" in index_names # 2D vector
assert "t-test_user-test_collection-4" in index_names # 4D vector
# Verify correct index used for 2D vector
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify query was made
assert mock_index.query.call_count == 1
# Verify results from both dimensions
entity_values = [e.value for e in entities]
# Verify results with EntityMatch structure
entity_values = [e.entity.value for e in entities]
assert 'entity_2d' in entity_values
assert 'entity_4d' in entity_values
@pytest.mark.asyncio
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
"""Test querying with empty vectors list"""
message = MagicMock()
message.vectors = []
message.vector = []
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_no_results(self, processor):
"""Test querying when index returns no results"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
assert entities == []
@pytest.mark.asyncio
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
"""Test that deduplication works correctly across multiple vector queries"""
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
"""Test that deduplication works correctly within query results"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 3
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Both queries return overlapping results
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns results with some duplicates
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}),
MagicMock(metadata={'entity': 'entity4'})
]
mock_results2 = MagicMock()
mock_results2.matches = [
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
MagicMock(metadata={'entity': 'entity5'})
]
mock_index.query.side_effect = [mock_results1, mock_results2]
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should get exactly 3 unique entities (respecting limit)
assert len(entities) == 3
entity_values = [e.value for e in entities]
entity_values = [e.entity.value for e in entities]
assert len(set(entity_values)) == 3 # All unique
@pytest.mark.asyncio
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
"""Test that querying stops early when limit is reached"""
async def test_query_graph_embeddings_respects_limit(self, processor):
"""Test that query respects limit parameter"""
message = MagicMock()
message.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]
]
message.vector = [0.1, 0.2, 0.3]
message.limit = 2
message.user = 'test_user'
message.collection = 'test_collection'
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# First query returns enough results to meet limit
mock_results1 = MagicMock()
mock_results1.matches = [
# Query returns more results than limit
mock_results = MagicMock()
mock_results.matches = [
MagicMock(metadata={'entity': 'entity1'}),
MagicMock(metadata={'entity': 'entity2'}),
MagicMock(metadata={'entity': 'entity3'})
]
mock_index.query.return_value = mock_results1
mock_index.query.return_value = mock_results
entities = await processor.query_graph_embeddings(message)
# Should only make one query since limit was reached
# Should only return 2 entities (respecting limit)
mock_index.query.assert_called_once()
assert len(entities) == 2
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
async def test_query_graph_embeddings_exception_handling(self, processor):
"""Test that exceptions are properly raised"""
message = MagicMock()
message.vectors = [[0.1, 0.2, 0.3]]
message.vector = [0.1, 0.2, 0.3]
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'

View file

@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
from trustgraph.schema import IRI, LITERAL, EntityMatch
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 5
mock_message.user = 'test_user'
mock_message.collection = 'test_collection'
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
with_payload=True
)
# Verify result contains expected entities
# Verify result contains expected EntityMatch objects
assert len(result) == 2
assert all(hasattr(entity, 'value') for entity in result)
entity_values = [entity.value for entity in result]
assert all(isinstance(entity, EntityMatch) for entity in result)
entity_values = [entity.entity.value for entity in result]
assert 'entity1' in entity_values
assert 'entity2' in entity_values
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with multiple vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 3
mock_message.user = 'multi_user'
mock_message.collection = 'multi_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
# Verify both collections were queried (both 2-dimensional vectors)
# Verify collection was queried
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
calls = mock_qdrant_instance.query_points.call_args_list
assert calls[0][1]['collection_name'] == expected_collection
assert calls[1][1]['collection_name'] == expected_collection
assert calls[0][1]['query'] == [0.1, 0.2]
assert calls[1][1]['query'] == [0.3, 0.4]
# Verify deduplication - entity2 appears in both results but should only appear once
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert len(set(entity_values)) == len(entity_values) # All unique
assert 'entity1' in entity_values
assert 'entity2' in entity_values
assert 'entity3' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2, 0.3]]
mock_message.vector = [0.1, 0.2, 0.3]
mock_message.limit = 3 # Should only return 3 results
mock_message.user = 'limit_user'
mock_message.collection = 'limit_collection'
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'empty_user'
mock_message.collection = 'empty_collection'
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with different dimension vectors
# Create mock message with single vector
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
mock_message.vector = [0.1, 0.2] # 2D vector
mock_message.limit = 5
mock_message.user = 'dim_user'
mock_message.collection = 'dim_collection'
# Act
result = await processor.query_graph_embeddings(mock_message)
# Assert
# Verify query was called twice with different collections
assert mock_qdrant_instance.query_points.call_count == 2
# Verify query was called once
assert mock_qdrant_instance.query_points.call_count == 1
calls = mock_qdrant_instance.query_points.call_args_list
# First call should use 2D collection
# Call should use 2D collection
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
assert calls[0][1]['query'] == [0.1, 0.2]
# Second call should use 3D collection
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
# Verify results
entity_values = [entity.value for entity in result]
# Verify results with EntityMatch structure
entity_values = [entity.entity.value for entity in result]
assert 'entity2d' in entity_values
assert 'entity3d' in entity_values
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'uri_user'
mock_message.collection = 'uri_collection'
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if entity.type == IRI]
uri_entities = [entity for entity in result if entity.entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.iri for entity in uri_entities]
uri_values = [entity.entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if entity.type == LITERAL]
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'
assert regular_entities[0].entity.value == 'regular entity'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 5
mock_message.user = 'error_user'
mock_message.collection = 'error_collection'
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Create mock message with zero limit
mock_message = MagicMock()
mock_message.vectors = [[0.1, 0.2]]
mock_message.vector = [0.1, 0.2]
mock_message.limit = 0
mock_message.user = 'zero_user'
mock_message.collection = 'zero_collection'
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# With zero limit, the logic still adds one entity before checking the limit
# So it returns one result (current behavior, not ideal but actual)
assert len(result) == 1
assert result[0].value == 'entity1'
assert result[0].entity.value == 'entity1'
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')

View file

@ -118,8 +118,8 @@ class TestCassandraQueryProcessor:
# Verify result contains the queried triple
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'test_object'
def test_processor_initialization_with_defaults(self):
@ -182,8 +182,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -219,8 +219,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -256,8 +256,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@ -293,8 +293,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -331,8 +331,8 @@ class TestCassandraQueryProcessor:
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
assert result[0].p.value == 'all_predicate'
assert result[0].s.iri == 'all_subject'
assert result[0].p.iri == 'all_predicate'
assert result[0].o.value == 'all_object'
def test_add_args_method(self):
@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations:
)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations:
)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}' # Mock returns literal values
assert triple.s.iri == f'subject_{i}' # Mock returns literal values
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.type == IRI
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,309 @@
"""
Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node,
typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass
with named graph support, and the knowledge/defs helper types.
"""
import pytest
from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
# ---------------------------------------------------------------------------
# Type constants
# ---------------------------------------------------------------------------
class TestTypeConstants:
def test_iri_constant(self):
assert IRI == "i"
def test_blank_constant(self):
assert BLANK == "b"
def test_literal_constant(self):
assert LITERAL == "l"
def test_triple_constant(self):
assert TRIPLE == "t"
def test_constants_are_distinct(self):
vals = {IRI, BLANK, LITERAL, TRIPLE}
assert len(vals) == 4
# ---------------------------------------------------------------------------
# IRI terms
# ---------------------------------------------------------------------------
class TestIriTerm:
def test_create_iri(self):
t = Term(type=IRI, iri="http://example.org/Alice")
assert t.type == IRI
assert t.iri == "http://example.org/Alice"
def test_iri_defaults_empty(self):
t = Term(type=IRI)
assert t.iri == ""
def test_iri_with_fragment(self):
t = Term(type=IRI, iri="http://example.org/ontology#Person")
assert "#Person" in t.iri
def test_iri_with_unicode(self):
t = Term(type=IRI, iri="http://example.org/概念")
assert "概念" in t.iri
def test_iri_other_fields_default(self):
t = Term(type=IRI, iri="http://example.org/x")
assert t.id == ""
assert t.value == ""
assert t.datatype == ""
assert t.language == ""
assert t.triple is None
# ---------------------------------------------------------------------------
# Blank node terms
# ---------------------------------------------------------------------------
class TestBlankNodeTerm:
def test_create_blank_node(self):
t = Term(type=BLANK, id="_:b0")
assert t.type == BLANK
assert t.id == "_:b0"
def test_blank_node_defaults_empty(self):
t = Term(type=BLANK)
assert t.id == ""
def test_blank_node_arbitrary_id(self):
t = Term(type=BLANK, id="node-abc-123")
assert t.id == "node-abc-123"
# ---------------------------------------------------------------------------
# Typed literals (XSD datatypes)
# ---------------------------------------------------------------------------
class TestTypedLiteral:
def test_plain_literal(self):
t = Term(type=LITERAL, value="hello")
assert t.type == LITERAL
assert t.value == "hello"
assert t.datatype == ""
assert t.language == ""
def test_xsd_integer(self):
t = Term(
type=LITERAL, value="42",
datatype="http://www.w3.org/2001/XMLSchema#integer",
)
assert t.value == "42"
assert "integer" in t.datatype
def test_xsd_boolean(self):
t = Term(
type=LITERAL, value="true",
datatype="http://www.w3.org/2001/XMLSchema#boolean",
)
assert t.datatype.endswith("#boolean")
def test_xsd_date(self):
t = Term(
type=LITERAL, value="2026-03-13",
datatype="http://www.w3.org/2001/XMLSchema#date",
)
assert t.value == "2026-03-13"
assert t.datatype.endswith("#date")
def test_xsd_double(self):
t = Term(
type=LITERAL, value="3.14",
datatype="http://www.w3.org/2001/XMLSchema#double",
)
assert t.datatype.endswith("#double")
def test_empty_value_literal(self):
t = Term(type=LITERAL, value="")
assert t.value == ""
# ---------------------------------------------------------------------------
# Language-tagged literals
# ---------------------------------------------------------------------------
class TestLanguageTaggedLiteral:
def test_english_tag(self):
t = Term(type=LITERAL, value="hello", language="en")
assert t.language == "en"
assert t.datatype == ""
def test_french_tag(self):
t = Term(type=LITERAL, value="bonjour", language="fr")
assert t.language == "fr"
def test_bcp47_subtag(self):
t = Term(type=LITERAL, value="colour", language="en-GB")
assert t.language == "en-GB"
def test_language_and_datatype_mutually_exclusive(self):
"""Both can be set on the dataclass, but semantically only one should be used."""
t = Term(type=LITERAL, value="x", language="en",
datatype="http://www.w3.org/2001/XMLSchema#string")
# Dataclass allows both — translators should respect mutual exclusivity
assert t.language == "en"
assert t.datatype != ""
# ---------------------------------------------------------------------------
# Quoted triples (RDF-star)
# ---------------------------------------------------------------------------
class TestQuotedTriple:
def test_term_with_nested_triple(self):
inner = Triple(
s=Term(type=IRI, iri="http://example.org/Alice"),
p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"),
o=Term(type=IRI, iri="http://example.org/Bob"),
)
qt = Term(type=TRIPLE, triple=inner)
assert qt.type == TRIPLE
assert qt.triple is inner
assert qt.triple.s.iri == "http://example.org/Alice"
def test_quoted_triple_as_object(self):
"""A triple whose object is a quoted triple (RDF-star)."""
inner = Triple(
s=Term(type=IRI, iri="http://example.org/Hope"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="A feeling of expectation"),
)
outer = Triple(
s=Term(type=IRI, iri="urn:subgraph:123"),
p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"),
o=Term(type=TRIPLE, triple=inner),
)
assert outer.o.type == TRIPLE
assert outer.o.triple.o.value == "A feeling of expectation"
def test_quoted_triple_none(self):
t = Term(type=TRIPLE, triple=None)
assert t.triple is None
# ---------------------------------------------------------------------------
# Triple / Quad (named graph)
# ---------------------------------------------------------------------------
class TestTripleQuad:
def test_default_graph_is_none(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
)
assert t.g is None
def test_named_graph(self):
t = Triple(
s=Term(type=IRI, iri="http://example.org/s"),
p=Term(type=IRI, iri="http://example.org/p"),
o=Term(type=LITERAL, value="val"),
g="urn:graph:source",
)
assert t.g == "urn:graph:source"
def test_empty_string_graph(self):
t = Triple(g="")
assert t.g == ""
def test_triple_with_all_none_terms(self):
t = Triple()
assert t.s is None
assert t.p is None
assert t.o is None
assert t.g is None
def test_triple_equality(self):
"""Dataclass equality based on field values."""
t1 = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
t2 = Triple(
s=Term(type=IRI, iri="http://example.org/A"),
p=Term(type=IRI, iri="http://example.org/B"),
o=Term(type=LITERAL, value="C"),
)
assert t1 == t2
# ---------------------------------------------------------------------------
# knowledge/defs helper types
# ---------------------------------------------------------------------------
class TestKnowledgeDefs:
def test_uri_type(self):
from trustgraph.knowledge.defs import Uri
u = Uri("http://example.org/x")
assert u.is_uri() is True
assert u.is_literal() is False
assert u.is_triple() is False
assert str(u) == "http://example.org/x"
def test_literal_type(self):
from trustgraph.knowledge.defs import Literal
l = Literal("hello world")
assert l.is_uri() is False
assert l.is_literal() is True
assert l.is_triple() is False
assert str(l) == "hello world"
def test_quoted_triple_type(self):
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
qt = QuotedTriple(
s=Uri("http://example.org/s"),
p=Uri("http://example.org/p"),
o=Literal("val"),
)
assert qt.is_uri() is False
assert qt.is_literal() is False
assert qt.is_triple() is True
assert qt.s == "http://example.org/s"
assert qt.o == "val"
def test_quoted_triple_repr(self):
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
qt = QuotedTriple(
s=Uri("http://example.org/A"),
p=Uri("http://example.org/B"),
o=Literal("C"),
)
r = repr(qt)
assert "<<" in r
assert ">>" in r
assert "http://example.org/A" in r
def test_quoted_triple_nested(self):
"""QuotedTriple can contain another QuotedTriple as object."""
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
inner = QuotedTriple(
s=Uri("http://example.org/s"),
p=Uri("http://example.org/p"),
o=Literal("v"),
)
outer = QuotedTriple(
s=Uri("http://example.org/s2"),
p=Uri("http://example.org/p2"),
o=inner,
)
assert outer.o.is_triple() is True

Some files were not shown because too many files have changed in this diff Show more