mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge branch 'release/v2.1'
This commit is contained in:
commit
824f993985
266 changed files with 33195 additions and 5834 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=2.0.999
|
||||
run: make update-package-versions VERSION=2.1.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
|
|||
49
README.md
49
README.md
|
|
@ -13,9 +13,17 @@
|
|||
|
||||
# The context backend for reliable AI
|
||||
|
||||
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
# The context backend for AI agents
|
||||
|
||||
</div>
|
||||
|
||||
LLMs alone hallucinate and diverge from ground truth. [TrustGraph](https://trustgraph.ai) is a context system that stores, enriches, and delivers context to LLMs to enable reliable AI agents. Think like [Supabase](https://github.com/supabase/supabase) but AI-native and powered by context graphs.
|
||||
Durable agent memory you can trust. Build, version, and retrieve grounded context from a context graph.
|
||||
|
||||
- Give agents **memory** that persists across sessions and deployments.
|
||||
- Reduce hallucinations with **grounded context retrieval**
|
||||
- Ship reusable, portable [Context Cores](#context-cores) (packaged context you can move between projects/environments).
|
||||
|
||||
The context backend:
|
||||
- [x] Multi-model and multimodal database system
|
||||
|
|
@ -45,21 +53,6 @@ The context backend:
|
|||
- [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html)
|
||||
- [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python)
|
||||
- [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/)
|
||||
|
||||
## No API Keys Required
|
||||
|
||||
How many times have you cloned a repo and opened the `.env.example` to see the dozens of API keys for 3rd party dependencies needed to make the services work? There are only 3 things in TrustGraph that might need an API key:
|
||||
|
||||
- 3rd party LLM services like Anthropic, Cohere, Gemini, Mistral, OpenAI, etc.
|
||||
- 3rd party OCR like Mistral OCR
|
||||
- The API key *you set* for the TrustGraph API gateway
|
||||
|
||||
Everything else is included.
|
||||
- [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html)
|
||||
- [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant)
|
||||
- [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible)
|
||||
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar)
|
||||
- [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile)
|
||||
|
||||
## Quickstart
|
||||
|
||||
|
|
@ -76,8 +69,6 @@ TrustGraph downloads as Docker containers and can be run locally with Docker, Po
|
|||
width="80%" controls></video>
|
||||
</p>
|
||||
|
||||
For a browser based quickstart, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/).
|
||||
|
||||
<details>
|
||||
<summary>Table of Contents</summary>
|
||||
<br>
|
||||
|
|
@ -181,24 +172,28 @@ TrustGraph provides component flexibility to optimize agent workflows.
|
|||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Multi-model storage</summary>
|
||||
<summary>Graph Storage</summary>
|
||||
<br>
|
||||
|
||||
- Apache Cassandra (default)<br>
|
||||
- Neo4j<br>
|
||||
- Memgraph<br>
|
||||
- FalkorDB<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>VectorDBs</summary>
|
||||
<br>
|
||||
|
||||
- Apache Cassandra<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>VectorDB</summary>
|
||||
<br>
|
||||
|
||||
- Qdrant<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>File and Object Storage</summary>
|
||||
<br>
|
||||
|
||||
- Garage<br>
|
||||
- Garage (default)<br>
|
||||
- MinIO<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
|
|
|
|||
108
docs/api-gateway-changes-v1.8-to-v2.1.md
Normal file
108
docs/api-gateway-changes-v1.8-to-v2.1.md
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# API Gateway Changes: v1.8 to v2.1
|
||||
|
||||
## Summary
|
||||
|
||||
The API gateway gained new WebSocket service dispatchers for embeddings
|
||||
queries, a new REST streaming endpoint for document content, and underwent
|
||||
a significant wire format change from `Value` to `Term`. The "objects"
|
||||
service was renamed to "rows".
|
||||
|
||||
---
|
||||
|
||||
## New WebSocket Service Dispatchers
|
||||
|
||||
These are new request/response services available through the WebSocket
|
||||
multiplexer at `/api/v1/socket` (flow-scoped):
|
||||
|
||||
| Service Key | Description |
|
||||
|-------------|-------------|
|
||||
| `document-embeddings` | Queries document chunks by text similarity. Request/response uses `DocumentEmbeddingsRequest`/`DocumentEmbeddingsResponse` schemas. |
|
||||
| `row-embeddings` | Queries structured data rows by text similarity on indexed fields. Request/response uses `RowEmbeddingsRequest`/`RowEmbeddingsResponse` schemas. |
|
||||
|
||||
These join the existing `graph-embeddings` dispatcher (which was already
|
||||
present in v1.8 but may have been updated).
|
||||
|
||||
### Full list of WebSocket flow service dispatchers (v2.1)
|
||||
|
||||
Request/response services (via `/api/v1/flow/{flow}/service/{kind}` or
|
||||
WebSocket mux):
|
||||
|
||||
- `agent`, `text-completion`, `prompt`, `mcp-tool`
|
||||
- `graph-rag`, `document-rag`
|
||||
- `embeddings`, `graph-embeddings`, `document-embeddings`
|
||||
- `triples`, `rows`, `nlp-query`, `structured-query`, `structured-diag`
|
||||
- `row-embeddings`
|
||||
|
||||
---
|
||||
|
||||
## New REST Endpoint
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/api/v1/document-stream` | Streams document content from the library as raw bytes. Query parameters: `user` (required), `document-id` (required), `chunk-size` (optional, default 1MB). Returns the document content in chunked transfer encoding, decoded from base64 internally. |
|
||||
|
||||
---
|
||||
|
||||
## Renamed Service: "objects" to "rows"
|
||||
|
||||
| v1.8 | v2.1 | Notes |
|
||||
|------|------|-------|
|
||||
| `objects_query.py` / `ObjectsQueryRequestor` | `rows_query.py` / `RowsQueryRequestor` | Schema changed from `ObjectsQueryRequest`/`ObjectsQueryResponse` to `RowsQueryRequest`/`RowsQueryResponse`. |
|
||||
| `objects_import.py` / `ObjectsImport` | `rows_import.py` / `RowsImport` | Import dispatcher for structured data. |
|
||||
|
||||
The WebSocket service key changed from `"objects"` to `"rows"`, and the
|
||||
import dispatcher key similarly changed from `"objects"` to `"rows"`.
|
||||
|
||||
---
|
||||
|
||||
## Wire Format Change: Value to Term
|
||||
|
||||
The serialization layer (`serialize.py`) was rewritten to use the new `Term`
|
||||
type instead of the old `Value` type.
|
||||
|
||||
### Old format (v1.8 — `Value`)
|
||||
|
||||
```json
|
||||
{"v": "http://example.org/entity", "e": true}
|
||||
```
|
||||
|
||||
- `v`: the value (string)
|
||||
- `e`: boolean flag indicating whether the value is a URI
|
||||
|
||||
### New format (v2.1 — `Term`)
|
||||
|
||||
IRIs:
|
||||
```json
|
||||
{"t": "i", "i": "http://example.org/entity"}
|
||||
```
|
||||
|
||||
Literals:
|
||||
```json
|
||||
{"t": "l", "v": "some text", "d": "datatype-uri", "l": "en"}
|
||||
```
|
||||
|
||||
Quoted triples (RDF-star):
|
||||
```json
|
||||
{"t": "r", "r": {"s": {...}, "p": {...}, "o": {...}}}
|
||||
```
|
||||
|
||||
- `t`: type discriminator — `"i"` (IRI), `"l"` (literal), `"r"` (quoted triple), `"b"` (blank node)
|
||||
- Serialization now delegates to `TermTranslator` and `TripleTranslator` from `trustgraph.messaging.translators.primitives`
|
||||
|
||||
### Other serialization changes
|
||||
|
||||
| Field | v1.8 | v2.1 |
|
||||
|-------|------|------|
|
||||
| Metadata | `metadata.metadata` (subgraph) | `metadata.root` (simple value) |
|
||||
| Graph embeddings entity | `entity.vectors` (plural) | `entity.vector` (singular) |
|
||||
| Document embeddings chunk | `chunk.vectors` + `chunk.chunk` (text) | `chunk.vector` + `chunk.chunk_id` (ID reference) |
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **`Value` to `Term` wire format**: All clients sending/receiving triples, embeddings, or entity contexts through the gateway must update to the new Term format.
|
||||
- **`objects` to `rows` rename**: WebSocket service key and import key changed.
|
||||
- **Metadata field change**: `metadata.metadata` (a serialized subgraph) replaced by `metadata.root` (a simple value).
|
||||
- **Embeddings field changes**: `vectors` (plural) became `vector` (singular); document embeddings now reference `chunk_id` instead of inline `chunk` text.
|
||||
- **New `/api/v1/document-stream` endpoint**: Additive, not breaking.
|
||||
2018
docs/api.html
2018
docs/api.html
File diff suppressed because one or more lines are too long
112
docs/cli-changes-v1.8-to-v2.1.md
Normal file
112
docs/cli-changes-v1.8-to-v2.1.md
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
# CLI Changes: v1.8 to v2.1
|
||||
|
||||
## Summary
|
||||
|
||||
The CLI (`trustgraph-cli`) has significant additions focused on three themes:
|
||||
**explainability/provenance**, **embeddings access**, and **graph querying**.
|
||||
Two legacy tools were removed, one was renamed, and several existing tools
|
||||
gained new capabilities.
|
||||
|
||||
---
|
||||
|
||||
## New CLI Tools
|
||||
|
||||
### Explainability & Provenance
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-list-explain-traces` | Lists all explainability sessions (GraphRAG and Agent) in a collection, showing session IDs, type, question text, and timestamps. |
|
||||
| `tg-show-explain-trace` | Displays the full explainability trace for a session. For GraphRAG: Question, Exploration, Focus, Synthesis stages. For Agent: Session, Iterations (thought/action/observation), Final Answer. Auto-detects trace type. Supports `--show-provenance` to trace edges back to source documents. |
|
||||
| `tg-show-extraction-provenance` | Given a document ID, traverses the provenance chain: Document -> Pages -> Chunks -> Edges, using `prov:wasDerivedFrom` relationships. Supports `--show-content` and `--max-content` options. |
|
||||
|
||||
### Embeddings
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-invoke-embeddings` | Converts text to a vector embedding via the embeddings service. Accepts one or more text inputs, returns vectors as lists of floats. |
|
||||
| `tg-invoke-graph-embeddings` | Queries graph entities by text similarity using vector embeddings. Returns matching entities with similarity scores. |
|
||||
| `tg-invoke-document-embeddings` | Queries document chunks by text similarity using vector embeddings. Returns matching chunk IDs with similarity scores. |
|
||||
| `tg-invoke-row-embeddings` | Queries structured data rows by text similarity on indexed fields. Returns matching rows with index values and scores. Requires `--schema-name` and supports `--index-name`. |
|
||||
|
||||
### Graph Querying
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `tg-query-graph` | Pattern-based triple store query. Unlike `tg-show-graph` (which dumps everything), this allows selective queries by any combination of subject, predicate, object, and graph. Auto-detects value types: IRIs (`http://...`, `urn:...`, `<...>`), quoted triples (`<<s p o>>`), and literals. |
|
||||
| `tg-get-document-content` | Retrieves document content from the library by document ID. Can output to file or stdout, handles both text and binary content. |
|
||||
|
||||
---
|
||||
|
||||
## Removed CLI Tools
|
||||
|
||||
| Command | Notes |
|
||||
|---------|-------|
|
||||
| `tg-load-pdf` | Removed. Document loading is now handled through the library/processing pipeline. |
|
||||
| `tg-load-text` | Removed. Document loading is now handled through the library/processing pipeline. |
|
||||
|
||||
---
|
||||
|
||||
## Renamed CLI Tools
|
||||
|
||||
| Old Name | New Name | Notes |
|
||||
|----------|----------|-------|
|
||||
| `tg-invoke-objects-query` | `tg-invoke-rows-query` | Reflects the terminology rename from "objects" to "rows" for structured data. |
|
||||
|
||||
---
|
||||
|
||||
## Significant Changes to Existing Tools
|
||||
|
||||
### `tg-invoke-graph-rag`
|
||||
|
||||
- **Explainability support**: Now supports a 4-stage explainability pipeline (Question, Grounding/Exploration, Focus, Synthesis) with inline provenance event display.
|
||||
- **Streaming**: Uses WebSocket streaming for real-time output.
|
||||
- **Provenance tracing**: Can trace selected edges back to source documents via reification and `prov:wasDerivedFrom` chains.
|
||||
- Grew from ~30 lines to ~760 lines to accommodate the full explainability pipeline.
|
||||
|
||||
### `tg-invoke-document-rag`
|
||||
|
||||
- **Explainability support**: Added `question_explainable()` mode that streams Document RAG responses with inline provenance events (Question, Grounding, Exploration, Synthesis stages).
|
||||
|
||||
### `tg-invoke-agent`
|
||||
|
||||
- **Explainability support**: Added `question_explainable()` mode showing provenance events inline during agent execution (Question, Analysis, Conclusion, AgentThought, AgentObservation, AgentAnswer).
|
||||
- Verbose mode shows thought/observation streams with emoji prefixes.
|
||||
|
||||
### `tg-show-graph`
|
||||
|
||||
- **Streaming mode**: Now uses `triples_query_stream()` with configurable batch sizes for lower time-to-first-result and reduced memory overhead.
|
||||
- **Named graph support**: New `--graph` filter option. Recognises named graphs:
|
||||
- Default graph (empty): Core knowledge facts
|
||||
- `urn:graph:source`: Extraction provenance
|
||||
- `urn:graph:retrieval`: Query-time explainability
|
||||
- **Show graph column**: New `--show-graph` flag to display the named graph for each triple.
|
||||
- **Configurable limits**: New `--limit` and `--batch-size` options.
|
||||
|
||||
### `tg-graph-to-turtle`
|
||||
|
||||
- **RDF-star support**: Now handles quoted triples (RDF-star reification).
|
||||
- **Streaming mode**: Uses streaming for lower time-to-first-processing.
|
||||
- **Wire format handling**: Updated to use the new term wire format (`{"t": "i", "i": uri}` for IRIs, `{"t": "l", "v": value}` for literals, `{"t": "r", "r": {...}}` for quoted triples).
|
||||
- **Named graph support**: New `--graph` filter option.
|
||||
|
||||
### `tg-set-tool`
|
||||
|
||||
- **New tool type**: `row-embeddings-query` for semantic search on structured data indexes.
|
||||
- **New options**: `--schema-name`, `--index-name`, `--limit` for configuring row embeddings query tools.
|
||||
|
||||
### `tg-show-tools`
|
||||
|
||||
- Displays the new `row-embeddings-query` tool type with its `schema-name`, `index-name`, and `limit` fields.
|
||||
|
||||
### `tg-load-knowledge`
|
||||
|
||||
- **Progress reporting**: Now counts and reports triples and entity contexts loaded per file and in total.
|
||||
- **Term format update**: Entity contexts now use the new Term format (`{"t": "i", "i": uri}`) instead of the old Value format (`{"v": entity, "e": True}`).
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **Terminology rename**: The `Value` schema was renamed to `Term` across the system (PR #622). This affects the wire format used by CLI tools that interact with the graph store. The new format uses `{"t": "i", "i": uri}` for IRIs and `{"t": "l", "v": value}` for literals, replacing the old `{"v": ..., "e": ...}` format.
|
||||
- **`tg-invoke-objects-query` renamed** to `tg-invoke-rows-query`.
|
||||
- **`tg-load-pdf` and `tg-load-text` removed**.
|
||||
2234
docs/python-api.md
2234
docs/python-api.md
File diff suppressed because it is too large
Load diff
272
docs/tech-specs/agent-explainability.md
Normal file
272
docs/tech-specs/agent-explainability.md
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
# Agent Explainability: Provenance Recording
|
||||
|
||||
## Overview
|
||||
|
||||
Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG.
|
||||
|
||||
**Design Decisions:**
|
||||
- Write to `urn:graph:retrieval` (generic explainability graph)
|
||||
- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1)
|
||||
- Tools are opaque black boxes (record input/output only)
|
||||
- DAG support deferred to future iteration
|
||||
|
||||
## Entity Types
|
||||
|
||||
Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes:
|
||||
|
||||
### GraphRAG Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
|
||||
| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph |
|
||||
| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning |
|
||||
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
|
||||
|
||||
### Agent Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query |
|
||||
| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle |
|
||||
| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer |
|
||||
|
||||
### Document RAG Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query |
|
||||
| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store |
|
||||
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
|
||||
|
||||
**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase).
|
||||
|
||||
### Question Subtypes
|
||||
|
||||
All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism:
|
||||
|
||||
| Subtype | URI Pattern | Mechanism |
|
||||
|---------|-------------|-----------|
|
||||
| `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG |
|
||||
| `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG |
|
||||
| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent |
|
||||
|
||||
This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype.
|
||||
|
||||
## Provenance Model
|
||||
|
||||
```
|
||||
Question (urn:trustgraph:agent:{uuid})
|
||||
│
|
||||
│ tg:query = "User's question"
|
||||
│ prov:startedAtTime = timestamp
|
||||
│ rdf:type = prov:Activity, tg:Question
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Analysis1 (urn:trustgraph:agent:{uuid}/i1)
|
||||
│
|
||||
│ tg:thought = "I need to query the knowledge base..."
|
||||
│ tg:action = "knowledge-query"
|
||||
│ tg:arguments = {"question": "..."}
|
||||
│ tg:observation = "Result from tool..."
|
||||
│ rdf:type = prov:Entity, tg:Analysis
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Analysis2 (urn:trustgraph:agent:{uuid}/i2)
|
||||
│ ...
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Conclusion (urn:trustgraph:agent:{uuid}/final)
|
||||
│
|
||||
│ tg:answer = "The final response..."
|
||||
│ rdf:type = prov:Entity, tg:Conclusion
|
||||
```
|
||||
|
||||
### Document RAG Provenance Model
|
||||
|
||||
```
|
||||
Question (urn:trustgraph:docrag:{uuid})
|
||||
│
|
||||
│ tg:query = "User's question"
|
||||
│ prov:startedAtTime = timestamp
|
||||
│ rdf:type = prov:Activity, tg:Question
|
||||
│
|
||||
↓ prov:wasGeneratedBy
|
||||
│
|
||||
Exploration (urn:trustgraph:docrag:{uuid}/exploration)
|
||||
│
|
||||
│ tg:chunkCount = 5
|
||||
│ tg:selectedChunk = "chunk-id-1"
|
||||
│ tg:selectedChunk = "chunk-id-2"
|
||||
│ ...
|
||||
│ rdf:type = prov:Entity, tg:Exploration
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
|
||||
│
|
||||
│ tg:content = "The synthesized answer..."
|
||||
│ rdf:type = prov:Entity, tg:Synthesis
|
||||
```
|
||||
|
||||
## Changes Required
|
||||
|
||||
### 1. Schema Changes
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/schema/services/agent.py`
|
||||
|
||||
Add `session_id` and `collection` fields to `AgentRequest`:
|
||||
```python
|
||||
@dataclass
|
||||
class AgentRequest:
|
||||
question: str = ""
|
||||
state: str = ""
|
||||
group: list[str] | None = None
|
||||
history: list[AgentStep] = field(default_factory=list)
|
||||
user: str = ""
|
||||
collection: str = "default" # NEW: Collection for provenance traces
|
||||
streaming: bool = False
|
||||
session_id: str = "" # NEW: For provenance tracking across iterations
|
||||
```
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py`
|
||||
|
||||
Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`.
|
||||
|
||||
### 2. Add Explainability Producer to Agent Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/agent/react/service.py`
|
||||
|
||||
Register an "explainability" producer (same pattern as GraphRAG):
|
||||
```python
|
||||
from ... base import ProducerSpec
|
||||
from ... schema import Triples
|
||||
|
||||
# In __init__:
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Provenance Triple Generation
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/provenance/agent.py`
|
||||
|
||||
Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.):
|
||||
```python
|
||||
def agent_session_triples(session_uri, query, timestamp):
|
||||
"""Generate triples for agent Question."""
|
||||
return [
|
||||
Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY),
|
||||
Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION),
|
||||
Triple(s=session_uri, p=TG_QUERY, o=query),
|
||||
Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp),
|
||||
]
|
||||
|
||||
def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation):
|
||||
"""Generate triples for one Analysis step."""
|
||||
return [
|
||||
Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY),
|
||||
Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS),
|
||||
Triple(s=iteration_uri, p=TG_THOUGHT, o=thought),
|
||||
Triple(s=iteration_uri, p=TG_ACTION, o=action),
|
||||
Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)),
|
||||
Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation),
|
||||
Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
|
||||
]
|
||||
|
||||
def agent_final_triples(final_uri, parent_uri, answer):
|
||||
"""Generate triples for Conclusion."""
|
||||
return [
|
||||
Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY),
|
||||
Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION),
|
||||
Triple(s=final_uri, p=TG_ANSWER, o=answer),
|
||||
Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
|
||||
]
|
||||
```
|
||||
|
||||
### 4. Type Definitions
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/provenance/namespaces.py`
|
||||
|
||||
Add explainability entity types and agent predicates:
|
||||
```python
|
||||
# Explainability entity types (used by both GraphRAG and Agent)
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
|
||||
# Agent predicates
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
```
|
||||
|
||||
## Files Modified
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields |
|
||||
| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates |
|
||||
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
|
||||
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
|
||||
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
|
||||
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id and explain_graph to DocumentRagResponse |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields |
|
||||
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
|
||||
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types |
|
||||
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG |
|
||||
|
||||
## Files Created
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators |
|
||||
|
||||
## CLI Updates
|
||||
|
||||
**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by:
|
||||
1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:`
|
||||
2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG)
|
||||
|
||||
**`list_explain_traces.py`:**
|
||||
- Shows Type column (Agent vs GraphRAG)
|
||||
|
||||
**`show_explain_trace.py`:**
|
||||
- Auto-detects trace type
|
||||
- Agent rendering shows: Question → Analysis step(s) → Conclusion
|
||||
|
||||
## Backwards Compatibility
|
||||
|
||||
- `session_id` defaults to `""` - old requests work, just won't have provenance
|
||||
- `collection` defaults to `"default"` - reasonable fallback
|
||||
- CLI gracefully handles both trace types
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Run an agent query
|
||||
tg-invoke-agent -q "What is the capital of France?"
|
||||
|
||||
# List traces (should show agent sessions with Type column)
|
||||
tg-list-explain-traces -U trustgraph -C default
|
||||
|
||||
# Show agent trace
|
||||
tg-show-explain-trace "urn:trustgraph:agent:xxx"
|
||||
```
|
||||
|
||||
## Future Work (Not This PR)
|
||||
|
||||
- DAG dependencies (when analysis N uses results from multiple prior analyses)
|
||||
- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace)
|
||||
- Streaming provenance emission (emit as we go, not batch at end)
|
||||
136
docs/tech-specs/document-embeddings-chunk-id.md
Normal file
136
docs/tech-specs/document-embeddings-chunk-id.md
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
# Document Embeddings Chunk ID
|
||||
|
||||
## Overview
|
||||
|
||||
Document embeddings storage currently stores chunk text directly in the vector store payload, duplicating data that exists in Garage. This spec replaces chunk text storage with `chunk_id` references.
|
||||
|
||||
## Current State
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk: bytes = b""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunks: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Vector store payload:
|
||||
```python
|
||||
payload={"doc": chunk} # Duplicates Garage content
|
||||
```
|
||||
|
||||
## Design
|
||||
|
||||
### Schema Changes
|
||||
|
||||
**ChunkEmbeddings** - replace chunk with chunk_id:
|
||||
```python
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk_id: str = ""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**DocumentEmbeddingsResponse** - return chunk_ids instead of chunks:
|
||||
```python
|
||||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunk_ids: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
### Vector Store Payload
|
||||
|
||||
All stores (Qdrant, Milvus, Pinecone):
|
||||
```python
|
||||
payload={"chunk_id": chunk_id}
|
||||
```
|
||||
|
||||
### Document RAG Changes
|
||||
|
||||
The document RAG processor fetches chunk content from Garage:
|
||||
|
||||
```python
|
||||
# Get chunk_ids from embeddings store
|
||||
chunk_ids = await self.rag.doc_embeddings_client.query(...)
|
||||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
for chunk_id in chunk_ids:
|
||||
content = await self.rag.librarian_client.get_document_content(
|
||||
chunk_id, self.user
|
||||
)
|
||||
docs.append(content)
|
||||
```
|
||||
|
||||
### API/SDK Changes
|
||||
|
||||
**DocumentEmbeddingsClient** returns chunk_ids:
|
||||
```python
|
||||
return resp.chunk_ids # Changed from resp.chunks
|
||||
```
|
||||
|
||||
**Wire format** (DocumentEmbeddingsResponseTranslator):
|
||||
```python
|
||||
result["chunk_ids"] = obj.chunk_ids # Changed from chunks
|
||||
```
|
||||
|
||||
### CLI Changes
|
||||
|
||||
CLI tool displays chunk_ids (callers can fetch content separately if needed).
|
||||
|
||||
## Files to Modify
|
||||
|
||||
### Schema
|
||||
- `trustgraph-base/trustgraph/schema/knowledge/embeddings.py` - ChunkEmbeddings
|
||||
- `trustgraph-base/trustgraph/schema/services/query.py` - DocumentEmbeddingsResponse
|
||||
|
||||
### Messaging/Translators
|
||||
- `trustgraph-base/trustgraph/messaging/translators/embeddings_query.py` - DocumentEmbeddingsResponseTranslator
|
||||
|
||||
### Client
|
||||
- `trustgraph-base/trustgraph/base/document_embeddings_client.py` - return chunk_ids
|
||||
|
||||
### Python SDK/API
|
||||
- `trustgraph-base/trustgraph/api/flow.py` - document_embeddings_query
|
||||
- `trustgraph-base/trustgraph/api/socket_client.py` - document_embeddings_query
|
||||
- `trustgraph-base/trustgraph/api/async_flow.py` - if applicable
|
||||
- `trustgraph-base/trustgraph/api/bulk_client.py` - import/export document embeddings
|
||||
- `trustgraph-base/trustgraph/api/async_bulk_client.py` - import/export document embeddings
|
||||
|
||||
### Embeddings Service
|
||||
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py` - pass chunk_id
|
||||
|
||||
### Storage Writers
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
|
||||
### Query Services
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py`
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py`
|
||||
- `trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py`
|
||||
|
||||
### Gateway
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_query.py`
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py`
|
||||
- `trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py`
|
||||
|
||||
### Document RAG
|
||||
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` - add librarian client
|
||||
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` - fetch from Garage
|
||||
|
||||
### CLI
|
||||
- `trustgraph-cli/trustgraph/cli/invoke_document_embeddings.py`
|
||||
- `trustgraph-cli/trustgraph/cli/save_doc_embeds.py`
|
||||
- `trustgraph-cli/trustgraph/cli/load_doc_embeds.py`
|
||||
|
||||
## Benefits
|
||||
|
||||
1. Single source of truth - chunk text only in Garage
|
||||
2. Reduced vector store storage
|
||||
3. Enables query-time provenance via chunk_id
|
||||
667
docs/tech-specs/embeddings-batch-processing.md
Normal file
667
docs/tech-specs/embeddings-batch-processing.md
Normal file
|
|
@ -0,0 +1,667 @@
|
|||
# Embeddings Batch Processing Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes optimizations for the embeddings service to support batch processing of multiple texts in a single request. The current implementation processes one text at a time, missing the significant performance benefits that embedding models provide when processing batches.
|
||||
|
||||
1. **Single-Text Processing Inefficiency**: Current implementation wraps single texts in a list, underutilizing FastEmbed's batch capabilities
|
||||
2. **Request-Per-Text Overhead**: Each text requires a separate Pulsar message round-trip
|
||||
3. **Model Inference Inefficiency**: Embedding models have fixed per-batch overhead; small batches waste GPU/CPU resources
|
||||
4. **Serial Processing in Callers**: Key services loop over items and call embeddings one at a time
|
||||
|
||||
## Goals
|
||||
|
||||
- **Batch API Support**: Enable processing multiple texts in a single request
|
||||
- **Backward Compatibility**: Maintain support for single-text requests
|
||||
- **Significant Throughput Improvement**: Target 5-10x throughput improvement for bulk operations
|
||||
- **Reduced Latency per Text**: Lower amortized latency when embedding multiple texts
|
||||
- **Memory Efficiency**: Process batches without excessive memory consumption
|
||||
- **Provider Agnostic**: Support batching across FastEmbed, Ollama, and other providers
|
||||
- **Caller Migration**: Update all embedding callers to use batch API where beneficial
|
||||
|
||||
## Background
|
||||
|
||||
### Current Implementation - Embeddings Service
|
||||
|
||||
The embeddings implementation in `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py` exhibits a significant performance inefficiency:
|
||||
|
||||
```python
|
||||
# fastembed/processor.py line 56
|
||||
async def on_embeddings(self, text, model=None):
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
|
||||
vecs = self.embeddings.embed([text]) # Single text wrapped in list
|
||||
|
||||
return [v.tolist() for v in vecs]
|
||||
```
|
||||
|
||||
**Problems:**
|
||||
|
||||
1. **Batch Size 1**: FastEmbed's `embed()` method is optimized for batch processing, but we always call it with `[text]` - a batch of size 1
|
||||
|
||||
2. **Per-Request Overhead**: Each embedding request incurs:
|
||||
- Pulsar message serialization/deserialization
|
||||
- Network round-trip latency
|
||||
- Model inference startup overhead
|
||||
- Python async scheduling overhead
|
||||
|
||||
3. **Schema Limitation**: The `EmbeddingsRequest` schema only supports a single text:
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
text: str = "" # Single text only
|
||||
```
|
||||
|
||||
### Current Callers - Serial Processing
|
||||
|
||||
#### 1. API Gateway
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
|
||||
|
||||
The gateway accepts single-text embedding requests via HTTP/WebSocket and forwards them to the embeddings service. Currently no batch endpoint exists.
|
||||
|
||||
```python
|
||||
class EmbeddingsRequestor(ServiceRequestor):
|
||||
# Handles single EmbeddingsRequest -> EmbeddingsResponse
|
||||
request_schema=EmbeddingsRequest, # Single text only
|
||||
response_schema=EmbeddingsResponse,
|
||||
```
|
||||
|
||||
**Impact:** External clients (web apps, scripts) must make N HTTP requests to embed N texts.
|
||||
|
||||
#### 2. Document Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
|
||||
|
||||
Processes document chunks one at a time:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
v = msg.value()
|
||||
|
||||
# Single chunk per request
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(text=v.chunk)
|
||||
)
|
||||
vectors = resp.vectors
|
||||
```
|
||||
|
||||
**Impact:** Each document chunk requires a separate embedding call. A document with 100 chunks = 100 embedding requests.
|
||||
|
||||
#### 3. Graph Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
|
||||
|
||||
Loops over entities and embeds each one serially:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
for entity in v.entities:
|
||||
# Serial embedding - one entity at a time
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text=entity.context
|
||||
)
|
||||
entities.append(EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors,
|
||||
chunk_id=entity.chunk_id,
|
||||
))
|
||||
```
|
||||
|
||||
**Impact:** A message with 50 entities = 50 serial embedding requests. This is a major bottleneck during knowledge graph construction.
|
||||
|
||||
#### 4. Row Embeddings Service
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
|
||||
|
||||
Loops over unique texts and embeds each one serially:
|
||||
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
# Serial embedding - one text at a time
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
|
||||
embeddings_list.append(RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
))
|
||||
```
|
||||
|
||||
**Impact:** Processing a table with 100 unique indexed values = 100 serial embedding requests.
|
||||
|
||||
#### 5. EmbeddingsClient (Base Client)
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/base/embeddings_client.py`
|
||||
|
||||
The client used by all flow processors only supports single-text embedding:
|
||||
|
||||
```python
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(self, text, timeout=30):
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(text=text), # Single text
|
||||
timeout=timeout
|
||||
)
|
||||
return resp.vectors
|
||||
```
|
||||
|
||||
**Impact:** All callers using this client are limited to single-text operations.
|
||||
|
||||
#### 6. Command-Line Tools
|
||||
|
||||
**File:** `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
|
||||
|
||||
CLI tool accepts single text argument:
|
||||
|
||||
```python
|
||||
def query(url, flow_id, text, token=None):
|
||||
result = flow.embeddings(text=text) # Single text
|
||||
vectors = result.get("vectors", [])
|
||||
```
|
||||
|
||||
**Impact:** Users cannot batch-embed from command line. Processing a file of texts requires N invocations.
|
||||
|
||||
#### 7. Python SDK
|
||||
|
||||
The Python SDK provides two client classes for interacting with TrustGraph services. Both only support single-text embedding.
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/api/flow.py`
|
||||
|
||||
```python
|
||||
class FlowInstance:
|
||||
def embeddings(self, text):
|
||||
"""Get embeddings for a single text"""
|
||||
input = {"text": text}
|
||||
return self.request("service/embeddings", input)["vectors"]
|
||||
```
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/api/socket_client.py`
|
||||
|
||||
```python
|
||||
class SocketFlowInstance:
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Get embeddings for a single text via WebSocket"""
|
||||
request = {"text": text}
|
||||
return self.client._send_request_sync(
|
||||
"embeddings", self.flow_id, request, False
|
||||
)
|
||||
```
|
||||
|
||||
**Impact:** Python developers using the SDK must loop over texts and make N separate API calls. No batch embedding support exists for SDK users.
|
||||
|
||||
### Performance Impact
|
||||
|
||||
For typical document ingestion (1000 text chunks):
|
||||
- **Current**: 1000 separate requests, 1000 model inference calls
|
||||
- **Batched (batch_size=32)**: 32 requests, 32 model inference calls (96.8% reduction)
|
||||
|
||||
For graph embedding (message with 50 entities):
|
||||
- **Current**: 50 serial await calls, ~5-10 seconds
|
||||
- **Batched**: 1-2 batch calls, ~0.5-1 second (5-10x improvement)
|
||||
|
||||
FastEmbed and similar libraries achieve near-linear throughput scaling with batch size up to hardware limits (typically 32-128 texts per batch).
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Architecture
|
||||
|
||||
The embeddings batch processing optimization requires changes to the following components:
|
||||
|
||||
#### 1. **Schema Enhancement**
|
||||
- Extend `EmbeddingsRequest` to support multiple texts
|
||||
- Extend `EmbeddingsResponse` to return multiple vector sets
|
||||
- Maintain backward compatibility with single-text requests
|
||||
|
||||
Module: `trustgraph-base/trustgraph/schema/services/llm.py`
|
||||
|
||||
#### 2. **Base Service Enhancement**
|
||||
- Update `EmbeddingsService` to handle batch requests
|
||||
- Add batch size configuration
|
||||
- Implement batch-aware request handling
|
||||
|
||||
Module: `trustgraph-base/trustgraph/base/embeddings_service.py`
|
||||
|
||||
#### 3. **Provider Processor Updates**
|
||||
- Update FastEmbed processor to pass full batch to `embed()`
|
||||
- Update Ollama processor to handle batches (if supported)
|
||||
- Add fallback sequential processing for providers without batch support
|
||||
|
||||
Modules:
|
||||
- `trustgraph-flow/trustgraph/embeddings/fastembed/processor.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/ollama/processor.py`
|
||||
|
||||
#### 4. **Client Enhancement**
|
||||
- Add batch embedding method to `EmbeddingsClient`
|
||||
- Support both single and batch APIs
|
||||
- Add automatic batching for large inputs
|
||||
|
||||
Module: `trustgraph-base/trustgraph/base/embeddings_client.py`
|
||||
|
||||
#### 5. **Caller Updates - Flow Processors**
|
||||
- Update `graph_embeddings` to batch entity contexts
|
||||
- Update `row_embeddings` to batch index texts
|
||||
- Update `document_embeddings` if message batching is feasible
|
||||
|
||||
Modules:
|
||||
- `trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py`
|
||||
- `trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py`
|
||||
|
||||
#### 6. **API Gateway Enhancement**
|
||||
- Add batch embedding endpoint
|
||||
- Support array of texts in request body
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py`
|
||||
|
||||
#### 7. **CLI Tool Enhancement**
|
||||
- Add support for multiple texts or file input
|
||||
- Add batch size parameter
|
||||
|
||||
Module: `trustgraph-cli/trustgraph/cli/invoke_embeddings.py`
|
||||
|
||||
#### 8. **Python SDK Enhancement**
|
||||
- Add `embeddings_batch()` method to `FlowInstance`
|
||||
- Add `embeddings_batch()` method to `SocketFlowInstance`
|
||||
- Support both single and batch APIs for SDK users
|
||||
|
||||
Modules:
|
||||
- `trustgraph-base/trustgraph/api/flow.py`
|
||||
- `trustgraph-base/trustgraph/api/socket_client.py`
|
||||
|
||||
### Data Models
|
||||
|
||||
#### EmbeddingsRequest
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
texts: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Usage:
|
||||
- Single text: `EmbeddingsRequest(texts=["hello world"])`
|
||||
- Batch: `EmbeddingsRequest(texts=["text1", "text2", "text3"])`
|
||||
|
||||
#### EmbeddingsResponse
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
Response structure:
|
||||
- `vectors[i]` contains the vector set for `texts[i]`
|
||||
- Each vector set is `list[list[float]]` (models may return multiple vectors per text)
|
||||
- Example: 3 texts → `vectors` has 3 entries, each containing that text's embeddings
|
||||
|
||||
### APIs
|
||||
|
||||
#### EmbeddingsClient
|
||||
|
||||
```python
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(
|
||||
self,
|
||||
texts: list[str],
|
||||
timeout: float = 300,
|
||||
) -> list[list[list[float]]]:
|
||||
"""
|
||||
Embed one or more texts in a single request.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
timeout: Timeout for the operation
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(texts=texts),
|
||||
timeout=timeout
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
return resp.vectors
|
||||
```
|
||||
|
||||
#### API Gateway Embeddings Endpoint
|
||||
|
||||
Updated endpoint supporting single or batch embedding:
|
||||
|
||||
```
|
||||
POST /api/v1/embeddings
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"texts": ["text1", "text2", "text3"],
|
||||
"flow_id": "default"
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"vectors": [
|
||||
[[0.1, 0.2, ...]],
|
||||
[[0.3, 0.4, ...]],
|
||||
[[0.5, 0.6, ...]]
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Phase 1: Schema Changes
|
||||
|
||||
**EmbeddingsRequest:**
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
texts: list[str] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**EmbeddingsResponse:**
|
||||
```python
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
```
|
||||
|
||||
**Updated EmbeddingsService.on_request:**
|
||||
```python
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
request = msg.value()
|
||||
id = msg.properties()["id"]
|
||||
model = flow("model")
|
||||
|
||||
vectors = await self.on_embeddings(request.texts, model=model)
|
||||
response = EmbeddingsResponse(error=None, vectors=vectors)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
```
|
||||
|
||||
#### Phase 2: FastEmbed Processor Update
|
||||
|
||||
**Current (Inefficient):**
|
||||
```python
|
||||
async def on_embeddings(self, text, model=None):
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
vecs = self.embeddings.embed([text]) # Batch of 1
|
||||
return [v.tolist() for v in vecs]
|
||||
```
|
||||
|
||||
**Updated:**
|
||||
```python
|
||||
async def on_embeddings(self, texts: list[str], model=None):
|
||||
"""Embed texts - processes all texts in single model call"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
self._load_model(use_model)
|
||||
|
||||
# FastEmbed handles the full batch efficiently
|
||||
all_vecs = list(self.embeddings.embed(texts))
|
||||
|
||||
# Return list of vector sets, one per input text
|
||||
return [[v.tolist()] for v in all_vecs]
|
||||
```
|
||||
|
||||
#### Phase 3: Graph Embeddings Service Update
|
||||
|
||||
**Current (Serial):**
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
entities = []
|
||||
for entity in v.entities:
|
||||
vectors = await flow("embeddings-request").embed(text=entity.context)
|
||||
entities.append(EntityEmbeddings(...))
|
||||
```
|
||||
|
||||
**Updated (Batch):**
|
||||
```python
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
# Collect all contexts
|
||||
contexts = [entity.context for entity in v.entities]
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=contexts)
|
||||
|
||||
# Pair results with entities
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors[0], # First vector from the set
|
||||
chunk_id=entity.chunk_id,
|
||||
)
|
||||
for entity, vectors in zip(v.entities, all_vectors)
|
||||
]
|
||||
```
|
||||
|
||||
#### Phase 4: Row Embeddings Service Update
|
||||
|
||||
**Current (Serial):**
|
||||
```python
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
embeddings_list.append(RowIndexEmbedding(...))
|
||||
```
|
||||
|
||||
**Updated (Batch):**
|
||||
```python
|
||||
# Collect texts and metadata
|
||||
texts = list(texts_to_embed.keys())
|
||||
metadata = list(texts_to_embed.values())
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||
|
||||
# Pair results
|
||||
embeddings_list = [
|
||||
RowIndexEmbedding(
|
||||
index_name=meta[0],
|
||||
index_value=meta[1],
|
||||
text=text,
|
||||
vectors=vectors[0] # First vector from the set
|
||||
)
|
||||
for text, meta, vectors in zip(texts, metadata, all_vectors)
|
||||
]
|
||||
```
|
||||
|
||||
#### Phase 5: CLI Tool Enhancement
|
||||
|
||||
**Updated CLI:**
|
||||
```python
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(...)
|
||||
|
||||
parser.add_argument(
|
||||
'text',
|
||||
nargs='*', # Zero or more texts
|
||||
help='Text(s) to convert to embedding vectors',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--file',
|
||||
help='File containing texts (one per line)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=32,
|
||||
help='Batch size for processing (default: 32)',
|
||||
)
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
# Single text (existing)
|
||||
tg-invoke-embeddings "hello world"
|
||||
|
||||
# Multiple texts
|
||||
tg-invoke-embeddings "text one" "text two" "text three"
|
||||
|
||||
# From file
|
||||
tg-invoke-embeddings -f texts.txt --batch-size 64
|
||||
```
|
||||
|
||||
#### Phase 6: Python SDK Enhancement
|
||||
|
||||
**FlowInstance (HTTP client):**
|
||||
|
||||
```python
|
||||
class FlowInstance:
|
||||
def embeddings(self, texts: list[str]) -> list[list[list[float]]]:
|
||||
"""
|
||||
Get embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
input = {"texts": texts}
|
||||
return self.request("service/embeddings", input)["vectors"]
|
||||
```
|
||||
|
||||
**SocketFlowInstance (WebSocket client):**
|
||||
|
||||
```python
|
||||
class SocketFlowInstance:
|
||||
def embeddings(self, texts: list[str], **kwargs: Any) -> list[list[list[float]]]:
|
||||
"""
|
||||
Get embeddings for one or more texts via WebSocket.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of vector sets, one per input text
|
||||
"""
|
||||
request = {"texts": texts}
|
||||
response = self.client._send_request_sync(
|
||||
"embeddings", self.flow_id, request, False
|
||||
)
|
||||
return response["vectors"]
|
||||
```
|
||||
|
||||
**SDK Usage Examples:**
|
||||
|
||||
```python
|
||||
# Single text
|
||||
vectors = flow.embeddings(["hello world"])
|
||||
print(f"Dimensions: {len(vectors[0][0])}")
|
||||
|
||||
# Batch embedding
|
||||
texts = ["text one", "text two", "text three"]
|
||||
all_vectors = flow.embeddings(texts)
|
||||
|
||||
# Process results
|
||||
for text, vecs in zip(texts, all_vectors):
|
||||
print(f"{text}: {len(vecs[0])} dimensions")
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Request Size Limits**: Enforce maximum batch size to prevent resource exhaustion
|
||||
- **Timeout Handling**: Scale timeouts appropriately for batch size
|
||||
- **Memory Limits**: Monitor memory usage for large batches
|
||||
- **Input Validation**: Validate all texts in batch before processing
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Expected Improvements
|
||||
|
||||
**Throughput:**
|
||||
- Single-text: ~10-50 texts/second (depending on model)
|
||||
- Batch (size 32): ~200-500 texts/second (5-10x improvement)
|
||||
|
||||
**Latency per Text:**
|
||||
- Single-text: 50-200ms per text
|
||||
- Batch (size 32): 5-20ms per text (amortized)
|
||||
|
||||
**Service-Specific Improvements:**
|
||||
|
||||
| Service | Current | Batched | Improvement |
|
||||
|---------|---------|---------|-------------|
|
||||
| Graph Embeddings (50 entities) | 5-10s | 0.5-1s | 5-10x |
|
||||
| Row Embeddings (100 texts) | 10-20s | 1-2s | 5-10x |
|
||||
| Document Ingestion (1000 chunks) | 100-200s | 10-30s | 5-10x |
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
```python
|
||||
# Recommended defaults
|
||||
DEFAULT_BATCH_SIZE = 32
|
||||
MAX_BATCH_SIZE = 128
|
||||
BATCH_TIMEOUT_MULTIPLIER = 2.0
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Testing
|
||||
- Single text embedding (backward compatibility)
|
||||
- Empty batch handling
|
||||
- Maximum batch size enforcement
|
||||
- Error handling for partial batch failures
|
||||
|
||||
### Integration Testing
|
||||
- End-to-end batch embedding through Pulsar
|
||||
- Graph embeddings service batch processing
|
||||
- Row embeddings service batch processing
|
||||
- API gateway batch endpoint
|
||||
|
||||
### Performance Testing
|
||||
- Benchmark single vs batch throughput
|
||||
- Memory usage under various batch sizes
|
||||
- Latency distribution analysis
|
||||
|
||||
## Migration Plan
|
||||
|
||||
This is a breaking change release. All phases are implemented together.
|
||||
|
||||
### Phase 1: Schema Changes
|
||||
- Replace `text: str` with `texts: list[str]` in EmbeddingsRequest
|
||||
- Change `vectors` type to `list[list[list[float]]]` in EmbeddingsResponse
|
||||
|
||||
### Phase 2: Processor Updates
|
||||
- Update `on_embeddings` signature in FastEmbed and Ollama processors
|
||||
- Process full batch in single model call
|
||||
|
||||
### Phase 3: Client Updates
|
||||
- Update `EmbeddingsClient.embed()` to accept `texts: list[str]`
|
||||
|
||||
### Phase 4: Caller Updates
|
||||
- Update graph_embeddings to batch entity contexts
|
||||
- Update row_embeddings to batch index texts
|
||||
- Update document_embeddings to use new schema
|
||||
- Update CLI tool
|
||||
|
||||
### Phase 5: API Gateway
|
||||
- Update embeddings endpoint for new schema
|
||||
|
||||
### Phase 6: Python SDK
|
||||
- Update `FlowInstance.embeddings()` signature
|
||||
- Update `SocketFlowInstance.embeddings()` signature
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **Streaming Large Batches**: Should we support streaming results for very large batches (>100 texts)?
|
||||
- **Provider-Specific Limits**: How should we handle providers with different maximum batch sizes?
|
||||
- **Partial Failure Handling**: If one text in a batch fails, should we fail the entire batch or return partial results?
|
||||
- **Document Embeddings Batching**: Should we batch across multiple Chunk messages or keep per-message processing?
|
||||
|
||||
## References
|
||||
|
||||
- [FastEmbed Documentation](https://github.com/qdrant/fastembed)
|
||||
- [Ollama Embeddings API](https://github.com/ollama/ollama)
|
||||
- [EmbeddingsService Implementation](trustgraph-base/trustgraph/base/embeddings_service.py)
|
||||
- [GraphRAG Performance Optimization](graphrag-performance-optimization.md)
|
||||
|
|
@ -42,7 +42,7 @@ CREATE TABLE quads_by_entity (
|
|||
d text, -- Dataset/graph of the quad
|
||||
dtype text, -- XSD datatype (when otype = 'L'), e.g. 'xsd:string'
|
||||
lang text, -- Language tag (when otype = 'L'), e.g. 'en', 'fr'
|
||||
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
|
||||
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d, dtype, lang)
|
||||
);
|
||||
```
|
||||
|
||||
|
|
@ -54,6 +54,7 @@ CREATE TABLE quads_by_entity (
|
|||
2. **p** — next most common filter, "give me all `knows` relationships"
|
||||
3. **otype** — enables filtering by URI-valued vs literal-valued relationships
|
||||
4. **s, o, d** — remaining columns for uniqueness
|
||||
5. **dtype, lang** — distinguish literals with same value but different type metadata (e.g., `"thing"` vs `"thing"@en` vs `"thing"^^xsd:string`)
|
||||
|
||||
### Table 2: quads_by_collection
|
||||
|
||||
|
|
@ -69,11 +70,11 @@ CREATE TABLE quads_by_collection (
|
|||
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
|
||||
dtype text, -- XSD datatype (when otype = 'L')
|
||||
lang text, -- Language tag (when otype = 'L')
|
||||
PRIMARY KEY (collection, d, s, p, o)
|
||||
PRIMARY KEY (collection, d, s, p, o, otype, dtype, lang)
|
||||
);
|
||||
```
|
||||
|
||||
Clustered by dataset first, enabling deletion at either collection or dataset granularity.
|
||||
Clustered by dataset first, enabling deletion at either collection or dataset granularity. The `otype`, `dtype`, and `lang` columns are included in the clustering key to distinguish literals with the same value but different type metadata — in RDF, `"thing"`, `"thing"@en`, and `"thing"^^xsd:string` are semantically distinct values.
|
||||
|
||||
## Write Path
|
||||
|
||||
|
|
|
|||
220
docs/tech-specs/explainability-cli.md
Normal file
220
docs/tech-specs/explainability-cli.md
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
# Explainability CLI Technical Specification
|
||||
|
||||
## Status
|
||||
|
||||
Draft
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes CLI tools for debugging and exploring explainability data in TrustGraph. These tools enable users to trace how answers were derived and debug the provenance chain from edges back to source documents.
|
||||
|
||||
Three CLI tools:
|
||||
|
||||
1. **`tg-show-document-hierarchy`** - Show document → page → chunk → edge hierarchy
|
||||
2. **`tg-list-explain-traces`** - List all GraphRAG sessions with questions
|
||||
3. **`tg-show-explain-trace`** - Show full explainability trace for a session
|
||||
|
||||
## Goals
|
||||
|
||||
- **Debugging**: Enable developers to inspect document processing results
|
||||
- **Auditability**: Trace any extracted fact back to its source document
|
||||
- **Transparency**: Show exactly how GraphRAG derived an answer
|
||||
- **Usability**: Simple CLI interface with sensible defaults
|
||||
|
||||
## Background
|
||||
|
||||
TrustGraph has two provenance systems:
|
||||
|
||||
1. **Extraction-time provenance** (see `extraction-time-provenance.md`): Records document → page → chunk → edge relationships during ingestion. Stored in `urn:graph:source` named graph using `prov:wasDerivedFrom`.
|
||||
|
||||
2. **Query-time explainability** (see `query-time-explainability.md`): Records question → exploration → focus → synthesis chain during GraphRAG queries. Stored in `urn:graph:retrieval` named graph.
|
||||
|
||||
Current limitations:
|
||||
- No easy way to visualize document hierarchy after processing
|
||||
- Must manually query triples to see explainability data
|
||||
- No consolidated view of a GraphRAG session
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Tool 1: tg-show-document-hierarchy
|
||||
|
||||
**Purpose**: Given a document ID, traverse and display all derived entities.
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
tg-show-document-hierarchy "urn:trustgraph:doc:abc123"
|
||||
tg-show-document-hierarchy --show-content --max-content 500 "urn:trustgraph:doc:abc123"
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
| Arg | Description |
|
||||
|-----|-------------|
|
||||
| `document_id` | Document URI (positional) |
|
||||
| `-u/--api-url` | Gateway URL (default: `$TRUSTGRAPH_URL`) |
|
||||
| `-t/--token` | Auth token (default: `$TRUSTGRAPH_TOKEN`) |
|
||||
| `-U/--user` | User ID (default: `trustgraph`) |
|
||||
| `-C/--collection` | Collection (default: `default`) |
|
||||
| `--show-content` | Include blob/document content |
|
||||
| `--max-content` | Max chars per blob (default: 200) |
|
||||
| `--format` | Output: `tree` (default), `json` |
|
||||
|
||||
**Implementation**:
|
||||
1. Query triples: `?child prov:wasDerivedFrom <document_id>` in `urn:graph:source`
|
||||
2. Recursively query children of each result
|
||||
3. Build tree structure: Document → Pages → Chunks
|
||||
4. If `--show-content`, fetch content from librarian API
|
||||
5. Display as indented tree or JSON
|
||||
|
||||
**Output Example**:
|
||||
```
|
||||
Document: urn:trustgraph:doc:abc123
|
||||
Title: "Sample PDF"
|
||||
Type: application/pdf
|
||||
|
||||
└── Page 1: urn:trustgraph:doc:abc123/p1
|
||||
├── Chunk 0: urn:trustgraph:doc:abc123/p1/c0
|
||||
│ Content: "The quick brown fox..." [truncated]
|
||||
└── Chunk 1: urn:trustgraph:doc:abc123/p1/c1
|
||||
Content: "Machine learning is..." [truncated]
|
||||
```
|
||||
|
||||
### Tool 2: tg-list-explain-traces
|
||||
|
||||
**Purpose**: List all GraphRAG sessions (questions) in a collection.
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
tg-list-explain-traces
|
||||
tg-list-explain-traces --limit 20 --format json
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
| Arg | Description |
|
||||
|-----|-------------|
|
||||
| `-u/--api-url` | Gateway URL |
|
||||
| `-t/--token` | Auth token |
|
||||
| `-U/--user` | User ID |
|
||||
| `-C/--collection` | Collection |
|
||||
| `--limit` | Max results (default: 50) |
|
||||
| `--format` | Output: `table` (default), `json` |
|
||||
|
||||
**Implementation**:
|
||||
1. Query: `?session tg:query ?text` in `urn:graph:retrieval`
|
||||
2. Query timestamps: `?session prov:startedAtTime ?time`
|
||||
3. Display as table
|
||||
|
||||
**Output Example**:
|
||||
```
|
||||
Session ID | Question | Time
|
||||
----------------------------------------------|--------------------------------|---------------------
|
||||
urn:trustgraph:question:abc123 | What was the War on Terror? | 2024-01-15 10:30:00
|
||||
urn:trustgraph:question:def456 | Who founded OpenAI? | 2024-01-15 09:15:00
|
||||
```
|
||||
|
||||
### Tool 3: tg-show-explain-trace
|
||||
|
||||
**Purpose**: Show full explainability cascade for a GraphRAG session.
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
tg-show-explain-trace "urn:trustgraph:question:abc123"
|
||||
tg-show-explain-trace --max-answer 1000 --show-provenance "urn:trustgraph:question:abc123"
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
| Arg | Description |
|
||||
|-----|-------------|
|
||||
| `question_id` | Question URI (positional) |
|
||||
| `-u/--api-url` | Gateway URL |
|
||||
| `-t/--token` | Auth token |
|
||||
| `-U/--user` | User ID |
|
||||
| `-C/--collection` | Collection |
|
||||
| `--max-answer` | Max chars for answer (default: 500) |
|
||||
| `--show-provenance` | Trace edges to source documents |
|
||||
| `--format` | Output: `text` (default), `json` |
|
||||
|
||||
**Implementation**:
|
||||
1. Get question text from `tg:query` predicate
|
||||
2. Find exploration: `?exp prov:wasGeneratedBy <question_id>`
|
||||
3. Find focus: `?focus prov:wasDerivedFrom <exploration_id>`
|
||||
4. Get selected edges: `<focus_id> tg:selectedEdge ?edge`
|
||||
5. For each edge, get `tg:edge` (quoted triple) and `tg:reasoning`
|
||||
6. Find synthesis: `?synth prov:wasDerivedFrom <focus_id>`
|
||||
7. Get answer from `tg:document` via librarian
|
||||
8. If `--show-provenance`, trace edges to source documents
|
||||
|
||||
**Output Example**:
|
||||
```
|
||||
=== GraphRAG Session: urn:trustgraph:question:abc123 ===
|
||||
|
||||
Question: What was the War on Terror?
|
||||
Time: 2024-01-15 10:30:00
|
||||
|
||||
--- Exploration ---
|
||||
Retrieved 50 edges from knowledge graph
|
||||
|
||||
--- Focus (Edge Selection) ---
|
||||
Selected 12 edges:
|
||||
|
||||
1. (War on Terror, definition, "A military campaign...")
|
||||
Reasoning: Directly defines the subject of the query
|
||||
Source: chunk → page 2 → "Beyond the Vigilant State"
|
||||
|
||||
2. (Guantanamo Bay, part_of, War on Terror)
|
||||
Reasoning: Shows key component of the campaign
|
||||
|
||||
--- Synthesis ---
|
||||
Answer:
|
||||
The War on Terror was a military campaign initiated...
|
||||
[truncated at 500 chars]
|
||||
```
|
||||
|
||||
## Files to Create
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `trustgraph-cli/trustgraph/cli/show_document_hierarchy.py` | Tool 1 |
|
||||
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | Tool 2 |
|
||||
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Tool 3 |
|
||||
|
||||
## Files to Modify
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `trustgraph-cli/setup.py` | Add console_scripts entries |
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
1. **Binary content safety**: Try UTF-8 decode; if fails, show `[Binary: {size} bytes]`
|
||||
2. **Truncation**: Respect `--max-content`/`--max-answer` with `[truncated]` indicator
|
||||
3. **Quoted triples**: Parse RDF-star format from `tg:edge` predicate
|
||||
4. **Patterns**: Follow existing CLI patterns from `query_graph.py`
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- All queries respect user/collection boundaries
|
||||
- Token authentication supported via `--token` or `$TRUSTGRAPH_TOKEN`
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Manual verification with sample data:
|
||||
```bash
|
||||
# Load a test document
|
||||
tg-load-pdf -f test.pdf -c test-collection
|
||||
|
||||
# Verify hierarchy
|
||||
tg-show-document-hierarchy "urn:trustgraph:doc:test"
|
||||
|
||||
# Run a GraphRAG query with explainability
|
||||
tg-invoke-graph-rag --explainable -q "Test question"
|
||||
|
||||
# List and inspect traces
|
||||
tg-list-explain-traces
|
||||
tg-show-explain-trace "urn:trustgraph:question:xxx"
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- Query-time explainability: `docs/tech-specs/query-time-explainability.md`
|
||||
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
|
||||
- Existing CLI example: `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py`
|
||||
347
docs/tech-specs/extraction-flows.md
Normal file
347
docs/tech-specs/extraction-flows.md
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
# Extraction Flows
|
||||
|
||||
This document describes how data flows through the TrustGraph extraction pipeline, from document submission through to storage in knowledge stores.
|
||||
|
||||
## Overview
|
||||
|
||||
```
|
||||
┌──────────┐ ┌─────────────┐ ┌─────────┐ ┌────────────────────┐
|
||||
│ Librarian│────▶│ PDF Decoder │────▶│ Chunker │────▶│ Knowledge │
|
||||
│ │ │ (PDF only) │ │ │ │ Extraction │
|
||||
│ │────────────────────────▶│ │ │ │
|
||||
└──────────┘ └─────────────┘ └─────────┘ └────────────────────┘
|
||||
│ │
|
||||
│ ├──▶ Triples
|
||||
│ ├──▶ Entity Contexts
|
||||
│ └──▶ Rows
|
||||
│
|
||||
└──▶ Document Embeddings
|
||||
```
|
||||
|
||||
## Content Storage
|
||||
|
||||
### Blob Storage (S3/Minio)
|
||||
|
||||
Document content is stored in S3-compatible blob storage:
|
||||
- Path format: `doc/{object_id}` where object_id is a UUID
|
||||
- All document types stored here: source documents, pages, chunks
|
||||
|
||||
### Metadata Storage (Cassandra)
|
||||
|
||||
Document metadata stored in Cassandra includes:
|
||||
- Document ID, title, kind (MIME type)
|
||||
- `object_id` reference to blob storage
|
||||
- `parent_id` for child documents (pages, chunks)
|
||||
- `document_type`: "source", "page", "chunk", "answer"
|
||||
|
||||
### Inline vs Streaming Threshold
|
||||
|
||||
Content transmission uses a size-based strategy:
|
||||
- **< 2MB**: Content included inline in message (base64-encoded)
|
||||
- **≥ 2MB**: Only `document_id` sent; processor fetches via librarian API
|
||||
|
||||
## Stage 1: Document Submission (Librarian)
|
||||
|
||||
### Entry Point
|
||||
|
||||
Documents enter the system via librarian's `add-document` operation:
|
||||
1. Content uploaded to blob storage
|
||||
2. Metadata record created in Cassandra
|
||||
3. Returns document ID
|
||||
|
||||
### Triggering Extraction
|
||||
|
||||
The `add-processing` operation triggers extraction:
|
||||
- Specifies `document_id`, `flow` (pipeline ID), `collection` (target store)
|
||||
- Librarian's `load_document()` fetches content and publishes to flow input queue
|
||||
|
||||
### Schema: Document
|
||||
|
||||
```
|
||||
Document
|
||||
├── metadata: Metadata
|
||||
│ ├── id: str # Document identifier
|
||||
│ ├── user: str # Tenant/user ID
|
||||
│ ├── collection: str # Target collection
|
||||
│ └── metadata: list[Triple] # (largely unused, historical)
|
||||
├── data: bytes # PDF content (base64, if inline)
|
||||
└── document_id: str # Librarian reference (if streaming)
|
||||
```
|
||||
|
||||
**Routing**: Based on `kind` field:
|
||||
- `application/pdf` → `document-load` queue → PDF Decoder
|
||||
- `text/plain` → `text-load` queue → Chunker
|
||||
|
||||
## Stage 2: PDF Decoder
|
||||
|
||||
Converts PDF documents into text pages.
|
||||
|
||||
### Process
|
||||
|
||||
1. Fetch content (inline `data` or via `document_id` from librarian)
|
||||
2. Extract pages using PyPDF
|
||||
3. For each page:
|
||||
- Save as child document in librarian (`{doc_id}/p{page_num}`)
|
||||
- Emit provenance triples (page derived from document)
|
||||
- Forward to chunker
|
||||
|
||||
### Schema: TextDocument
|
||||
|
||||
```
|
||||
TextDocument
|
||||
├── metadata: Metadata
|
||||
│ ├── id: str # Page URI (e.g., https://trustgraph.ai/doc/xxx/p1)
|
||||
│ ├── user: str
|
||||
│ ├── collection: str
|
||||
│ └── metadata: list[Triple]
|
||||
├── text: bytes # Page text content (if inline)
|
||||
└── document_id: str # Librarian reference (e.g., "doc123/p1")
|
||||
```
|
||||
|
||||
## Stage 3: Chunker
|
||||
|
||||
Splits text into chunks at configured size.
|
||||
|
||||
### Parameters (flow-configurable)
|
||||
|
||||
- `chunk_size`: Target chunk size in characters (default: 2000)
|
||||
- `chunk_overlap`: Overlap between chunks (default: 100)
|
||||
|
||||
### Process
|
||||
|
||||
1. Fetch text content (inline or via librarian)
|
||||
2. Split using recursive character splitter
|
||||
3. For each chunk:
|
||||
- Save as child document in librarian (`{parent_id}/c{index}`)
|
||||
- Emit provenance triples (chunk derived from page/document)
|
||||
- Forward to extraction processors
|
||||
|
||||
### Schema: Chunk
|
||||
|
||||
```
|
||||
Chunk
|
||||
├── metadata: Metadata
|
||||
│ ├── id: str # Chunk URI
|
||||
│ ├── user: str
|
||||
│ ├── collection: str
|
||||
│ └── metadata: list[Triple]
|
||||
├── chunk: bytes # Chunk text content
|
||||
└── document_id: str # Librarian chunk ID (e.g., "doc123/p1/c3")
|
||||
```
|
||||
|
||||
### Document ID Hierarchy
|
||||
|
||||
Child documents encode their lineage in the ID:
|
||||
- Source: `doc123`
|
||||
- Page: `doc123/p5`
|
||||
- Chunk from page: `doc123/p5/c2`
|
||||
- Chunk from text: `doc123/c2`
|
||||
|
||||
## Stage 4: Knowledge Extraction
|
||||
|
||||
Multiple extraction patterns available, selected by flow configuration.
|
||||
|
||||
### Pattern A: Basic GraphRAG
|
||||
|
||||
Two parallel processors:
|
||||
|
||||
**kg-extract-definitions**
|
||||
- Input: Chunk
|
||||
- Output: Triples (entity definitions), EntityContexts
|
||||
- Extracts: entity labels, definitions
|
||||
|
||||
**kg-extract-relationships**
|
||||
- Input: Chunk
|
||||
- Output: Triples (relationships), EntityContexts
|
||||
- Extracts: subject-predicate-object relationships
|
||||
|
||||
### Pattern B: Ontology-Driven (kg-extract-ontology)
|
||||
|
||||
- Input: Chunk
|
||||
- Output: Triples, EntityContexts
|
||||
- Uses configured ontology to guide extraction
|
||||
|
||||
### Pattern C: Agent-Based (kg-extract-agent)
|
||||
|
||||
- Input: Chunk
|
||||
- Output: Triples, EntityContexts
|
||||
- Uses agent framework for extraction
|
||||
|
||||
### Pattern D: Row Extraction (kg-extract-rows)
|
||||
|
||||
- Input: Chunk
|
||||
- Output: Rows (structured data, not triples)
|
||||
- Uses schema definition to extract structured records
|
||||
|
||||
### Schema: Triples
|
||||
|
||||
```
|
||||
Triples
|
||||
├── metadata: Metadata
|
||||
│ ├── id: str
|
||||
│ ├── user: str
|
||||
│ ├── collection: str
|
||||
│ └── metadata: list[Triple] # (set to [] by extractors)
|
||||
└── triples: list[Triple]
|
||||
└── Triple
|
||||
├── s: Term # Subject
|
||||
├── p: Term # Predicate
|
||||
├── o: Term # Object
|
||||
└── g: str | None # Named graph
|
||||
```
|
||||
|
||||
### Schema: EntityContexts
|
||||
|
||||
```
|
||||
EntityContexts
|
||||
├── metadata: Metadata
|
||||
└── entities: list[EntityContext]
|
||||
└── EntityContext
|
||||
├── entity: Term # Entity identifier (IRI)
|
||||
├── context: str # Textual description for embedding
|
||||
└── chunk_id: str # Source chunk ID (provenance)
|
||||
```
|
||||
|
||||
### Schema: Rows
|
||||
|
||||
```
|
||||
Rows
|
||||
├── metadata: Metadata
|
||||
├── row_schema: RowSchema
|
||||
│ ├── name: str
|
||||
│ ├── description: str
|
||||
│ └── fields: list[Field]
|
||||
└── rows: list[dict[str, str]] # Extracted records
|
||||
```
|
||||
|
||||
## Stage 5: Embeddings Generation
|
||||
|
||||
### Graph Embeddings
|
||||
|
||||
Converts entity contexts into vector embeddings.
|
||||
|
||||
**Process:**
|
||||
1. Receive EntityContexts
|
||||
2. Call embeddings service with context text
|
||||
3. Output GraphEmbeddings (entity → vector mapping)
|
||||
|
||||
**Schema: GraphEmbeddings**
|
||||
|
||||
```
|
||||
GraphEmbeddings
|
||||
├── metadata: Metadata
|
||||
└── entities: list[EntityEmbeddings]
|
||||
└── EntityEmbeddings
|
||||
├── entity: Term # Entity identifier
|
||||
├── vector: list[float] # Embedding vector
|
||||
└── chunk_id: str # Source chunk (provenance)
|
||||
```
|
||||
|
||||
### Document Embeddings
|
||||
|
||||
Converts chunk text directly into vector embeddings.
|
||||
|
||||
**Process:**
|
||||
1. Receive Chunk
|
||||
2. Call embeddings service with chunk text
|
||||
3. Output DocumentEmbeddings
|
||||
|
||||
**Schema: DocumentEmbeddings**
|
||||
|
||||
```
|
||||
DocumentEmbeddings
|
||||
├── metadata: Metadata
|
||||
└── chunks: list[ChunkEmbeddings]
|
||||
└── ChunkEmbeddings
|
||||
├── chunk_id: str # Chunk identifier
|
||||
└── vector: list[float] # Embedding vector
|
||||
```
|
||||
|
||||
### Row Embeddings
|
||||
|
||||
Converts row index fields into vector embeddings.
|
||||
|
||||
**Process:**
|
||||
1. Receive Rows
|
||||
2. Embed configured index fields
|
||||
3. Output to row vector store
|
||||
|
||||
## Stage 6: Storage
|
||||
|
||||
### Triple Store
|
||||
|
||||
- Receives: Triples
|
||||
- Storage: Cassandra (entity-centric tables)
|
||||
- Named graphs separate core knowledge from provenance:
|
||||
- `""` (default): Core knowledge facts
|
||||
- `urn:graph:source`: Extraction provenance
|
||||
- `urn:graph:retrieval`: Query-time explainability
|
||||
|
||||
### Vector Store (Graph Embeddings)
|
||||
|
||||
- Receives: GraphEmbeddings
|
||||
- Storage: Qdrant, Milvus, or Pinecone
|
||||
- Indexed by: entity IRI
|
||||
- Metadata: chunk_id for provenance
|
||||
|
||||
### Vector Store (Document Embeddings)
|
||||
|
||||
- Receives: DocumentEmbeddings
|
||||
- Storage: Qdrant, Milvus, or Pinecone
|
||||
- Indexed by: chunk_id
|
||||
|
||||
### Row Store
|
||||
|
||||
- Receives: Rows
|
||||
- Storage: Cassandra
|
||||
- Schema-driven table structure
|
||||
|
||||
### Row Vector Store
|
||||
|
||||
- Receives: Row embeddings
|
||||
- Storage: Vector DB
|
||||
- Indexed by: row index fields
|
||||
|
||||
## Metadata Field Analysis
|
||||
|
||||
### Actively Used Fields
|
||||
|
||||
| Field | Usage |
|
||||
|-------|-------|
|
||||
| `metadata.id` | Document/chunk identifier, logging, provenance |
|
||||
| `metadata.user` | Multi-tenancy, storage routing |
|
||||
| `metadata.collection` | Target collection selection |
|
||||
| `document_id` | Librarian reference, provenance linking |
|
||||
| `chunk_id` | Provenance tracking through pipeline |
|
||||
|
||||
<<<<<<< HEAD
|
||||
### Potentially Redundant Fields
|
||||
|
||||
| Field | Status |
|
||||
|-------|--------|
|
||||
| `metadata.metadata` | Set to `[]` by all extractors; document-level metadata now handled by librarian at submission time |
|
||||
=======
|
||||
### Removed Fields
|
||||
|
||||
| Field | Status |
|
||||
|-------|--------|
|
||||
| `metadata.metadata` | Removed from `Metadata` class. Document-level metadata triples are now emitted directly by librarian to triple store at submission time, not carried through the extraction pipeline. |
|
||||
>>>>>>> e3bcbf73 (The metadata field (list of triples) in the pipeline Metadata class)
|
||||
|
||||
### Bytes Fields Pattern
|
||||
|
||||
All content fields (`data`, `text`, `chunk`) are `bytes` but immediately decoded to UTF-8 strings by all processors. No processor uses raw bytes.
|
||||
|
||||
## Flow Configuration
|
||||
|
||||
Flows are defined externally and provided to librarian via config service. Each flow specifies:
|
||||
|
||||
- Input queues (`text-load`, `document-load`)
|
||||
- Processor chain
|
||||
- Parameters (chunk size, extraction method, etc.)
|
||||
|
||||
Example flow patterns:
|
||||
- `pdf-graphrag`: PDF → Decoder → Chunker → Definitions + Relationships → Embeddings
|
||||
- `text-graphrag`: Text → Chunker → Definitions + Relationships → Embeddings
|
||||
- `pdf-ontology`: PDF → Decoder → Chunker → Ontology Extraction → Embeddings
|
||||
- `text-rows`: Text → Chunker → Row Extraction → Row Store
|
||||
205
docs/tech-specs/extraction-provenance-subgraph.md
Normal file
205
docs/tech-specs/extraction-provenance-subgraph.md
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
# Extraction Provenance: Subgraph Model
|
||||
|
||||
## Problem
|
||||
|
||||
Extraction-time provenance currently generates a full reification per
|
||||
extracted triple: a unique `stmt_uri`, `activity_uri`, and associated
|
||||
PROV-O metadata for every single knowledge fact. Processing one chunk
|
||||
that yields 20 relationships produces ~220 provenance triples on top of
|
||||
the ~20 knowledge triples — a roughly 10:1 overhead.
|
||||
|
||||
This is both expensive (storage, indexing, transmission) and semantically
|
||||
inaccurate. Each chunk is processed by a single LLM call that produces
|
||||
all its triples in one transaction. The current per-triple model
|
||||
obscures that by creating the illusion of 20 independent extraction
|
||||
events.
|
||||
|
||||
Additionally, two of the four extraction processors (kg-extract-ontology,
|
||||
kg-extract-agent) have no provenance at all, leaving gaps in the audit
|
||||
trail.
|
||||
|
||||
## Solution
|
||||
|
||||
Replace per-triple reification with a **subgraph model**: one provenance
|
||||
record per chunk extraction, shared across all triples produced from that
|
||||
chunk.
|
||||
|
||||
### Terminology Change
|
||||
|
||||
| Old | New |
|
||||
|-----|-----|
|
||||
| `stmt_uri` (`https://trustgraph.ai/stmt/{uuid}`) | `subgraph_uri` (`https://trustgraph.ai/subgraph/{uuid}`) |
|
||||
| `statement_uri()` | `subgraph_uri()` |
|
||||
| `tg:reifies` (1:1, identity) | `tg:contains` (1:many, containment) |
|
||||
|
||||
### Target Structure
|
||||
|
||||
All provenance triples go in the `urn:graph:source` named graph.
|
||||
|
||||
```
|
||||
# Subgraph contains each extracted triple (RDF-star quoted triples)
|
||||
<subgraph> tg:contains <<s1 p1 o1>> .
|
||||
<subgraph> tg:contains <<s2 p2 o2>> .
|
||||
<subgraph> tg:contains <<s3 p3 o3>> .
|
||||
|
||||
# Derivation from source chunk
|
||||
<subgraph> prov:wasDerivedFrom <chunk_uri> .
|
||||
<subgraph> prov:wasGeneratedBy <activity> .
|
||||
|
||||
# Activity: one per chunk extraction
|
||||
<activity> rdf:type prov:Activity .
|
||||
<activity> rdfs:label "{component_name} extraction" .
|
||||
<activity> prov:used <chunk_uri> .
|
||||
<activity> prov:wasAssociatedWith <agent> .
|
||||
<activity> prov:startedAtTime "2026-03-13T10:00:00Z" .
|
||||
<activity> tg:componentVersion "0.25.0" .
|
||||
<activity> tg:llmModel "gpt-4" . # if available
|
||||
<activity> tg:ontology <ontology_uri> . # if available
|
||||
|
||||
# Agent: stable per component
|
||||
<agent> rdf:type prov:Agent .
|
||||
<agent> rdfs:label "{component_name}" .
|
||||
```
|
||||
|
||||
### Volume Comparison
|
||||
|
||||
For a chunk producing N extracted triples:
|
||||
|
||||
| | Old (per-triple) | New (subgraph) |
|
||||
|---|---|---|
|
||||
| `tg:contains` / `tg:reifies` | N | N |
|
||||
| Activity triples | ~9 x N | ~9 |
|
||||
| Agent triples | 2 x N | 2 |
|
||||
| Statement/subgraph metadata | 2 x N | 2 |
|
||||
| **Total provenance triples** | **~13N** | **N + 13** |
|
||||
| **Example (N=20)** | **~260** | **33** |
|
||||
|
||||
## Scope
|
||||
|
||||
### Processors to Update (existing provenance, per-triple)
|
||||
|
||||
**kg-extract-definitions**
|
||||
(`trustgraph-flow/trustgraph/extract/kg/definitions/extract.py`)
|
||||
|
||||
Currently calls `statement_uri()` + `triple_provenance_triples()` inside
|
||||
the per-definition loop.
|
||||
|
||||
Changes:
|
||||
- Move `subgraph_uri()` and `activity_uri()` creation before the loop
|
||||
- Collect `tg:contains` triples inside the loop
|
||||
- Emit shared activity/agent/derivation block once after the loop
|
||||
|
||||
**kg-extract-relationships**
|
||||
(`trustgraph-flow/trustgraph/extract/kg/relationships/extract.py`)
|
||||
|
||||
Same pattern as definitions. Same changes.
|
||||
|
||||
### Processors to Add Provenance (currently missing)
|
||||
|
||||
**kg-extract-ontology**
|
||||
(`trustgraph-flow/trustgraph/extract/kg/ontology/extract.py`)
|
||||
|
||||
Currently emits triples with no provenance. Add subgraph provenance
|
||||
using the same pattern: one subgraph per chunk, `tg:contains` for each
|
||||
extracted triple.
|
||||
|
||||
**kg-extract-agent**
|
||||
(`trustgraph-flow/trustgraph/extract/kg/agent/extract.py`)
|
||||
|
||||
Currently emits triples with no provenance. Add subgraph provenance
|
||||
using the same pattern.
|
||||
|
||||
### Shared Provenance Library Changes
|
||||
|
||||
**`trustgraph-base/trustgraph/provenance/triples.py`**
|
||||
|
||||
- Replace `triple_provenance_triples()` with `subgraph_provenance_triples()`
|
||||
- New function accepts a list of extracted triples instead of a single one
|
||||
- Generates one `tg:contains` per triple, shared activity/agent block
|
||||
- Remove old `triple_provenance_triples()`
|
||||
|
||||
**`trustgraph-base/trustgraph/provenance/uris.py`**
|
||||
|
||||
- Replace `statement_uri()` with `subgraph_uri()`
|
||||
|
||||
**`trustgraph-base/trustgraph/provenance/namespaces.py`**
|
||||
|
||||
- Replace `TG_REIFIES` with `TG_CONTAINS`
|
||||
|
||||
### Not in Scope
|
||||
|
||||
- **kg-extract-topics**: older-style processor, not currently used in
|
||||
standard flows
|
||||
- **kg-extract-rows**: produces rows not triples, different provenance
|
||||
model
|
||||
- **Query-time provenance** (`urn:graph:retrieval`): separate concern,
|
||||
already uses a different pattern (question/exploration/focus/synthesis)
|
||||
- **Document/page/chunk provenance** (PDF decoder, chunker): already uses
|
||||
`derived_entity_triples()` which is per-entity, not per-triple — no
|
||||
redundancy issue
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
### Processor Loop Restructure
|
||||
|
||||
Before (per-triple, in relationships):
|
||||
```python
|
||||
for rel in rels:
|
||||
# ... build relationship_triple ...
|
||||
stmt_uri = statement_uri()
|
||||
prov_triples = triple_provenance_triples(
|
||||
stmt_uri=stmt_uri,
|
||||
extracted_triple=relationship_triple,
|
||||
...
|
||||
)
|
||||
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
```
|
||||
|
||||
After (subgraph):
|
||||
```python
|
||||
sg_uri = subgraph_uri()
|
||||
|
||||
for rel in rels:
|
||||
# ... build relationship_triple ...
|
||||
extracted_triples.append(relationship_triple)
|
||||
|
||||
prov_triples = subgraph_provenance_triples(
|
||||
subgraph_uri=sg_uri,
|
||||
extracted_triples=extracted_triples,
|
||||
chunk_uri=chunk_uri,
|
||||
component_name=default_ident,
|
||||
component_version=COMPONENT_VERSION,
|
||||
llm_model=llm_model,
|
||||
ontology_uri=ontology_uri,
|
||||
)
|
||||
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
```
|
||||
|
||||
### New Helper Signature
|
||||
|
||||
```python
|
||||
def subgraph_provenance_triples(
|
||||
subgraph_uri: str,
|
||||
extracted_triples: List[Triple],
|
||||
chunk_uri: str,
|
||||
component_name: str,
|
||||
component_version: str,
|
||||
llm_model: Optional[str] = None,
|
||||
ontology_uri: Optional[str] = None,
|
||||
timestamp: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build provenance triples for a subgraph of extracted knowledge.
|
||||
|
||||
Creates:
|
||||
- tg:contains link for each extracted triple (RDF-star quoted)
|
||||
- One prov:wasDerivedFrom link to source chunk
|
||||
- One activity with agent metadata
|
||||
"""
|
||||
```
|
||||
|
||||
### Breaking Change
|
||||
|
||||
This is a breaking change to the provenance model. Provenance has not
|
||||
been released, so no migration is needed. The old `tg:reifies` /
|
||||
`statement_uri` code can be removed outright.
|
||||
619
docs/tech-specs/extraction-time-provenance.md
Normal file
619
docs/tech-specs/extraction-time-provenance.md
Normal file
|
|
@ -0,0 +1,619 @@
|
|||
# Extraction-Time Provenance: Source Layer
|
||||
|
||||
## Overview
|
||||
|
||||
This document captures notes on extraction-time provenance for future specification work. Extraction-time provenance records the "source layer" - where data came from originally, how it was extracted and transformed.
|
||||
|
||||
This is separate from query-time provenance (see `query-time-provenance.md`) which records agent reasoning.
|
||||
|
||||
## Problem Statement
|
||||
|
||||
### Current Implementation
|
||||
|
||||
Provenance currently works as follows:
|
||||
- Document metadata is stored as RDF triples in the knowledge graph
|
||||
- A document ID ties metadata to the document, so the document appears as a node in the graph
|
||||
- When edges (relationships/facts) are extracted from documents, a `subjectOf` relationship links the extracted edge back to the source document
|
||||
|
||||
### Problems with Current Approach
|
||||
|
||||
1. **Repetitive metadata loading:** Document metadata is bundled and loaded repeatedly with every batch of triples extracted from that document. This is wasteful and redundant - the same metadata travels as cargo with every extraction output.
|
||||
|
||||
2. **Shallow provenance:** The current `subjectOf` relationship only links facts directly to the top-level document. There is no visibility into the transformation chain - which page the fact came from, which chunk, what extraction method was used.
|
||||
|
||||
### Desired State
|
||||
|
||||
1. **Load metadata once:** Document metadata should be loaded once and attached to the top-level document node, not repeated with every triple batch.
|
||||
|
||||
2. **Rich provenance DAG:** Capture the full transformation chain from source document through all intermediate artifacts down to extracted facts. For example, a PDF document transformation:
|
||||
|
||||
```
|
||||
PDF file (source document with metadata)
|
||||
→ Page 1 (decoded text)
|
||||
→ Chunk 1
|
||||
→ Extracted edge/fact (via subjectOf)
|
||||
→ Extracted edge/fact
|
||||
→ Chunk 2
|
||||
→ Extracted edge/fact
|
||||
→ Page 2
|
||||
→ Chunk 3
|
||||
→ ...
|
||||
```
|
||||
|
||||
3. **Unified storage:** The provenance DAG is stored in the same knowledge graph as the extracted knowledge. This allows provenance to be queried the same way as knowledge - following edges back up the chain from any fact to its exact source location.
|
||||
|
||||
4. **Stable IDs:** Each intermediate artifact (page, chunk) has a stable ID as a node in the graph.
|
||||
|
||||
5. **Parent-child linking:** Derived documents are linked to their parents all the way up to the top-level source document using consistent relationship types.
|
||||
|
||||
6. **Precise fact attribution:** The `subjectOf` relationship on extracted edges points to the immediate parent (chunk), not the top-level document. Full provenance is recovered by traversing up the DAG.
|
||||
|
||||
## Use Cases
|
||||
|
||||
### UC1: Source Attribution in GraphRAG Responses
|
||||
|
||||
**Scenario:** A user runs a GraphRAG query and receives a response from the agent.
|
||||
|
||||
**Flow:**
|
||||
1. User submits a query to the GraphRAG agent
|
||||
2. Agent retrieves relevant facts from the knowledge graph to formulate a response
|
||||
3. Per the query-time provenance spec, the agent reports which facts contributed to the response
|
||||
4. Each fact links to its source chunk via the provenance DAG
|
||||
5. Chunks link to pages, pages link to source documents
|
||||
|
||||
**UX Outcome:** The interface displays the LLM response alongside source attribution. The user can:
|
||||
- See which facts supported the response
|
||||
- Drill down from facts → chunks → pages → documents
|
||||
- Peruse the original source documents to verify claims
|
||||
- Understand exactly where in a document (which page, which section) a fact originated
|
||||
|
||||
**Value:** Users can verify AI-generated responses against primary sources, building trust and enabling fact-checking.
|
||||
|
||||
### UC2: Debugging Extraction Quality
|
||||
|
||||
A fact looks wrong. Trace back through chunk → page → document to see the original text. Was it a bad extraction, or was the source itself wrong?
|
||||
|
||||
### UC3: Incremental Re-extraction
|
||||
|
||||
Source document gets updated. Which chunks/facts were derived from it? Invalidate and regenerate just those, rather than re-processing everything.
|
||||
|
||||
### UC4: Data Deletion / Right to be Forgotten
|
||||
|
||||
A source document must be removed (GDPR, legal, etc.). Traverse the DAG to find and remove all derived facts.
|
||||
|
||||
### UC5: Conflict Resolution
|
||||
|
||||
Two facts contradict each other. Trace both back to their sources to understand why and decide which to trust (more authoritative source, more recent, etc.).
|
||||
|
||||
### UC6: Source Authority Weighting
|
||||
|
||||
Some sources are more authoritative than others. Facts can be weighted or filtered based on the authority/quality of their origin documents.
|
||||
|
||||
### UC7: Extraction Pipeline Comparison
|
||||
|
||||
Compare outputs from different extraction methods/versions. Which extractor produced better facts from the same source?
|
||||
|
||||
## Integration Points
|
||||
|
||||
### Librarian
|
||||
|
||||
The librarian component already provides document storage with unique document IDs. The provenance system integrates with this existing infrastructure.
|
||||
|
||||
#### Existing Capabilities (already implemented)
|
||||
|
||||
**Parent-Child Document Linking:**
|
||||
- `parent_id` field in `DocumentMetadata` - links child to parent document
|
||||
- `document_type` field - values: `"source"` (original) or `"extracted"` (derived)
|
||||
- `add-child-document` API - creates child document with automatic `document_type = "extracted"`
|
||||
- `list-children` API - retrieves all children of a parent document
|
||||
- Cascade deletion - removing a parent automatically deletes all child documents
|
||||
|
||||
**Document Identification:**
|
||||
- Document IDs are client-specified (not auto-generated)
|
||||
- Documents keyed by composite `(user, document_id)` in Cassandra
|
||||
- Object IDs (UUIDs) generated internally for blob storage
|
||||
|
||||
**Metadata Support:**
|
||||
- `metadata: list[Triple]` field - RDF triples for structured metadata
|
||||
- `title`, `comments`, `tags` - basic document metadata
|
||||
- `time` - timestamp, `kind` - MIME type
|
||||
|
||||
**Storage Architecture:**
|
||||
- Metadata stored in Cassandra (`librarian` keyspace, `document` table)
|
||||
- Content stored in MinIO/S3 blob storage (`library` bucket)
|
||||
- Smart content delivery: documents < 2MB embedded, larger documents streamed
|
||||
|
||||
#### Key Files
|
||||
|
||||
- `trustgraph-flow/trustgraph/librarian/librarian.py` - Core librarian operations
|
||||
- `trustgraph-flow/trustgraph/librarian/service.py` - Service processor, document loading
|
||||
- `trustgraph-flow/trustgraph/tables/library.py` - Cassandra table store
|
||||
- `trustgraph-base/trustgraph/schema/services/library.py` - Schema definitions
|
||||
|
||||
#### Gaps to Address
|
||||
|
||||
The librarian has the building blocks but currently:
|
||||
1. Parent-child linking is one level deep - no multi-level DAG traversal helpers
|
||||
2. No standard relationship type vocabulary (e.g., `derivedFrom`, `extractedFrom`)
|
||||
3. Provenance metadata (extraction method, confidence, chunk position) not standardized
|
||||
4. No query API to traverse the full provenance chain from a fact back to source
|
||||
|
||||
## End-to-End Flow Design
|
||||
|
||||
Each processor in the pipeline follows a consistent pattern:
|
||||
- Receive document ID from upstream
|
||||
- Fetch content from librarian
|
||||
- Produce child artifacts
|
||||
- For each child: save to librarian, emit edge to graph, forward ID downstream
|
||||
|
||||
### Processing Flows
|
||||
|
||||
There are two flows depending on document type:
|
||||
|
||||
#### PDF Document Flow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Librarian (initiate processing) │
|
||||
│ 1. Emit root document metadata to knowledge graph (once) │
|
||||
│ 2. Send root document ID to PDF extractor │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ PDF Extractor (per page) │
|
||||
│ 1. Fetch PDF content from librarian using document ID │
|
||||
│ 2. Extract pages as text │
|
||||
│ 3. For each page: │
|
||||
│ a. Save page as child document in librarian (parent = root doc) │
|
||||
│ b. Emit parent-child edge to knowledge graph │
|
||||
│ c. Send page document ID to chunker │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Chunker (per chunk) │
|
||||
│ 1. Fetch page content from librarian using document ID │
|
||||
│ 2. Split text into chunks │
|
||||
│ 3. For each chunk: │
|
||||
│ a. Save chunk as child document in librarian (parent = page) │
|
||||
│ b. Emit parent-child edge to knowledge graph │
|
||||
│ c. Send chunk document ID + chunk content to next processor │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
||||
Post-chunker optimization: messages carry both
|
||||
chunk ID (for provenance) and content (to avoid
|
||||
librarian round-trip). Chunks are small (2-4KB).
|
||||
─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Knowledge Extractor (per chunk) │
|
||||
│ 1. Receive chunk ID + content directly (no librarian fetch needed) │
|
||||
│ 2. Extract facts/triples and embeddings from chunk content │
|
||||
│ 3. For each triple: │
|
||||
│ a. Emit triple to knowledge graph │
|
||||
│ b. Emit reified edge linking triple → chunk ID (edge pointing │
|
||||
│ to edge - first use of reification support) │
|
||||
│ 4. For each embedding: │
|
||||
│ a. Emit embedding with its entity ID │
|
||||
│ b. Link entity ID → chunk ID in knowledge graph │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### Text Document Flow
|
||||
|
||||
Text documents skip the PDF extractor and go directly to the chunker:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Librarian (initiate processing) │
|
||||
│ 1. Emit root document metadata to knowledge graph (once) │
|
||||
│ 2. Send root document ID directly to chunker (skip PDF extractor) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Chunker (per chunk) │
|
||||
│ 1. Fetch text content from librarian using document ID │
|
||||
│ 2. Split text into chunks │
|
||||
│ 3. For each chunk: │
|
||||
│ a. Save chunk as child document in librarian (parent = root doc) │
|
||||
│ b. Emit parent-child edge to knowledge graph │
|
||||
│ c. Send chunk document ID + chunk content to next processor │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Knowledge Extractor │
|
||||
│ (same as PDF flow) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
The resulting DAG is one level shorter:
|
||||
|
||||
```
|
||||
PDF: Document → Pages → Chunks → Triples/Embeddings
|
||||
Text: Document → Chunks → Triples/Embeddings
|
||||
```
|
||||
|
||||
The design accommodates both because the chunker treats its input generically - it uses whatever document ID it receives as the parent, regardless of whether that's a source document or a page.
|
||||
|
||||
### Metadata Schema (PROV-O)
|
||||
|
||||
Provenance metadata uses the W3C PROV-O ontology. This provides a standard vocabulary and enables future signing/authentication of extraction outputs.
|
||||
|
||||
#### PROV-O Core Concepts
|
||||
|
||||
| PROV-O Type | TrustGraph Usage |
|
||||
|-------------|------------------|
|
||||
| `prov:Entity` | Document, Page, Chunk, Triple, Embedding |
|
||||
| `prov:Activity` | Instances of extraction operations |
|
||||
| `prov:Agent` | TG components (PDF extractor, chunker, etc.) with versions |
|
||||
|
||||
#### PROV-O Relationships
|
||||
|
||||
| Predicate | Meaning | Example |
|
||||
|-----------|---------|---------|
|
||||
| `prov:wasDerivedFrom` | Entity derived from another entity | Page wasDerivedFrom Document |
|
||||
| `prov:wasGeneratedBy` | Entity generated by an activity | Page wasGeneratedBy PDFExtractionActivity |
|
||||
| `prov:used` | Activity used an entity as input | PDFExtractionActivity used Document |
|
||||
| `prov:wasAssociatedWith` | Activity performed by an agent | PDFExtractionActivity wasAssociatedWith tg:PDFExtractor |
|
||||
|
||||
#### Metadata at Each Level
|
||||
|
||||
**Source Document (emitted by Librarian):**
|
||||
```
|
||||
doc:123 a prov:Entity .
|
||||
doc:123 dc:title "Research Paper" .
|
||||
doc:123 dc:source <https://example.com/paper.pdf> .
|
||||
doc:123 dc:date "2024-01-15" .
|
||||
doc:123 dc:creator "Author Name" .
|
||||
doc:123 tg:pageCount 42 .
|
||||
doc:123 tg:mimeType "application/pdf" .
|
||||
```
|
||||
|
||||
**Page (emitted by PDF Extractor):**
|
||||
```
|
||||
page:123-1 a prov:Entity .
|
||||
page:123-1 prov:wasDerivedFrom doc:123 .
|
||||
page:123-1 prov:wasGeneratedBy activity:pdf-extract-456 .
|
||||
page:123-1 tg:pageNumber 1 .
|
||||
|
||||
activity:pdf-extract-456 a prov:Activity .
|
||||
activity:pdf-extract-456 prov:used doc:123 .
|
||||
activity:pdf-extract-456 prov:wasAssociatedWith tg:PDFExtractor .
|
||||
activity:pdf-extract-456 tg:componentVersion "1.2.3" .
|
||||
activity:pdf-extract-456 prov:startedAtTime "2024-01-15T10:30:00Z" .
|
||||
```
|
||||
|
||||
**Chunk (emitted by Chunker):**
|
||||
```
|
||||
chunk:123-1-1 a prov:Entity .
|
||||
chunk:123-1-1 prov:wasDerivedFrom page:123-1 .
|
||||
chunk:123-1-1 prov:wasGeneratedBy activity:chunk-789 .
|
||||
chunk:123-1-1 tg:chunkIndex 1 .
|
||||
chunk:123-1-1 tg:charOffset 0 .
|
||||
chunk:123-1-1 tg:charLength 2048 .
|
||||
|
||||
activity:chunk-789 a prov:Activity .
|
||||
activity:chunk-789 prov:used page:123-1 .
|
||||
activity:chunk-789 prov:wasAssociatedWith tg:Chunker .
|
||||
activity:chunk-789 tg:componentVersion "1.0.0" .
|
||||
activity:chunk-789 tg:chunkSize 2048 .
|
||||
activity:chunk-789 tg:chunkOverlap 200 .
|
||||
```
|
||||
|
||||
**Triple (emitted by Knowledge Extractor):**
|
||||
```
|
||||
# The extracted triple (edge)
|
||||
entity:JohnSmith rel:worksAt entity:AcmeCorp .
|
||||
|
||||
# Subgraph containing the extracted triples
|
||||
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
|
||||
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
|
||||
subgraph:001 prov:wasGeneratedBy activity:extract-999 .
|
||||
|
||||
activity:extract-999 a prov:Activity .
|
||||
activity:extract-999 prov:used chunk:123-1-1 .
|
||||
activity:extract-999 prov:wasAssociatedWith tg:KnowledgeExtractor .
|
||||
activity:extract-999 tg:componentVersion "2.1.0" .
|
||||
activity:extract-999 tg:llmModel "claude-3" .
|
||||
activity:extract-999 tg:ontology <http://example.org/ontologies/business-v1> .
|
||||
```
|
||||
|
||||
**Embedding (stored in vector store, not triple store):**
|
||||
|
||||
Embeddings are stored in the vector store with metadata, not as RDF triples. Each embedding record contains:
|
||||
|
||||
| Field | Description | Example |
|
||||
|-------|-------------|---------|
|
||||
| vector | The embedding vector | [0.123, -0.456, ...] |
|
||||
| entity | Node URI the embedding represents | `entity:JohnSmith` |
|
||||
| chunk_id | Source chunk (provenance) | `chunk:123-1-1` |
|
||||
| model | Embedding model used | `text-embedding-ada-002` |
|
||||
| component_version | TG embedder version | `1.0.0` |
|
||||
|
||||
The `entity` field links the embedding to the knowledge graph (node URI). The `chunk_id` field provides provenance back to the source chunk, enabling traversal up the DAG to the original document.
|
||||
|
||||
#### TrustGraph Namespace Extensions
|
||||
|
||||
Custom predicates under the `tg:` namespace for extraction-specific metadata:
|
||||
|
||||
| Predicate | Domain | Description |
|
||||
|-----------|--------|-------------|
|
||||
| `tg:contains` | Subgraph | Points at a triple contained in this extraction subgraph |
|
||||
| `tg:pageCount` | Document | Total number of pages in source document |
|
||||
| `tg:mimeType` | Document | MIME type of source document |
|
||||
| `tg:pageNumber` | Page | Page number in source document |
|
||||
| `tg:chunkIndex` | Chunk | Index of chunk within parent |
|
||||
| `tg:charOffset` | Chunk | Character offset in parent text |
|
||||
| `tg:charLength` | Chunk | Length of chunk in characters |
|
||||
| `tg:chunkSize` | Activity | Configured chunk size |
|
||||
| `tg:chunkOverlap` | Activity | Configured overlap between chunks |
|
||||
| `tg:componentVersion` | Activity | Version of TG component |
|
||||
| `tg:llmModel` | Activity | LLM used for extraction |
|
||||
| `tg:ontology` | Activity | Ontology URI used to guide extraction |
|
||||
| `tg:embeddingModel` | Activity | Model used for embeddings |
|
||||
| `tg:sourceText` | Statement | Exact text from which a triple was extracted |
|
||||
| `tg:sourceCharOffset` | Statement | Character offset within chunk where source text starts |
|
||||
| `tg:sourceCharLength` | Statement | Length of source text in characters |
|
||||
|
||||
#### Vocabulary Bootstrap (Per Collection)
|
||||
|
||||
The knowledge graph is ontology-neutral and initialises empty. When writing PROV-O provenance data to a collection for the first time, the vocabulary must be bootstrapped with RDF labels for all classes and predicates. This ensures human-readable display in queries and UI.
|
||||
|
||||
**PROV-O Classes:**
|
||||
```
|
||||
prov:Entity rdfs:label "Entity" .
|
||||
prov:Activity rdfs:label "Activity" .
|
||||
prov:Agent rdfs:label "Agent" .
|
||||
```
|
||||
|
||||
**PROV-O Predicates:**
|
||||
```
|
||||
prov:wasDerivedFrom rdfs:label "was derived from" .
|
||||
prov:wasGeneratedBy rdfs:label "was generated by" .
|
||||
prov:used rdfs:label "used" .
|
||||
prov:wasAssociatedWith rdfs:label "was associated with" .
|
||||
prov:startedAtTime rdfs:label "started at" .
|
||||
```
|
||||
|
||||
**TrustGraph Predicates:**
|
||||
```
|
||||
tg:contains rdfs:label "contains" .
|
||||
tg:pageCount rdfs:label "page count" .
|
||||
tg:mimeType rdfs:label "MIME type" .
|
||||
tg:pageNumber rdfs:label "page number" .
|
||||
tg:chunkIndex rdfs:label "chunk index" .
|
||||
tg:charOffset rdfs:label "character offset" .
|
||||
tg:charLength rdfs:label "character length" .
|
||||
tg:chunkSize rdfs:label "chunk size" .
|
||||
tg:chunkOverlap rdfs:label "chunk overlap" .
|
||||
tg:componentVersion rdfs:label "component version" .
|
||||
tg:llmModel rdfs:label "LLM model" .
|
||||
tg:ontology rdfs:label "ontology" .
|
||||
tg:embeddingModel rdfs:label "embedding model" .
|
||||
tg:sourceText rdfs:label "source text" .
|
||||
tg:sourceCharOffset rdfs:label "source character offset" .
|
||||
tg:sourceCharLength rdfs:label "source character length" .
|
||||
```
|
||||
|
||||
**Implementation note:** This vocabulary bootstrap should be idempotent - safe to run multiple times without creating duplicates. Could be triggered on first document processing in a collection, or as a separate collection initialisation step.
|
||||
|
||||
#### Sub-Chunk Provenance (Aspirational)
|
||||
|
||||
For finer-grained provenance, it would be valuable to record exactly where within a chunk a triple was extracted from. This enables:
|
||||
|
||||
- Highlighting the exact source text in the UI
|
||||
- Verifying extraction accuracy against source
|
||||
- Debugging extraction quality at the sentence level
|
||||
|
||||
**Example with position tracking:**
|
||||
```
|
||||
# The extracted triple
|
||||
entity:JohnSmith rel:worksAt entity:AcmeCorp .
|
||||
|
||||
# Subgraph with sub-chunk provenance
|
||||
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
|
||||
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
|
||||
subgraph:001 tg:sourceText "John Smith has worked at Acme Corp since 2019" .
|
||||
subgraph:001 tg:sourceCharOffset 1547 .
|
||||
subgraph:001 tg:sourceCharLength 46 .
|
||||
```
|
||||
|
||||
**Example with text range (alternative):**
|
||||
```
|
||||
subgraph:001 tg:contains <<entity:JohnSmith rel:worksAt entity:AcmeCorp>> .
|
||||
subgraph:001 prov:wasDerivedFrom chunk:123-1-1 .
|
||||
subgraph:001 tg:sourceRange "1547-1593" .
|
||||
subgraph:001 tg:sourceText "John Smith has worked at Acme Corp since 2019" .
|
||||
```
|
||||
|
||||
**Implementation considerations:**
|
||||
|
||||
- LLM-based extraction may not naturally provide character positions
|
||||
- Could prompt the LLM to return the source sentence/phrase alongside extracted triples
|
||||
- Alternatively, post-process to fuzzy-match extracted entities back to source text
|
||||
- Trade-off between extraction complexity and provenance granularity
|
||||
- May be easier to achieve with structured extraction methods than free-form LLM extraction
|
||||
|
||||
This is marked as aspirational - the basic chunk-level provenance should be implemented first, with sub-chunk tracking as a future enhancement if feasible.
|
||||
|
||||
### Dual Storage Model
|
||||
|
||||
The provenance DAG is built progressively as documents flow through the pipeline:
|
||||
|
||||
| Store | What's Stored | Purpose |
|
||||
|-------|---------------|---------|
|
||||
| Librarian | Document content + parent-child links | Content retrieval, cascade deletion |
|
||||
| Knowledge Graph | Parent-child edges + metadata | Provenance queries, fact attribution |
|
||||
|
||||
Both stores maintain the same DAG structure. The librarian holds content; the graph holds relationships and enables traversal queries.
|
||||
|
||||
### Key Design Principles
|
||||
|
||||
1. **Document ID as the unit of flow** - Processors pass IDs, not content. Content is fetched from librarian when needed.
|
||||
|
||||
2. **Emit once at source** - Metadata is written to the graph once when processing begins, not repeated downstream.
|
||||
|
||||
3. **Consistent processor pattern** - Every processor follows the same receive/fetch/produce/save/emit/forward pattern.
|
||||
|
||||
4. **Progressive DAG construction** - Each processor adds its level to the DAG. The full provenance chain is built incrementally.
|
||||
|
||||
5. **Post-chunker optimization** - After chunking, messages carry both ID and content. Chunks are small (2-4KB), so including content avoids unnecessary librarian round-trips while preserving provenance via the ID.
|
||||
|
||||
## Implementation Tasks
|
||||
|
||||
### Librarian Changes
|
||||
|
||||
#### Current State
|
||||
|
||||
- Initiates document processing by sending document ID to first processor
|
||||
- No connection to triple store - metadata is bundled with extraction outputs
|
||||
- `add-child-document` creates one-level parent-child links
|
||||
- `list-children` returns immediate children only
|
||||
|
||||
#### Required Changes
|
||||
|
||||
**1. New interface: Triple store connection**
|
||||
|
||||
Librarian needs to emit document metadata edges directly to the knowledge graph when initiating processing.
|
||||
- Add triple store client/publisher to librarian service
|
||||
- On processing initiation: emit root document metadata as graph edges (once)
|
||||
|
||||
**2. Document type vocabulary**
|
||||
|
||||
Standardize `document_type` values for child documents:
|
||||
- `source` - original uploaded document
|
||||
- `page` - page extracted from source (PDF, etc.)
|
||||
- `chunk` - text chunk derived from page or source
|
||||
|
||||
#### Interface Changes Summary
|
||||
|
||||
| Interface | Change |
|
||||
|-----------|--------|
|
||||
| Triple store | New outbound connection - emit document metadata edges |
|
||||
| Processing initiation | Emit metadata to graph before forwarding document ID |
|
||||
|
||||
### PDF Extractor Changes
|
||||
|
||||
#### Current State
|
||||
|
||||
- Receives document content (or streams large documents)
|
||||
- Extracts text from PDF pages
|
||||
- Forwards page content to chunker
|
||||
- No interaction with librarian or triple store
|
||||
|
||||
#### Required Changes
|
||||
|
||||
**1. New interface: Librarian client**
|
||||
|
||||
PDF extractor needs to save each page as a child document in librarian.
|
||||
- Add librarian client to PDF extractor service
|
||||
- For each page: call `add-child-document` with parent = root document ID
|
||||
|
||||
**2. New interface: Triple store connection**
|
||||
|
||||
PDF extractor needs to emit parent-child edges to knowledge graph.
|
||||
- Add triple store client/publisher
|
||||
- For each page: emit edge linking page document to parent document
|
||||
|
||||
**3. Change output format**
|
||||
|
||||
Instead of forwarding page content directly, forward page document ID.
|
||||
- Chunker will fetch content from librarian using the ID
|
||||
|
||||
#### Interface Changes Summary
|
||||
|
||||
| Interface | Change |
|
||||
|-----------|--------|
|
||||
| Librarian | New outbound - save child documents |
|
||||
| Triple store | New outbound - emit parent-child edges |
|
||||
| Output message | Change from content to document ID |
|
||||
|
||||
### Chunker Changes
|
||||
|
||||
#### Current State
|
||||
|
||||
- Receives page/text content
|
||||
- Splits into chunks
|
||||
- Forwards chunk content to downstream processors
|
||||
- No interaction with librarian or triple store
|
||||
|
||||
#### Required Changes
|
||||
|
||||
**1. Change input handling**
|
||||
|
||||
Receive document ID instead of content, fetch from librarian.
|
||||
- Add librarian client to chunker service
|
||||
- Fetch page content using document ID
|
||||
|
||||
**2. New interface: Librarian client (write)**
|
||||
|
||||
Save each chunk as a child document in librarian.
|
||||
- For each chunk: call `add-child-document` with parent = page document ID
|
||||
|
||||
**3. New interface: Triple store connection**
|
||||
|
||||
Emit parent-child edges to knowledge graph.
|
||||
- Add triple store client/publisher
|
||||
- For each chunk: emit edge linking chunk document to page document
|
||||
|
||||
**4. Change output format**
|
||||
|
||||
Forward both chunk document ID and chunk content (post-chunker optimization).
|
||||
- Downstream processors receive ID for provenance + content to work with
|
||||
|
||||
#### Interface Changes Summary
|
||||
|
||||
| Interface | Change |
|
||||
|-----------|--------|
|
||||
| Input message | Change from content to document ID |
|
||||
| Librarian | New outbound (read + write) - fetch content, save child documents |
|
||||
| Triple store | New outbound - emit parent-child edges |
|
||||
| Output message | Change from content-only to ID + content |
|
||||
|
||||
### Knowledge Extractor Changes
|
||||
|
||||
#### Current State
|
||||
|
||||
- Receives chunk content
|
||||
- Extracts triples and embeddings
|
||||
- Emits to triple store and embedding store
|
||||
- `subjectOf` relationship points to top-level document (not chunk)
|
||||
|
||||
#### Required Changes
|
||||
|
||||
**1. Change input handling**
|
||||
|
||||
Receive chunk document ID alongside content.
|
||||
- Use chunk ID for provenance linking (content already included per optimization)
|
||||
|
||||
**2. Update triple provenance**
|
||||
|
||||
Link extracted triples to chunk (not top-level document).
|
||||
- Use reification to create edge pointing to edge
|
||||
- `subjectOf` relationship: triple → chunk document ID
|
||||
- First use of existing reification support
|
||||
|
||||
**3. Update embedding provenance**
|
||||
|
||||
Link embedding entity IDs to chunk.
|
||||
- Emit edge: embedding entity ID → chunk document ID
|
||||
|
||||
#### Interface Changes Summary
|
||||
|
||||
| Interface | Change |
|
||||
|-----------|--------|
|
||||
| Input message | Expect chunk ID + content (not content only) |
|
||||
| Triple store | Use reification for triple → chunk provenance |
|
||||
| Embedding provenance | Link entity ID → chunk ID |
|
||||
|
||||
## References
|
||||
|
||||
- Query-time provenance: `docs/tech-specs/query-time-provenance.md`
|
||||
- PROV-O standard for provenance modeling
|
||||
- Existing source metadata in knowledge graph (needs audit)
|
||||
984
docs/tech-specs/large-document-loading.md
Normal file
984
docs/tech-specs/large-document-loading.md
Normal file
|
|
@ -0,0 +1,984 @@
|
|||
# Large Document Loading Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification addresses scalability and user experience issues when loading
|
||||
large documents into TrustGraph. The current architecture treats document upload
|
||||
as a single atomic operation, causing memory pressure at multiple points in the
|
||||
pipeline and providing no feedback or recovery options to users.
|
||||
|
||||
This implementation targets the following use cases:
|
||||
|
||||
1. **Large PDF Processing**: Upload and process multi-hundred-megabyte PDF files
|
||||
without exhausting memory
|
||||
2. **Resumable Uploads**: Allow interrupted uploads to continue from where they
|
||||
left off rather than restarting
|
||||
3. **Progress Feedback**: Provide users with real-time visibility into upload
|
||||
and processing progress
|
||||
4. **Memory-Efficient Processing**: Process documents in a streaming fashion
|
||||
without holding entire files in memory
|
||||
|
||||
## Goals
|
||||
|
||||
- **Incremental Upload**: Support chunked document upload via REST and WebSocket
|
||||
- **Resumable Transfers**: Enable recovery from interrupted uploads
|
||||
- **Progress Visibility**: Provide upload/processing progress feedback to clients
|
||||
- **Memory Efficiency**: Eliminate full-document buffering throughout the pipeline
|
||||
- **Backward Compatibility**: Existing small-document workflows continue unchanged
|
||||
- **Streaming Processing**: PDF decoding and text chunking operate on streams
|
||||
|
||||
## Background
|
||||
|
||||
### Current Architecture
|
||||
|
||||
Document submission flows through the following path:
|
||||
|
||||
1. **Client** submits document via REST (`POST /api/v1/librarian`) or WebSocket
|
||||
2. **API Gateway** receives complete request with base64-encoded document content
|
||||
3. **LibrarianRequestor** translates request to Pulsar message
|
||||
4. **Librarian Service** receives message, decodes document into memory
|
||||
5. **BlobStore** uploads document to Garage/S3
|
||||
6. **Cassandra** stores metadata with object reference
|
||||
7. For processing: document retrieved from S3, decoded, chunked—all in memory
|
||||
|
||||
Key files:
|
||||
- REST/WebSocket entry: `trustgraph-flow/trustgraph/gateway/service.py`
|
||||
- Librarian core: `trustgraph-flow/trustgraph/librarian/librarian.py`
|
||||
- Blob storage: `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||
- Cassandra tables: `trustgraph-flow/trustgraph/tables/library.py`
|
||||
- API schema: `trustgraph-base/trustgraph/schema/services/library.py`
|
||||
|
||||
### Current Limitations
|
||||
|
||||
The current design has several compounding memory and UX issues:
|
||||
|
||||
1. **Atomic Upload Operation**: The entire document must be transmitted in a
|
||||
single request. Large documents require long-running requests with no
|
||||
progress indication and no retry mechanism if the connection fails.
|
||||
|
||||
2. **API Design**: Both REST and WebSocket APIs expect the complete document
|
||||
in a single message. The schema (`LibrarianRequest`) has a single `content`
|
||||
field containing the entire base64-encoded document.
|
||||
|
||||
3. **Librarian Memory**: The librarian service decodes the entire document
|
||||
into memory before uploading to S3. For a 500MB PDF, this means holding
|
||||
500MB+ in process memory.
|
||||
|
||||
4. **PDF Decoder Memory**: When processing begins, the PDF decoder loads the
|
||||
entire PDF into memory to extract text. PyPDF and similar libraries
|
||||
typically require full document access.
|
||||
|
||||
5. **Chunker Memory**: The text chunker receives the complete extracted text
|
||||
and holds it in memory while producing chunks.
|
||||
|
||||
**Memory Impact Example** (500MB PDF):
|
||||
- Gateway: ~700MB (base64 encoding overhead)
|
||||
- Librarian: ~500MB (decoded bytes)
|
||||
- PDF Decoder: ~500MB + extraction buffers
|
||||
- Chunker: extracted text (variable, potentially 100MB+)
|
||||
|
||||
Total peak memory can exceed 2GB for a single large document.
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **API Facade**: All client interaction goes through the librarian API. Clients
|
||||
have no direct access to or knowledge of the underlying S3/Garage storage.
|
||||
|
||||
2. **S3 Multipart Upload**: Use standard S3 multipart upload under the hood.
|
||||
This is widely supported across S3-compatible systems (AWS S3, MinIO, Garage,
|
||||
Ceph, DigitalOcean Spaces, Backblaze B2, etc.) ensuring portability.
|
||||
|
||||
3. **Atomic Completion**: S3 multipart uploads are inherently atomic - uploaded
|
||||
parts are invisible until `CompleteMultipartUpload` is called. No temporary
|
||||
files or rename operations needed.
|
||||
|
||||
4. **Trackable State**: Upload sessions tracked in Cassandra, providing
|
||||
visibility into incomplete uploads and enabling resume capability.
|
||||
|
||||
### Chunked Upload Flow
|
||||
|
||||
```
|
||||
Client Librarian API S3/Garage
|
||||
│ │ │
|
||||
│── begin-upload ───────────►│ │
|
||||
│ (metadata, size) │── CreateMultipartUpload ────►│
|
||||
│ │◄── s3_upload_id ─────────────│
|
||||
│◄── upload_id ──────────────│ (store session in │
|
||||
│ │ Cassandra) │
|
||||
│ │ │
|
||||
│── upload-chunk ───────────►│ │
|
||||
│ (upload_id, index, data) │── UploadPart ───────────────►│
|
||||
│ │◄── etag ─────────────────────│
|
||||
│◄── ack + progress ─────────│ (store etag in session) │
|
||||
│ ⋮ │ ⋮ │
|
||||
│ (repeat for all chunks) │ │
|
||||
│ │ │
|
||||
│── complete-upload ────────►│ │
|
||||
│ (upload_id) │── CompleteMultipartUpload ──►│
|
||||
│ │ (parts coalesced by S3) │
|
||||
│ │── store doc metadata ───────►│ Cassandra
|
||||
│◄── document_id ────────────│ (delete session) │
|
||||
```
|
||||
|
||||
The client never interacts with S3 directly. The librarian translates between
|
||||
our chunked upload API and S3 multipart operations internally.
|
||||
|
||||
### Librarian API Operations
|
||||
|
||||
#### `begin-upload`
|
||||
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "begin-upload",
|
||||
"document-metadata": {
|
||||
"id": "doc-123",
|
||||
"kind": "application/pdf",
|
||||
"title": "Large Document",
|
||||
"user": "user-id",
|
||||
"tags": ["tag1", "tag2"]
|
||||
},
|
||||
"total-size": 524288000,
|
||||
"chunk-size": 5242880
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"upload-id": "upload-abc-123",
|
||||
"chunk-size": 5242880,
|
||||
"total-chunks": 100
|
||||
}
|
||||
```
|
||||
|
||||
The librarian:
|
||||
1. Generates a unique `upload_id` and `object_id` (UUID for blob storage)
|
||||
2. Calls S3 `CreateMultipartUpload`, receives `s3_upload_id`
|
||||
3. Creates session record in Cassandra
|
||||
4. Returns `upload_id` to client
|
||||
|
||||
#### `upload-chunk`
|
||||
|
||||
Upload a single chunk.
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "upload-chunk",
|
||||
"upload-id": "upload-abc-123",
|
||||
"chunk-index": 0,
|
||||
"content": "<base64-encoded-chunk>"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"upload-id": "upload-abc-123",
|
||||
"chunk-index": 0,
|
||||
"chunks-received": 1,
|
||||
"total-chunks": 100,
|
||||
"bytes-received": 5242880,
|
||||
"total-bytes": 524288000
|
||||
}
|
||||
```
|
||||
|
||||
The librarian:
|
||||
1. Looks up session by `upload_id`
|
||||
2. Validates ownership (user must match session creator)
|
||||
3. Calls S3 `UploadPart` with chunk data, receives `etag`
|
||||
4. Updates session record with chunk index and etag
|
||||
5. Returns progress to client
|
||||
|
||||
Failed chunks can be retried - just send the same `chunk-index` again.
|
||||
|
||||
#### `complete-upload`
|
||||
|
||||
Finalize the upload and create the document.
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "complete-upload",
|
||||
"upload-id": "upload-abc-123"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"document-id": "doc-123",
|
||||
"object-id": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
The librarian:
|
||||
1. Looks up session, verifies all chunks received
|
||||
2. Calls S3 `CompleteMultipartUpload` with part etags (S3 coalesces parts
|
||||
internally - zero memory cost to librarian)
|
||||
3. Creates document record in Cassandra with metadata and object reference
|
||||
4. Deletes upload session record
|
||||
5. Returns document ID to client
|
||||
|
||||
#### `abort-upload`
|
||||
|
||||
Cancel an in-progress upload.
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "abort-upload",
|
||||
"upload-id": "upload-abc-123"
|
||||
}
|
||||
```
|
||||
|
||||
The librarian:
|
||||
1. Calls S3 `AbortMultipartUpload` to clean up parts
|
||||
2. Deletes session record from Cassandra
|
||||
|
||||
#### `get-upload-status`
|
||||
|
||||
Query status of an upload (for resume capability).
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "get-upload-status",
|
||||
"upload-id": "upload-abc-123"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"upload-id": "upload-abc-123",
|
||||
"state": "in-progress",
|
||||
"chunks-received": [0, 1, 2, 5, 6],
|
||||
"missing-chunks": [3, 4, 7, 8],
|
||||
"total-chunks": 100,
|
||||
"bytes-received": 36700160,
|
||||
"total-bytes": 524288000
|
||||
}
|
||||
```
|
||||
|
||||
#### `list-uploads`
|
||||
|
||||
List incomplete uploads for a user.
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "list-uploads"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"uploads": [
|
||||
{
|
||||
"upload-id": "upload-abc-123",
|
||||
"document-metadata": { "title": "Large Document", ... },
|
||||
"progress": { "chunks-received": 43, "total-chunks": 100 },
|
||||
"created-at": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Upload Session Storage
|
||||
|
||||
Track in-progress uploads in Cassandra:
|
||||
|
||||
```sql
|
||||
CREATE TABLE upload_session (
|
||||
upload_id text PRIMARY KEY,
|
||||
user text,
|
||||
document_id text,
|
||||
document_metadata text, -- JSON: title, kind, tags, comments, etc.
|
||||
s3_upload_id text, -- internal, for S3 operations
|
||||
object_id uuid, -- target blob ID
|
||||
total_size bigint,
|
||||
chunk_size int,
|
||||
total_chunks int,
|
||||
chunks_received map<int, text>, -- chunk_index → etag
|
||||
created_at timestamp,
|
||||
updated_at timestamp
|
||||
) WITH default_time_to_live = 86400; -- 24 hour TTL
|
||||
|
||||
CREATE INDEX upload_session_user ON upload_session (user);
|
||||
```
|
||||
|
||||
**TTL Behavior:**
|
||||
- Sessions expire after 24 hours if not completed
|
||||
- When Cassandra TTL expires, the session record is deleted
|
||||
- Orphaned S3 parts are cleaned up by S3 lifecycle policy (configure on bucket)
|
||||
|
||||
### Failure Handling and Atomicity
|
||||
|
||||
**Chunk upload failure:**
|
||||
- Client retries the failed chunk (same `upload_id` and `chunk-index`)
|
||||
- S3 `UploadPart` is idempotent for the same part number
|
||||
- Session tracks which chunks succeeded
|
||||
|
||||
**Client disconnect mid-upload:**
|
||||
- Session remains in Cassandra with received chunks recorded
|
||||
- Client can call `get-upload-status` to see what's missing
|
||||
- Resume by uploading only missing chunks, then `complete-upload`
|
||||
|
||||
**Complete-upload failure:**
|
||||
- S3 `CompleteMultipartUpload` is atomic - either succeeds fully or fails
|
||||
- On failure, parts remain and client can retry `complete-upload`
|
||||
- No partial document is ever visible
|
||||
|
||||
**Session expiry:**
|
||||
- Cassandra TTL deletes session record after 24 hours
|
||||
- S3 bucket lifecycle policy cleans up incomplete multipart uploads
|
||||
- No manual cleanup required
|
||||
|
||||
### S3 Multipart Atomicity
|
||||
|
||||
S3 multipart uploads provide built-in atomicity:
|
||||
|
||||
1. **Parts are invisible**: Uploaded parts cannot be accessed as objects.
|
||||
They exist only as parts of an incomplete multipart upload.
|
||||
|
||||
2. **Atomic completion**: `CompleteMultipartUpload` either succeeds (object
|
||||
appears atomically) or fails (no object created). No partial state.
|
||||
|
||||
3. **No rename needed**: The final object key is specified at
|
||||
`CreateMultipartUpload` time. Parts are coalesced directly to that key.
|
||||
|
||||
4. **Server-side coalesce**: S3 combines parts internally. The librarian
|
||||
never reads parts back - zero memory overhead regardless of document size.
|
||||
|
||||
### BlobStore Extensions
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||
|
||||
Add multipart upload methods:
|
||||
|
||||
```python
|
||||
class BlobStore:
|
||||
# Existing methods...
|
||||
|
||||
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
|
||||
"""Initialize multipart upload, return s3_upload_id."""
|
||||
# minio client: create_multipart_upload()
|
||||
|
||||
def upload_part(
|
||||
self, object_id: UUID, s3_upload_id: str,
|
||||
part_number: int, data: bytes
|
||||
) -> str:
|
||||
"""Upload a single part, return etag."""
|
||||
# minio client: upload_part()
|
||||
# Note: S3 part numbers are 1-indexed
|
||||
|
||||
def complete_multipart_upload(
|
||||
self, object_id: UUID, s3_upload_id: str,
|
||||
parts: List[Tuple[int, str]] # [(part_number, etag), ...]
|
||||
) -> None:
|
||||
"""Finalize multipart upload."""
|
||||
# minio client: complete_multipart_upload()
|
||||
|
||||
def abort_multipart_upload(
|
||||
self, object_id: UUID, s3_upload_id: str
|
||||
) -> None:
|
||||
"""Cancel multipart upload, clean up parts."""
|
||||
# minio client: abort_multipart_upload()
|
||||
```
|
||||
|
||||
### Chunk Size Considerations
|
||||
|
||||
- **S3 minimum**: 5MB per part (except last part)
|
||||
- **S3 maximum**: 10,000 parts per upload
|
||||
- **Practical default**: 5MB chunks
|
||||
- 500MB document = 100 chunks
|
||||
- 5GB document = 1,000 chunks
|
||||
- **Progress granularity**: Smaller chunks = finer progress updates
|
||||
- **Network efficiency**: Larger chunks = fewer round trips
|
||||
|
||||
Chunk size could be client-configurable within bounds (5MB - 100MB).
|
||||
|
||||
### Document Processing: Streaming Retrieval
|
||||
|
||||
The upload flow addresses getting documents into storage efficiently. The
|
||||
processing flow addresses extracting and chunking documents without loading
|
||||
them entirely into memory.
|
||||
|
||||
#### Design Principle: Identifier, Not Content
|
||||
|
||||
Currently, when processing is triggered, document content flows through Pulsar
|
||||
messages. This loads entire documents into memory. Instead:
|
||||
|
||||
- Pulsar messages carry only the **document identifier**
|
||||
- Processors fetch document content directly from librarian
|
||||
- Fetching happens as a **stream to temporary file**
|
||||
- Document-specific parsing (PDF, text, etc.) works with files, not memory buffers
|
||||
|
||||
This keeps the librarian document-structure-agnostic. PDF parsing, text
|
||||
extraction, and other format-specific logic stays in the respective decoders.
|
||||
|
||||
#### Processing Flow
|
||||
|
||||
```
|
||||
Pulsar PDF Decoder Librarian S3
|
||||
│ │ │ │
|
||||
│── doc-id ───────────►│ │ │
|
||||
│ (processing msg) │ │ │
|
||||
│ │ │ │
|
||||
│ │── stream-document ──────►│ │
|
||||
│ │ (doc-id) │── GetObject ────►│
|
||||
│ │ │ │
|
||||
│ │◄── chunk ────────────────│◄── stream ───────│
|
||||
│ │ (write to temp file) │ │
|
||||
│ │◄── chunk ────────────────│◄── stream ───────│
|
||||
│ │ (append to temp file) │ │
|
||||
│ │ ⋮ │ ⋮ │
|
||||
│ │◄── EOF ──────────────────│ │
|
||||
│ │ │ │
|
||||
│ │ ┌──────────────────────────┐ │
|
||||
│ │ │ temp file on disk │ │
|
||||
│ │ │ (memory stays bounded) │ │
|
||||
│ │ └────────────┬─────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ PDF library opens file │
|
||||
│ │ extract page 1 text ──► chunker │
|
||||
│ │ extract page 2 text ──► chunker │
|
||||
│ │ ⋮ │
|
||||
│ │ close file │
|
||||
│ │ delete temp file │
|
||||
```
|
||||
|
||||
#### Librarian Stream API
|
||||
|
||||
Add a streaming document retrieval operation:
|
||||
|
||||
**`stream-document`**
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"operation": "stream-document",
|
||||
"document-id": "doc-123"
|
||||
}
|
||||
```
|
||||
|
||||
Response: Streamed binary chunks (not a single response).
|
||||
|
||||
For REST API, this returns a streaming response with `Transfer-Encoding: chunked`.
|
||||
|
||||
For internal service-to-service calls (processor to librarian), this could be:
|
||||
- Direct S3 streaming via presigned URL (if internal network allows)
|
||||
- Chunked responses over the service protocol
|
||||
- A dedicated streaming endpoint
|
||||
|
||||
The key requirement: data flows in chunks, never fully buffered in librarian.
|
||||
|
||||
#### PDF Decoder Changes
|
||||
|
||||
**Current implementation** (memory-intensive):
|
||||
|
||||
```python
|
||||
def decode_pdf(document_content: bytes) -> str:
|
||||
reader = PdfReader(BytesIO(document_content)) # full doc in memory
|
||||
text = ""
|
||||
for page in reader.pages:
|
||||
text += page.extract_text() # accumulating
|
||||
return text # full text in memory
|
||||
```
|
||||
|
||||
**New implementation** (temp file, incremental):
|
||||
|
||||
```python
|
||||
def decode_pdf_streaming(doc_id: str, librarian_client) -> Iterator[str]:
|
||||
"""Yield extracted text page by page."""
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix='.pdf') as tmp:
|
||||
# Stream document to temp file
|
||||
for chunk in librarian_client.stream_document(doc_id):
|
||||
tmp.write(chunk)
|
||||
tmp.flush()
|
||||
|
||||
# Open PDF from file (not memory)
|
||||
reader = PdfReader(tmp.name)
|
||||
|
||||
# Yield pages incrementally
|
||||
for page in reader.pages:
|
||||
yield page.extract_text()
|
||||
|
||||
# tmp file auto-deleted on context exit
|
||||
```
|
||||
|
||||
Memory profile:
|
||||
- Temp file on disk: size of PDF (disk is cheap)
|
||||
- In memory: one page's text at a time
|
||||
- Peak memory: bounded, independent of document size
|
||||
|
||||
#### Text Document Decoder Changes
|
||||
|
||||
For plain text documents, even simpler - no temp file needed:
|
||||
|
||||
```python
|
||||
def decode_text_streaming(doc_id: str, librarian_client) -> Iterator[str]:
|
||||
"""Yield text in chunks as it streams from storage."""
|
||||
|
||||
buffer = ""
|
||||
for chunk in librarian_client.stream_document(doc_id):
|
||||
buffer += chunk.decode('utf-8')
|
||||
|
||||
# Yield complete lines/paragraphs as they arrive
|
||||
while '\n\n' in buffer:
|
||||
paragraph, buffer = buffer.split('\n\n', 1)
|
||||
yield paragraph + '\n\n'
|
||||
|
||||
# Yield remaining buffer
|
||||
if buffer:
|
||||
yield buffer
|
||||
```
|
||||
|
||||
Text documents can stream directly without temp file since they're
|
||||
linearly structured.
|
||||
|
||||
#### Streaming Chunker Integration
|
||||
|
||||
The chunker receives an iterator of text (pages or paragraphs) and produces
|
||||
chunks incrementally:
|
||||
|
||||
```python
|
||||
class StreamingChunker:
|
||||
def __init__(self, chunk_size: int, overlap: int):
|
||||
self.chunk_size = chunk_size
|
||||
self.overlap = overlap
|
||||
|
||||
def process(self, text_stream: Iterator[str]) -> Iterator[str]:
|
||||
"""Yield chunks as text arrives."""
|
||||
buffer = ""
|
||||
|
||||
for text_segment in text_stream:
|
||||
buffer += text_segment
|
||||
|
||||
while len(buffer) >= self.chunk_size:
|
||||
chunk = buffer[:self.chunk_size]
|
||||
yield chunk
|
||||
# Keep overlap for context continuity
|
||||
buffer = buffer[self.chunk_size - self.overlap:]
|
||||
|
||||
# Yield remaining buffer as final chunk
|
||||
if buffer.strip():
|
||||
yield buffer
|
||||
```
|
||||
|
||||
#### End-to-End Processing Pipeline
|
||||
|
||||
```python
|
||||
async def process_document(doc_id: str, librarian_client, embedder):
|
||||
"""Process document with bounded memory."""
|
||||
|
||||
# Get document metadata to determine type
|
||||
metadata = await librarian_client.get_document_metadata(doc_id)
|
||||
|
||||
# Select decoder based on document type
|
||||
if metadata.kind == 'application/pdf':
|
||||
text_stream = decode_pdf_streaming(doc_id, librarian_client)
|
||||
elif metadata.kind == 'text/plain':
|
||||
text_stream = decode_text_streaming(doc_id, librarian_client)
|
||||
else:
|
||||
raise UnsupportedDocumentType(metadata.kind)
|
||||
|
||||
# Chunk incrementally
|
||||
chunker = StreamingChunker(chunk_size=1000, overlap=100)
|
||||
|
||||
# Process each chunk as it's produced
|
||||
for chunk in chunker.process(text_stream):
|
||||
# Generate embeddings, store in vector DB, etc.
|
||||
embedding = await embedder.embed(chunk)
|
||||
await store_chunk(doc_id, chunk, embedding)
|
||||
```
|
||||
|
||||
At no point is the full document or full extracted text held in memory.
|
||||
|
||||
#### Temp File Considerations
|
||||
|
||||
**Location**: Use system temp directory (`/tmp` or equivalent). For
|
||||
containerized deployments, ensure temp directory has sufficient space
|
||||
and is on fast storage (not network-mounted if possible).
|
||||
|
||||
**Cleanup**: Use context managers (`with tempfile...`) to ensure cleanup
|
||||
even on exceptions.
|
||||
|
||||
**Concurrent processing**: Each processing job gets its own temp file.
|
||||
No conflicts between parallel document processing.
|
||||
|
||||
**Disk space**: Temp files are short-lived (duration of processing). For
|
||||
a 500MB PDF, need 500MB temp space during processing. Size limit could
|
||||
be enforced at upload time if disk space is constrained.
|
||||
|
||||
### Unified Processing Interface: Child Documents
|
||||
|
||||
PDF extraction and text document processing need to feed into the same
|
||||
downstream pipeline (chunker → embeddings → storage). To achieve this with
|
||||
a consistent "fetch by ID" interface, extracted text blobs are stored back
|
||||
to librarian as child documents.
|
||||
|
||||
#### Processing Flow with Child Documents
|
||||
|
||||
```
|
||||
PDF Document Text Document
|
||||
│ │
|
||||
▼ │
|
||||
pdf-extractor │
|
||||
│ │
|
||||
│ (stream PDF from librarian) │
|
||||
│ (extract page 1 text) │
|
||||
│ (store as child doc → librarian) │
|
||||
│ (extract page 2 text) │
|
||||
│ (store as child doc → librarian) │
|
||||
│ ⋮ │
|
||||
▼ ▼
|
||||
[child-doc-id, child-doc-id, ...] [doc-id]
|
||||
│ │
|
||||
└─────────────────────┬───────────────────────────────┘
|
||||
▼
|
||||
chunker
|
||||
│
|
||||
│ (receives document ID)
|
||||
│ (streams content from librarian)
|
||||
│ (chunks incrementally)
|
||||
▼
|
||||
[chunks → embedding → storage]
|
||||
```
|
||||
|
||||
The chunker has one uniform interface:
|
||||
- Receive a document ID (via Pulsar)
|
||||
- Stream content from librarian
|
||||
- Chunk it
|
||||
|
||||
It doesn't know or care whether the ID refers to:
|
||||
- A user-uploaded text document
|
||||
- An extracted text blob from a PDF page
|
||||
- Any future document type
|
||||
|
||||
#### Child Document Metadata
|
||||
|
||||
Extend the document schema to track parent/child relationships:
|
||||
|
||||
```sql
|
||||
-- Add columns to document table
|
||||
ALTER TABLE document ADD parent_id text;
|
||||
ALTER TABLE document ADD document_type text;
|
||||
|
||||
-- Index for finding children of a parent
|
||||
CREATE INDEX document_parent ON document (parent_id);
|
||||
```
|
||||
|
||||
**Document types:**
|
||||
|
||||
| `document_type` | Description |
|
||||
|-----------------|-------------|
|
||||
| `source` | User-uploaded document (PDF, text, etc.) |
|
||||
| `extracted` | Derived from a source document (e.g., PDF page text) |
|
||||
|
||||
**Metadata fields:**
|
||||
|
||||
| Field | Source Document | Extracted Child |
|
||||
|-------|-----------------|-----------------|
|
||||
| `id` | user-provided or generated | generated (e.g., `{parent-id}-page-{n}`) |
|
||||
| `parent_id` | `NULL` | parent document ID |
|
||||
| `document_type` | `source` | `extracted` |
|
||||
| `kind` | `application/pdf`, etc. | `text/plain` |
|
||||
| `title` | user-provided | generated (e.g., "Page 3 of Report.pdf") |
|
||||
| `user` | authenticated user | same as parent |
|
||||
|
||||
#### Librarian API for Child Documents
|
||||
|
||||
**Creating child documents** (internal, used by pdf-extractor):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "add-child-document",
|
||||
"parent-id": "doc-123",
|
||||
"document-metadata": {
|
||||
"id": "doc-123-page-1",
|
||||
"kind": "text/plain",
|
||||
"title": "Page 1"
|
||||
},
|
||||
"content": "<base64-encoded-text>"
|
||||
}
|
||||
```
|
||||
|
||||
For small extracted text (typical page text is < 100KB), single-operation
|
||||
upload is acceptable. For very large text extractions, chunked upload
|
||||
could be used.
|
||||
|
||||
**Listing child documents** (for debugging/admin):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "list-children",
|
||||
"parent-id": "doc-123"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"children": [
|
||||
{ "id": "doc-123-page-1", "title": "Page 1", "kind": "text/plain" },
|
||||
{ "id": "doc-123-page-2", "title": "Page 2", "kind": "text/plain" },
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### User-Facing Behavior
|
||||
|
||||
**`list-documents` default behavior:**
|
||||
|
||||
```sql
|
||||
SELECT * FROM document WHERE user = ? AND parent_id IS NULL;
|
||||
```
|
||||
|
||||
Only top-level (source) documents appear in the user's document list.
|
||||
Child documents are filtered out by default.
|
||||
|
||||
**Optional include-children flag** (for admin/debugging):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "list-documents",
|
||||
"include-children": true
|
||||
}
|
||||
```
|
||||
|
||||
#### Cascade Delete
|
||||
|
||||
When a parent document is deleted, all children must be deleted:
|
||||
|
||||
```python
|
||||
def delete_document(doc_id: str):
|
||||
# Find all children
|
||||
children = query("SELECT id, object_id FROM document WHERE parent_id = ?", doc_id)
|
||||
|
||||
# Delete child blobs from S3
|
||||
for child in children:
|
||||
blob_store.delete(child.object_id)
|
||||
|
||||
# Delete child metadata from Cassandra
|
||||
execute("DELETE FROM document WHERE parent_id = ?", doc_id)
|
||||
|
||||
# Delete parent blob and metadata
|
||||
parent = get_document(doc_id)
|
||||
blob_store.delete(parent.object_id)
|
||||
execute("DELETE FROM document WHERE id = ? AND user = ?", doc_id, user)
|
||||
```
|
||||
|
||||
#### Storage Considerations
|
||||
|
||||
Extracted text blobs do duplicate content:
|
||||
- Original PDF stored in Garage
|
||||
- Extracted text per page also stored in Garage
|
||||
|
||||
This tradeoff enables:
|
||||
- **Uniform chunker interface**: Chunker always fetches by ID
|
||||
- **Resume/retry**: Can restart at chunker stage without re-extracting PDF
|
||||
- **Debugging**: Extracted text is inspectable
|
||||
- **Separation of concerns**: PDF extractor and chunker are independent services
|
||||
|
||||
For a 500MB PDF with 200 pages averaging 5KB text per page:
|
||||
- PDF storage: 500MB
|
||||
- Extracted text storage: ~1MB total
|
||||
- Overhead: negligible
|
||||
|
||||
#### PDF Extractor Output
|
||||
|
||||
The pdf-extractor, after processing a document:
|
||||
|
||||
1. Streams PDF from librarian to temp file
|
||||
2. Extracts text page by page
|
||||
3. For each page, stores extracted text as child document via librarian
|
||||
4. Sends child document IDs to chunker queue
|
||||
|
||||
```python
|
||||
async def extract_pdf(doc_id: str, librarian_client, output_queue):
|
||||
"""Extract PDF pages and store as child documents."""
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix='.pdf') as tmp:
|
||||
# Stream PDF to temp file
|
||||
for chunk in librarian_client.stream_document(doc_id):
|
||||
tmp.write(chunk)
|
||||
tmp.flush()
|
||||
|
||||
# Extract pages
|
||||
reader = PdfReader(tmp.name)
|
||||
for page_num, page in enumerate(reader.pages, start=1):
|
||||
text = page.extract_text()
|
||||
|
||||
# Store as child document
|
||||
child_id = f"{doc_id}-page-{page_num}"
|
||||
await librarian_client.add_child_document(
|
||||
parent_id=doc_id,
|
||||
document_id=child_id,
|
||||
kind="text/plain",
|
||||
title=f"Page {page_num}",
|
||||
content=text.encode('utf-8')
|
||||
)
|
||||
|
||||
# Send to chunker queue
|
||||
await output_queue.send(child_id)
|
||||
```
|
||||
|
||||
The chunker receives these child IDs and processes them identically to
|
||||
how it would process a user-uploaded text document.
|
||||
|
||||
### Client Updates
|
||||
|
||||
#### Python SDK
|
||||
|
||||
The Python SDK (`trustgraph-base/trustgraph/api/library.py`) should handle
|
||||
chunked uploads transparently. The public interface remains unchanged:
|
||||
|
||||
```python
|
||||
# Existing interface - no change for users
|
||||
library.add_document(
|
||||
id="doc-123",
|
||||
title="Large Report",
|
||||
kind="application/pdf",
|
||||
content=large_pdf_bytes, # Can be hundreds of MB
|
||||
tags=["reports"]
|
||||
)
|
||||
```
|
||||
|
||||
Internally, the SDK detects document size and switches strategy:
|
||||
|
||||
```python
|
||||
class Library:
|
||||
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 # 2MB
|
||||
|
||||
def add_document(self, id, title, kind, content, tags=None, ...):
|
||||
if len(content) < self.CHUNKED_UPLOAD_THRESHOLD:
|
||||
# Small document: single operation (existing behavior)
|
||||
return self._add_document_single(id, title, kind, content, tags)
|
||||
else:
|
||||
# Large document: chunked upload
|
||||
return self._add_document_chunked(id, title, kind, content, tags)
|
||||
|
||||
def _add_document_chunked(self, id, title, kind, content, tags):
|
||||
# 1. begin-upload
|
||||
session = self._begin_upload(
|
||||
document_metadata={...},
|
||||
total_size=len(content),
|
||||
chunk_size=5 * 1024 * 1024
|
||||
)
|
||||
|
||||
# 2. upload-chunk for each chunk
|
||||
for i, chunk in enumerate(self._chunk_bytes(content, session.chunk_size)):
|
||||
self._upload_chunk(session.upload_id, i, chunk)
|
||||
|
||||
# 3. complete-upload
|
||||
return self._complete_upload(session.upload_id)
|
||||
```
|
||||
|
||||
**Progress callbacks** (optional enhancement):
|
||||
|
||||
```python
|
||||
def add_document(self, ..., on_progress=None):
|
||||
"""
|
||||
on_progress: Optional callback(bytes_sent, total_bytes)
|
||||
"""
|
||||
```
|
||||
|
||||
This allows UIs to display upload progress without changing the basic API.
|
||||
|
||||
#### CLI Tools
|
||||
|
||||
**`tg-add-library-document`** continues to work unchanged:
|
||||
|
||||
```bash
|
||||
# Works transparently for any size - SDK handles chunking internally
|
||||
tg-add-library-document --file large-report.pdf --title "Large Report"
|
||||
```
|
||||
|
||||
Optional progress display could be added:
|
||||
|
||||
```bash
|
||||
tg-add-library-document --file large-report.pdf --title "Large Report" --progress
|
||||
# Output:
|
||||
# Uploading: 45% (225MB / 500MB)
|
||||
```
|
||||
|
||||
**Legacy tools removed:**
|
||||
|
||||
- `tg-load-pdf` - deprecated, use `tg-add-library-document`
|
||||
- `tg-load-text` - deprecated, use `tg-add-library-document`
|
||||
|
||||
**Admin/debug commands** (optional, low priority):
|
||||
|
||||
```bash
|
||||
# List incomplete uploads (admin troubleshooting)
|
||||
tg-add-library-document --list-pending
|
||||
|
||||
# Resume specific upload (recovery scenario)
|
||||
tg-add-library-document --resume upload-abc-123 --file large-report.pdf
|
||||
```
|
||||
|
||||
These could be flags on the existing command rather than separate tools.
|
||||
|
||||
#### API Specification Updates
|
||||
|
||||
The OpenAPI spec (`specs/api/paths/librarian.yaml`) needs updates for:
|
||||
|
||||
**New operations:**
|
||||
|
||||
- `begin-upload` - Initialize chunked upload session
|
||||
- `upload-chunk` - Upload individual chunk
|
||||
- `complete-upload` - Finalize upload
|
||||
- `abort-upload` - Cancel upload
|
||||
- `get-upload-status` - Query upload progress
|
||||
- `list-uploads` - List incomplete uploads for user
|
||||
- `stream-document` - Streaming document retrieval
|
||||
- `add-child-document` - Store extracted text (internal)
|
||||
- `list-children` - List child documents (admin)
|
||||
|
||||
**Modified operations:**
|
||||
|
||||
- `list-documents` - Add `include-children` parameter
|
||||
|
||||
**New schemas:**
|
||||
|
||||
- `ChunkedUploadBeginRequest`
|
||||
- `ChunkedUploadBeginResponse`
|
||||
- `ChunkedUploadChunkRequest`
|
||||
- `ChunkedUploadChunkResponse`
|
||||
- `UploadSession`
|
||||
- `UploadProgress`
|
||||
|
||||
**WebSocket spec updates** (`specs/websocket/`):
|
||||
|
||||
Mirror the REST operations for WebSocket clients, enabling real-time
|
||||
progress updates during upload.
|
||||
|
||||
#### UX Considerations
|
||||
|
||||
The API spec updates enable frontend improvements:
|
||||
|
||||
**Upload progress UI:**
|
||||
- Progress bar showing chunks uploaded
|
||||
- Estimated time remaining
|
||||
- Pause/resume capability
|
||||
|
||||
**Error recovery:**
|
||||
- "Resume upload" option for interrupted uploads
|
||||
- List of pending uploads on reconnect
|
||||
|
||||
**Large file handling:**
|
||||
- Client-side file size detection
|
||||
- Automatic chunked upload for large files
|
||||
- Clear feedback during long uploads
|
||||
|
||||
These UX improvements require frontend work guided by the updated API spec.
|
||||
263
docs/tech-specs/query-time-explainability.md
Normal file
263
docs/tech-specs/query-time-explainability.md
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
# Query-Time Explainability
|
||||
|
||||
## Status
|
||||
|
||||
Implemented
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes how GraphRAG records and communicates explainability data during query execution. The goal is full traceability: from final answer back through selected edges to source documents.
|
||||
|
||||
Query-time explainability captures what the GraphRAG pipeline did during reasoning. It connects to extraction-time provenance which records where knowledge graph facts originated.
|
||||
|
||||
## Terminology
|
||||
|
||||
| Term | Definition |
|
||||
|------|------------|
|
||||
| **Explainability** | The record of how a result was derived |
|
||||
| **Session** | A single GraphRAG query execution |
|
||||
| **Edge Selection** | LLM-driven selection of relevant edges with reasoning |
|
||||
| **Provenance Chain** | Path from edge → chunk → page → document |
|
||||
|
||||
## Architecture
|
||||
|
||||
### Explainability Flow
|
||||
|
||||
```
|
||||
GraphRAG Query
|
||||
│
|
||||
├─► Session Activity
|
||||
│ └─► Query text, timestamp
|
||||
│
|
||||
├─► Retrieval Entity
|
||||
│ └─► All edges retrieved from subgraph
|
||||
│
|
||||
├─► Selection Entity
|
||||
│ └─► Selected edges with LLM reasoning
|
||||
│ └─► Each edge links to extraction provenance
|
||||
│
|
||||
└─► Answer Entity
|
||||
└─► Reference to synthesized response (in librarian)
|
||||
```
|
||||
|
||||
### Two-Stage GraphRAG Pipeline
|
||||
|
||||
1. **Edge Selection**: LLM selects relevant edges from subgraph, providing reasoning for each
|
||||
2. **Synthesis**: LLM generates answer from selected edges only
|
||||
|
||||
This separation enables explainability - we know exactly which edges contributed.
|
||||
|
||||
### Storage
|
||||
|
||||
- Explainability triples stored in configurable collection (default: `explainability`)
|
||||
- Uses PROV-O ontology for provenance relationships
|
||||
- RDF-star reification for edge references
|
||||
- Answer content stored in librarian service (not inline - too large)
|
||||
|
||||
### Real-Time Streaming
|
||||
|
||||
Explainability events stream to client as the query executes:
|
||||
|
||||
1. Session created → event emitted
|
||||
2. Edges retrieved → event emitted
|
||||
3. Edges selected with reasoning → event emitted
|
||||
4. Answer synthesized → event emitted
|
||||
|
||||
Client receives `explain_id` and `explain_collection` to fetch full details.
|
||||
|
||||
## URI Structure
|
||||
|
||||
All URIs use the `urn:trustgraph:` namespace with UUIDs:
|
||||
|
||||
| Entity | URI Pattern |
|
||||
|--------|-------------|
|
||||
| Session | `urn:trustgraph:session:{uuid}` |
|
||||
| Retrieval | `urn:trustgraph:prov:retrieval:{uuid}` |
|
||||
| Selection | `urn:trustgraph:prov:selection:{uuid}` |
|
||||
| Answer | `urn:trustgraph:prov:answer:{uuid}` |
|
||||
| Edge Selection | `urn:trustgraph:prov:edge:{uuid}:{index}` |
|
||||
|
||||
## RDF Model (PROV-O)
|
||||
|
||||
### Session Activity
|
||||
|
||||
```turtle
|
||||
<session-uri> a prov:Activity ;
|
||||
rdfs:label "GraphRAG query session" ;
|
||||
prov:startedAtTime "2024-01-15T10:30:00Z" ;
|
||||
tg:query "What was the War on Terror?" .
|
||||
```
|
||||
|
||||
### Retrieval Entity
|
||||
|
||||
```turtle
|
||||
<retrieval-uri> a prov:Entity ;
|
||||
rdfs:label "Retrieved edges" ;
|
||||
prov:wasGeneratedBy <session-uri> ;
|
||||
tg:edgeCount 50 .
|
||||
```
|
||||
|
||||
### Selection Entity
|
||||
|
||||
```turtle
|
||||
<selection-uri> a prov:Entity ;
|
||||
rdfs:label "Selected edges" ;
|
||||
prov:wasDerivedFrom <retrieval-uri> ;
|
||||
tg:selectedEdge <edge-sel-0> ;
|
||||
tg:selectedEdge <edge-sel-1> .
|
||||
|
||||
<edge-sel-0> tg:edge << <s> <p> <o> >> ;
|
||||
tg:reasoning "This edge establishes the key relationship..." .
|
||||
```
|
||||
|
||||
### Answer Entity
|
||||
|
||||
```turtle
|
||||
<answer-uri> a prov:Entity ;
|
||||
rdfs:label "GraphRAG answer" ;
|
||||
prov:wasDerivedFrom <selection-uri> ;
|
||||
tg:document <urn:trustgraph:answer:{uuid}> .
|
||||
```
|
||||
|
||||
The `tg:document` references the answer stored in the librarian service.
|
||||
|
||||
## Namespace Constants
|
||||
|
||||
Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`:
|
||||
|
||||
| Constant | URI |
|
||||
|----------|-----|
|
||||
| `TG_QUERY` | `https://trustgraph.ai/ns/query` |
|
||||
| `TG_EDGE_COUNT` | `https://trustgraph.ai/ns/edgeCount` |
|
||||
| `TG_SELECTED_EDGE` | `https://trustgraph.ai/ns/selectedEdge` |
|
||||
| `TG_EDGE` | `https://trustgraph.ai/ns/edge` |
|
||||
| `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` |
|
||||
| `TG_CONTENT` | `https://trustgraph.ai/ns/content` |
|
||||
| `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` |
|
||||
|
||||
## GraphRagResponse Schema
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class GraphRagResponse:
|
||||
error: Error | None = None
|
||||
response: str = ""
|
||||
end_of_stream: bool = False
|
||||
explain_id: str | None = None
|
||||
explain_collection: str | None = None
|
||||
message_type: str = "" # "chunk" or "explain"
|
||||
end_of_session: bool = False
|
||||
```
|
||||
|
||||
### Message Types
|
||||
|
||||
| message_type | Purpose |
|
||||
|--------------|---------|
|
||||
| `chunk` | Response text (streaming or final) |
|
||||
| `explain` | Explainability event with IRI reference |
|
||||
|
||||
### Session Lifecycle
|
||||
|
||||
1. Multiple `explain` messages (session, retrieval, selection, answer)
|
||||
2. Multiple `chunk` messages (streaming response)
|
||||
3. Final `chunk` with `end_of_session=True`
|
||||
|
||||
## Edge Selection Format
|
||||
|
||||
LLM returns JSONL with selected edges:
|
||||
|
||||
```jsonl
|
||||
{"id": "edge-hash-1", "reasoning": "This edge shows the key relationship..."}
|
||||
{"id": "edge-hash-2", "reasoning": "Provides supporting evidence..."}
|
||||
```
|
||||
|
||||
The `id` is a hash of `(labeled_s, labeled_p, labeled_o)` computed by `edge_id()`.
|
||||
|
||||
## URI Preservation
|
||||
|
||||
### The Problem
|
||||
|
||||
GraphRAG displays human-readable labels to the LLM, but explainability needs original URIs for provenance tracing.
|
||||
|
||||
### Solution
|
||||
|
||||
`get_labelgraph()` returns both:
|
||||
- `labeled_edges`: List of `(label_s, label_p, label_o)` for LLM
|
||||
- `uri_map`: Dict mapping `edge_id(labels)` → `(uri_s, uri_p, uri_o)`
|
||||
|
||||
When storing explainability data, URIs from `uri_map` are used.
|
||||
|
||||
## Provenance Tracing
|
||||
|
||||
### From Edge to Source
|
||||
|
||||
Selected edges can be traced back to source documents:
|
||||
|
||||
1. Query for containing subgraph: `?subgraph tg:contains <<s p o>>`
|
||||
2. Follow `prov:wasDerivedFrom` chain to root document
|
||||
3. Each step in chain: chunk → page → document
|
||||
|
||||
### Cassandra Quoted Triple Support
|
||||
|
||||
The Cassandra query service supports matching quoted triples:
|
||||
|
||||
```python
|
||||
# In get_term_value():
|
||||
elif term.type == TRIPLE:
|
||||
return serialize_triple(term.triple)
|
||||
```
|
||||
|
||||
This enables queries like:
|
||||
```
|
||||
?subgraph tg:contains <<http://example.org/s http://example.org/p "value">>
|
||||
```
|
||||
|
||||
## CLI Usage
|
||||
|
||||
```bash
|
||||
tg-invoke-graph-rag --explainable -q "What was the War on Terror?"
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```
|
||||
[session] urn:trustgraph:session:abc123
|
||||
|
||||
[retrieval] urn:trustgraph:prov:retrieval:abc123
|
||||
|
||||
[selection] urn:trustgraph:prov:selection:abc123
|
||||
Selected 12 edge(s)
|
||||
Edge: (Guantanamo, definition, A detention facility...)
|
||||
Reason: Directly connects Guantanamo to the War on Terror
|
||||
Source: Chunk 1 → Page 2 → Beyond the Vigilant State
|
||||
|
||||
[answer] urn:trustgraph:prov:answer:abc123
|
||||
|
||||
Based on the provided knowledge statements...
|
||||
```
|
||||
|
||||
### Features
|
||||
|
||||
- Real-time explainability events during query
|
||||
- Label resolution for edge components via `rdfs:label`
|
||||
- Source chain tracing via `prov:wasDerivedFrom`
|
||||
- Label caching to avoid repeated queries
|
||||
|
||||
## Files Implemented
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `trustgraph-base/trustgraph/provenance/uris.py` | URI generators |
|
||||
| `trustgraph-base/trustgraph/provenance/namespaces.py` | RDF namespace constants |
|
||||
| `trustgraph-base/trustgraph/provenance/triples.py` | Triple builders |
|
||||
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | GraphRagResponse schema |
|
||||
| `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py` | Core GraphRAG with URI preservation |
|
||||
| `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` | Service with librarian integration |
|
||||
| `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support |
|
||||
| `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display |
|
||||
|
||||
## References
|
||||
|
||||
- PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/
|
||||
- RDF-star: https://w3c.github.io/rdf-star/
|
||||
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
|
||||
471
docs/tech-specs/tool-services.md
Normal file
471
docs/tech-specs/tool-services.md
Normal file
|
|
@ -0,0 +1,471 @@
|
|||
# Tool Services: Dynamically Pluggable Agent Tools
|
||||
|
||||
## Status
|
||||
|
||||
Implemented
|
||||
|
||||
## Overview
|
||||
|
||||
This specification defines a mechanism for dynamically pluggable agent tools called "tool services". Unlike the existing built-in tool types (`KnowledgeQueryImpl`, `McpToolImpl`, etc.), tool services allow new tools to be introduced by:
|
||||
|
||||
1. Deploying a new Pulsar-based service
|
||||
2. Adding a configuration descriptor that tells the agent how to invoke it
|
||||
|
||||
This enables extensibility without modifying the core agent-react framework.
|
||||
|
||||
## Terminology
|
||||
|
||||
| Term | Definition |
|
||||
|------|------------|
|
||||
| **Built-in Tool** | Existing tool types with hardcoded implementations in `tools.py` |
|
||||
| **Tool Service** | A Pulsar service that can be invoked as an agent tool, defined by a service descriptor |
|
||||
| **Tool** | A configured instance that references a tool service, exposed to the agent/LLM |
|
||||
|
||||
This is a two-tier model, analogous to MCP tools:
|
||||
- MCP: MCP server defines the tool interface → Tool config references it
|
||||
- Tool Services: Tool service defines the Pulsar interface → Tool config references it
|
||||
|
||||
## Background: Existing Tools
|
||||
|
||||
### Built-in Tool Implementation
|
||||
|
||||
Tools are currently defined in `trustgraph-flow/trustgraph/agent/react/tools.py` with typed implementations:
|
||||
|
||||
```python
|
||||
class KnowledgeQueryImpl:
|
||||
async def invoke(self, question):
|
||||
client = self.context("graph-rag-request")
|
||||
return await client.rag(question, self.collection)
|
||||
```
|
||||
|
||||
Each tool type:
|
||||
- Has a hardcoded Pulsar service it calls (e.g., `graph-rag-request`)
|
||||
- Knows the exact method to call on the client (e.g., `client.rag()`)
|
||||
- Has typed arguments defined in the implementation
|
||||
|
||||
### Tool Registration (service.py:105-214)
|
||||
|
||||
Tools are loaded from config with a `type` field that maps to an implementation:
|
||||
|
||||
```python
|
||||
if impl_id == "knowledge-query":
|
||||
impl = functools.partial(KnowledgeQueryImpl, collection=data.get("collection"))
|
||||
elif impl_id == "text-completion":
|
||||
impl = TextCompletionImpl
|
||||
# ... etc
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-Tier Model
|
||||
|
||||
#### Tier 1: Tool Service Descriptor
|
||||
|
||||
A tool service defines a Pulsar service interface. It declares:
|
||||
- The Pulsar queues for request/response
|
||||
- Configuration parameters it requires from tools that use it
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "custom-rag",
|
||||
"request-queue": "non-persistent://tg/request/custom-rag",
|
||||
"response-queue": "non-persistent://tg/response/custom-rag",
|
||||
"config-params": [
|
||||
{"name": "collection", "required": true}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
A tool service that needs no configuration parameters:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "calculator",
|
||||
"request-queue": "non-persistent://tg/request/calc",
|
||||
"response-queue": "non-persistent://tg/response/calc",
|
||||
"config-params": []
|
||||
}
|
||||
```
|
||||
|
||||
#### Tier 2: Tool Descriptor
|
||||
|
||||
A tool references a tool service and provides:
|
||||
- Config parameter values (satisfying the service's requirements)
|
||||
- Tool metadata for the agent (name, description)
|
||||
- Argument definitions for the LLM
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "tool-service",
|
||||
"name": "query-customers",
|
||||
"description": "Query the customer knowledge base",
|
||||
"service": "custom-rag",
|
||||
"collection": "customers",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask about customers"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Multiple tools can reference the same service with different configurations:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "tool-service",
|
||||
"name": "query-products",
|
||||
"description": "Query the product knowledge base",
|
||||
"service": "custom-rag",
|
||||
"collection": "products",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask about products"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Request Format
|
||||
|
||||
When a tool is invoked, the request to the tool service includes:
|
||||
- `user`: From the agent request (multi-tenancy)
|
||||
- `config`: JSON-encoded config values from the tool descriptor
|
||||
- `arguments`: JSON-encoded arguments from the LLM
|
||||
|
||||
```json
|
||||
{
|
||||
"user": "alice",
|
||||
"config": "{\"collection\": \"customers\"}",
|
||||
"arguments": "{\"question\": \"What are the top customer complaints?\"}"
|
||||
}
|
||||
```
|
||||
|
||||
The tool service receives these as parsed dicts in the `invoke` method.
|
||||
|
||||
### Generic Tool Service Implementation
|
||||
|
||||
A `ToolServiceImpl` class invokes tool services based on configuration:
|
||||
|
||||
```python
|
||||
class ToolServiceImpl:
|
||||
def __init__(self, context, request_queue, response_queue, config_values, arguments, processor):
|
||||
self.request_queue = request_queue
|
||||
self.response_queue = response_queue
|
||||
self.config_values = config_values # e.g., {"collection": "customers"}
|
||||
# ...
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
client = await self._get_or_create_client()
|
||||
response = await client.call(user, self.config_values, arguments)
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
else:
|
||||
return json.dumps(response)
|
||||
```
|
||||
|
||||
## Design Decisions
|
||||
|
||||
### Two-Tier Configuration Model
|
||||
|
||||
Tool services follow a two-tier model similar to MCP tools:
|
||||
|
||||
1. **Tool Service**: Defines the Pulsar service interface (topic, required config params)
|
||||
2. **Tool**: References a tool service, provides config values, defines LLM arguments
|
||||
|
||||
This separation allows:
|
||||
- One tool service to be used by multiple tools with different configurations
|
||||
- Clear distinction between service interface and tool configuration
|
||||
- Reusability of service definitions
|
||||
|
||||
### Request Mapping: Pass-Through with Envelope
|
||||
|
||||
The request to a tool service is a structured envelope containing:
|
||||
- `user`: Propagated from the agent request for multi-tenancy
|
||||
- Config values: From the tool descriptor (e.g., `collection`)
|
||||
- `arguments`: LLM-provided arguments, passed through as a dict
|
||||
|
||||
The agent manager parses the LLM's response into `act.arguments` as a dict (`agent_manager.py:117-154`). This dict is included in the request envelope.
|
||||
|
||||
### Schema Handling: Untyped
|
||||
|
||||
Requests and responses use untyped dicts. No schema validation at the agent level - the tool service is responsible for validating its inputs. This provides maximum flexibility for defining new services.
|
||||
|
||||
### Client Interface: Direct Pulsar Topics
|
||||
|
||||
Tool services use direct Pulsar topics without requiring flow configuration. The tool-service descriptor specifies the full queue names:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [...]
|
||||
}
|
||||
```
|
||||
|
||||
This allows services to be hosted in any namespace.
|
||||
|
||||
### Error Handling: Standard Error Convention
|
||||
|
||||
Tool service responses follow the existing schema convention with an `error` field:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Error:
|
||||
type: str = ""
|
||||
message: str = ""
|
||||
```
|
||||
|
||||
Response structure:
|
||||
- Success: `error` is `None`, response contains result
|
||||
- Error: `error` is populated with `type` and `message`
|
||||
|
||||
This matches the pattern used throughout existing service schemas (e.g., `PromptResponse`, `QueryResponse`, `AgentResponse`).
|
||||
|
||||
### Request/Response Correlation
|
||||
|
||||
Requests and responses are correlated using an `id` in Pulsar message properties:
|
||||
|
||||
- Request includes `id` in properties: `properties={"id": id}`
|
||||
- Response(s) include the same `id`: `properties={"id": id}`
|
||||
|
||||
This follows the existing pattern used throughout the codebase (e.g., `agent_service.py`, `llm_service.py`).
|
||||
|
||||
### Streaming Support
|
||||
|
||||
Tool services can return streaming responses:
|
||||
|
||||
- Multiple response messages with the same `id` in properties
|
||||
- Each response includes `end_of_stream: bool` field
|
||||
- Final response has `end_of_stream: True`
|
||||
|
||||
This matches the pattern used in `AgentResponse` and other streaming services.
|
||||
|
||||
### Response Handling: String Return
|
||||
|
||||
All existing tools follow the same pattern: **receive arguments as a dict, return observation as a string**.
|
||||
|
||||
| Tool | Response Handling |
|
||||
|------|------------------|
|
||||
| `KnowledgeQueryImpl` | Returns `client.rag()` directly (string) |
|
||||
| `TextCompletionImpl` | Returns `client.question()` directly (string) |
|
||||
| `McpToolImpl` | Returns string, or `json.dumps(output)` if not string |
|
||||
| `StructuredQueryImpl` | Formats result to string |
|
||||
| `PromptImpl` | Returns `client.prompt()` directly (string) |
|
||||
|
||||
Tool services follow the same contract:
|
||||
- The service returns a string response (the observation)
|
||||
- If the response is not a string, it is converted via `json.dumps()`
|
||||
- No extraction configuration needed in the descriptor
|
||||
|
||||
This keeps the descriptor simple and places responsibility on the service to return an appropriate text response for the agent.
|
||||
|
||||
## Configuration Guide
|
||||
|
||||
To add a new tool service, two configuration items are required:
|
||||
|
||||
### 1. Tool Service Configuration
|
||||
|
||||
Stored under the `tool-service` config key. Defines the Pulsar queues and available config parameters.
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| `id` | Yes | Unique identifier for the tool service |
|
||||
| `request-queue` | Yes | Full Pulsar topic for requests (e.g., `non-persistent://tg/request/joke`) |
|
||||
| `response-queue` | Yes | Full Pulsar topic for responses (e.g., `non-persistent://tg/response/joke`) |
|
||||
| `config-params` | No | Array of config parameters the service accepts |
|
||||
|
||||
Each config param can specify:
|
||||
- `name`: Parameter name (required)
|
||||
- `required`: Whether the parameter must be provided by tools (default: false)
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [
|
||||
{"name": "style", "required": false}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Tool Configuration
|
||||
|
||||
Stored under the `tool` config key. Defines a tool that the agent can use.
|
||||
|
||||
| Field | Required | Description |
|
||||
|-------|----------|-------------|
|
||||
| `type` | Yes | Must be `"tool-service"` |
|
||||
| `name` | Yes | Tool name exposed to the LLM |
|
||||
| `description` | Yes | Description of what the tool does (shown to LLM) |
|
||||
| `service` | Yes | ID of the tool-service to invoke |
|
||||
| `arguments` | No | Array of argument definitions for the LLM |
|
||||
| *(config params)* | Varies | Any config params defined by the service |
|
||||
|
||||
Each argument can specify:
|
||||
- `name`: Argument name (required)
|
||||
- `type`: Data type, e.g., `"string"` (required)
|
||||
- `description`: Description shown to the LLM (required)
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"type": "tool-service",
|
||||
"name": "tell-joke",
|
||||
"description": "Tell a joke on a given topic",
|
||||
"service": "joke-service",
|
||||
"style": "pun",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "topic",
|
||||
"type": "string",
|
||||
"description": "The topic for the joke (e.g., programming, animals, food)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Loading Configuration
|
||||
|
||||
Use `tg-put-config-item` to load configurations:
|
||||
|
||||
```bash
|
||||
# Load tool-service config
|
||||
tg-put-config-item tool-service/joke-service < joke-service.json
|
||||
|
||||
# Load tool config
|
||||
tg-put-config-item tool/tell-joke < tell-joke.json
|
||||
```
|
||||
|
||||
The agent-manager must be restarted to pick up new configurations.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Schema
|
||||
|
||||
Request and response types in `trustgraph-base/trustgraph/schema/services/tool_service.py`:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ToolServiceRequest:
|
||||
user: str = "" # User context for multi-tenancy
|
||||
config: str = "" # JSON-encoded config values from tool descriptor
|
||||
arguments: str = "" # JSON-encoded arguments from LLM
|
||||
|
||||
@dataclass
|
||||
class ToolServiceResponse:
|
||||
error: Error | None = None
|
||||
response: str = "" # String response (the observation)
|
||||
end_of_stream: bool = False
|
||||
```
|
||||
|
||||
### Server-Side: DynamicToolService
|
||||
|
||||
Base class in `trustgraph-base/trustgraph/base/dynamic_tool_service.py`:
|
||||
|
||||
```python
|
||||
class DynamicToolService(AsyncProcessor):
|
||||
"""Base class for implementing tool services."""
|
||||
|
||||
def __init__(self, **params):
|
||||
topic = params.get("topic", default_topic)
|
||||
# Constructs topics: non-persistent://tg/request/{topic}, non-persistent://tg/response/{topic}
|
||||
# Sets up Consumer and Producer
|
||||
|
||||
async def invoke(self, user, config, arguments):
|
||||
"""Override this method to implement the tool's logic."""
|
||||
raise NotImplementedError()
|
||||
```
|
||||
|
||||
### Client-Side: ToolServiceImpl
|
||||
|
||||
Implementation in `trustgraph-flow/trustgraph/agent/react/tools.py`:
|
||||
|
||||
```python
|
||||
class ToolServiceImpl:
|
||||
def __init__(self, context, request_queue, response_queue, config_values, arguments, processor):
|
||||
# Uses the provided queue paths directly
|
||||
# Creates ToolServiceClient on first use
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
client = await self._get_or_create_client()
|
||||
response = await client.call(user, config_values, arguments)
|
||||
return response if isinstance(response, str) else json.dumps(response)
|
||||
```
|
||||
|
||||
### Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `trustgraph-base/trustgraph/schema/services/tool_service.py` | Request/response schemas |
|
||||
| `trustgraph-base/trustgraph/base/tool_service_client.py` | Client for invoking services |
|
||||
| `trustgraph-base/trustgraph/base/dynamic_tool_service.py` | Base class for service implementation |
|
||||
| `trustgraph-flow/trustgraph/agent/react/tools.py` | `ToolServiceImpl` class |
|
||||
| `trustgraph-flow/trustgraph/agent/react/service.py` | Config loading |
|
||||
|
||||
### Example: Joke Service
|
||||
|
||||
An example service in `trustgraph-flow/trustgraph/tool_service/joke/`:
|
||||
|
||||
```python
|
||||
class Processor(DynamicToolService):
|
||||
async def invoke(self, user, config, arguments):
|
||||
style = config.get("style", "pun")
|
||||
topic = arguments.get("topic", "")
|
||||
joke = pick_joke(topic, style)
|
||||
return f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
```
|
||||
|
||||
Tool service config:
|
||||
```json
|
||||
{
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [{"name": "style", "required": false}]
|
||||
}
|
||||
```
|
||||
|
||||
Tool config:
|
||||
```json
|
||||
{
|
||||
"type": "tool-service",
|
||||
"name": "tell-joke",
|
||||
"description": "Tell a joke on a given topic",
|
||||
"service": "joke-service",
|
||||
"style": "pun",
|
||||
"arguments": [
|
||||
{"name": "topic", "type": "string", "description": "The topic for the joke"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
- Existing built-in tool types continue to work unchanged
|
||||
- `tool-service` is a new tool type alongside existing types (`knowledge-query`, `mcp-tool`, etc.)
|
||||
|
||||
## Future Considerations
|
||||
|
||||
### Self-Announcing Services
|
||||
|
||||
A future enhancement could allow services to publish their own descriptors:
|
||||
|
||||
- Services publish to a well-known `tool-descriptors` topic on startup
|
||||
- Agent subscribes and dynamically registers tools
|
||||
- Enables true plug-and-play without config changes
|
||||
|
||||
This is out of scope for the initial implementation.
|
||||
|
||||
## References
|
||||
|
||||
- Current tool implementation: `trustgraph-flow/trustgraph/agent/react/tools.py`
|
||||
- Tool registration: `trustgraph-flow/trustgraph/agent/react/service.py:105-214`
|
||||
- Agent schemas: `trustgraph-base/trustgraph/schema/services/agent.py`
|
||||
1411
docs/websocket.html
1411
docs/websocket.html
File diff suppressed because one or more lines are too long
|
|
@ -1,14 +1,60 @@
|
|||
type: object
|
||||
description: RDF value - can be entity/URI or literal
|
||||
required:
|
||||
- v
|
||||
- e
|
||||
description: |
|
||||
RDF Term - typed representation of a value in the knowledge graph.
|
||||
|
||||
Term types (discriminated by `t` field):
|
||||
- `i`: IRI (URI reference)
|
||||
- `l`: Literal (string value, optionally with datatype or language tag)
|
||||
- `r`: Quoted triple (RDF-star reification)
|
||||
- `b`: Blank node
|
||||
properties:
|
||||
t:
|
||||
type: string
|
||||
description: Term type discriminator
|
||||
enum: [i, l, r, b]
|
||||
example: i
|
||||
i:
|
||||
type: string
|
||||
description: IRI value (when t=i)
|
||||
example: http://example.com/Person1
|
||||
v:
|
||||
type: string
|
||||
description: Value (URI or literal text)
|
||||
example: https://example.com/entity1
|
||||
e:
|
||||
type: boolean
|
||||
description: True if entity/URI, false if literal
|
||||
example: true
|
||||
description: Literal value (when t=l)
|
||||
example: John Doe
|
||||
d:
|
||||
type: string
|
||||
description: Datatype IRI for literal (when t=l, optional)
|
||||
example: http://www.w3.org/2001/XMLSchema#integer
|
||||
l:
|
||||
type: string
|
||||
description: Language tag for literal (when t=l, optional)
|
||||
example: en
|
||||
r:
|
||||
type: object
|
||||
description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure
|
||||
properties:
|
||||
s:
|
||||
type: object
|
||||
description: Subject term
|
||||
p:
|
||||
type: object
|
||||
description: Predicate term
|
||||
o:
|
||||
type: object
|
||||
description: Object term
|
||||
required:
|
||||
- t
|
||||
examples:
|
||||
- description: IRI term
|
||||
value:
|
||||
t: i
|
||||
i: http://schema.org/name
|
||||
- description: Literal term
|
||||
value:
|
||||
t: l
|
||||
v: John Doe
|
||||
- description: Literal with language tag
|
||||
value:
|
||||
t: l
|
||||
v: Bonjour
|
||||
l: fr
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
type: object
|
||||
description: RDF triple (subject-predicate-object)
|
||||
description: |
|
||||
RDF triple (subject-predicate-object), optionally scoped to a named graph.
|
||||
required:
|
||||
- s
|
||||
- p
|
||||
|
|
@ -14,3 +15,7 @@ properties:
|
|||
o:
|
||||
$ref: './RdfValue.yaml'
|
||||
description: Object
|
||||
g:
|
||||
type: string
|
||||
description: Named graph URI (optional)
|
||||
example: urn:graph:source
|
||||
|
|
|
|||
|
|
@ -9,12 +9,26 @@ properties:
|
|||
- action
|
||||
- observation
|
||||
- answer
|
||||
- final-answer
|
||||
- error
|
||||
example: answer
|
||||
content:
|
||||
type: string
|
||||
description: Chunk content (streaming mode only)
|
||||
example: Paris is the capital of France.
|
||||
message_type:
|
||||
type: string
|
||||
description: Message type - "chunk" for agent chunks, "explain" for explainability events
|
||||
enum: [chunk, explain]
|
||||
example: chunk
|
||||
explain_id:
|
||||
type: string
|
||||
description: Explainability node URI (for explain messages)
|
||||
example: urn:trustgraph:agent:abc123
|
||||
explain_graph:
|
||||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
end-of-message:
|
||||
type: boolean
|
||||
description: Current chunk type is complete (streaming mode)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,60 @@
|
|||
type: object
|
||||
description: |
|
||||
RDF value - represents either a URI/entity or a literal value.
|
||||
RDF Term - typed representation of a value in the knowledge graph.
|
||||
|
||||
When `e` is true, `v` must be a full URI (e.g., http://schema.org/name).
|
||||
When `e` is false, `v` is a literal value (string, number, etc.).
|
||||
Term types (discriminated by `t` field):
|
||||
- `i`: IRI (URI reference)
|
||||
- `l`: Literal (string value, optionally with datatype or language tag)
|
||||
- `r`: Quoted triple (RDF-star reification)
|
||||
- `b`: Blank node
|
||||
properties:
|
||||
t:
|
||||
type: string
|
||||
description: Term type discriminator
|
||||
enum: [i, l, r, b]
|
||||
example: i
|
||||
i:
|
||||
type: string
|
||||
description: IRI value (when t=i)
|
||||
example: http://example.com/Person1
|
||||
v:
|
||||
type: string
|
||||
description: The value - full URI when e=true, literal when e=false
|
||||
example: http://example.com/Person1
|
||||
e:
|
||||
type: boolean
|
||||
description: True if entity/URI, false if literal value
|
||||
example: true
|
||||
description: Literal value (when t=l)
|
||||
example: John Doe
|
||||
d:
|
||||
type: string
|
||||
description: Datatype IRI for literal (when t=l, optional)
|
||||
example: http://www.w3.org/2001/XMLSchema#integer
|
||||
l:
|
||||
type: string
|
||||
description: Language tag for literal (when t=l, optional)
|
||||
example: en
|
||||
r:
|
||||
type: object
|
||||
description: Quoted triple (when t=r) - contains s, p, o as nested Term objects with the same structure
|
||||
properties:
|
||||
s:
|
||||
type: object
|
||||
description: Subject term
|
||||
p:
|
||||
type: object
|
||||
description: Predicate term
|
||||
o:
|
||||
type: object
|
||||
description: Object term
|
||||
required:
|
||||
- v
|
||||
- e
|
||||
example:
|
||||
v: http://schema.org/name
|
||||
e: true
|
||||
- t
|
||||
examples:
|
||||
- description: IRI term
|
||||
value:
|
||||
t: i
|
||||
i: http://schema.org/name
|
||||
- description: Literal term
|
||||
value:
|
||||
t: l
|
||||
v: John Doe
|
||||
- description: Literal with language tag
|
||||
value:
|
||||
t: l
|
||||
v: Bonjour
|
||||
l: fr
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
type: object
|
||||
description: |
|
||||
RDF triple representing a subject-predicate-object statement in the knowledge graph.
|
||||
RDF triple representing a subject-predicate-object statement in the knowledge graph,
|
||||
optionally scoped to a named graph.
|
||||
|
||||
Example: (Person1) -[has name]-> ("John Doe")
|
||||
properties:
|
||||
|
|
@ -13,17 +14,26 @@ properties:
|
|||
o:
|
||||
$ref: './RdfValue.yaml'
|
||||
description: Object - the value or target entity
|
||||
g:
|
||||
type: string
|
||||
description: |
|
||||
Named graph URI (optional). When absent, the triple is in the default graph.
|
||||
Well-known graphs:
|
||||
- (empty/absent): Core knowledge facts
|
||||
- urn:graph:source: Extraction provenance
|
||||
- urn:graph:retrieval: Query-time explainability
|
||||
example: urn:graph:source
|
||||
required:
|
||||
- s
|
||||
- p
|
||||
- o
|
||||
example:
|
||||
s:
|
||||
v: http://example.com/Person1
|
||||
e: true
|
||||
t: i
|
||||
i: http://example.com/Person1
|
||||
p:
|
||||
v: http://schema.org/name
|
||||
e: true
|
||||
t: i
|
||||
i: http://schema.org/name
|
||||
o:
|
||||
t: l
|
||||
v: John Doe
|
||||
e: false
|
||||
|
|
|
|||
|
|
@ -1,12 +1,22 @@
|
|||
type: object
|
||||
description: Document embeddings query response
|
||||
description: Document embeddings query response with matching chunks and similarity scores
|
||||
properties:
|
||||
chunks:
|
||||
type: array
|
||||
description: Similar document chunks (text strings)
|
||||
description: Matching document chunks with similarity scores
|
||||
items:
|
||||
type: string
|
||||
type: object
|
||||
properties:
|
||||
chunk_id:
|
||||
type: string
|
||||
description: Chunk identifier URI
|
||||
example: "urn:trustgraph:chunk:abc123"
|
||||
score:
|
||||
type: number
|
||||
description: Similarity score (higher is more similar)
|
||||
example: 0.89
|
||||
example:
|
||||
- "Quantum computing uses quantum mechanics principles for computation..."
|
||||
- "Neural networks are computing systems inspired by biological neurons..."
|
||||
- "Machine learning algorithms learn patterns from data..."
|
||||
- chunk_id: "urn:trustgraph:chunk:abc123"
|
||||
score: 0.95
|
||||
- chunk_id: "urn:trustgraph:chunk:def456"
|
||||
score: 0.82
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
type: object
|
||||
description: Graph embeddings query response
|
||||
description: Graph embeddings query response with matching entities and similarity scores
|
||||
properties:
|
||||
entities:
|
||||
type: array
|
||||
description: Similar entities (RDF values)
|
||||
description: Matching graph entities with similarity scores
|
||||
items:
|
||||
$ref: '../../common/RdfValue.yaml'
|
||||
type: object
|
||||
properties:
|
||||
entity:
|
||||
$ref: '../../common/RdfValue.yaml'
|
||||
description: Matching graph entity
|
||||
score:
|
||||
type: number
|
||||
description: Similarity score (higher is more similar)
|
||||
example: 0.92
|
||||
example:
|
||||
- {v: "https://example.com/person/alice", e: true}
|
||||
- {v: "https://example.com/person/bob", e: true}
|
||||
- {v: "https://example.com/concept/quantum", e: true}
|
||||
- entity: {t: i, i: "https://example.com/person/alice"}
|
||||
score: 0.95
|
||||
- entity: {t: i, i: "https://example.com/concept/quantum"}
|
||||
score: 0.82
|
||||
|
|
|
|||
|
|
@ -28,3 +28,23 @@ properties:
|
|||
description: Collection to query
|
||||
default: default
|
||||
example: research
|
||||
g:
|
||||
type: string
|
||||
description: |
|
||||
Named graph filter (optional).
|
||||
- Omitted/null: all graphs
|
||||
- Empty string: default graph only
|
||||
- URI string: specific named graph (e.g., urn:graph:source, urn:graph:retrieval)
|
||||
example: urn:graph:source
|
||||
streaming:
|
||||
type: boolean
|
||||
description: Enable streaming response delivery
|
||||
default: false
|
||||
example: true
|
||||
batch-size:
|
||||
type: integer
|
||||
description: Number of triples per streaming batch
|
||||
default: 20
|
||||
minimum: 1
|
||||
maximum: 1000
|
||||
example: 50
|
||||
|
|
|
|||
|
|
@ -1,13 +1,31 @@
|
|||
type: object
|
||||
description: Document RAG response
|
||||
description: Document RAG response message
|
||||
properties:
|
||||
message_type:
|
||||
type: string
|
||||
description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events
|
||||
enum: [chunk, explain]
|
||||
example: chunk
|
||||
response:
|
||||
type: string
|
||||
description: Generated response based on retrieved documents
|
||||
example: The research papers found three key findings...
|
||||
description: Generated response text (for chunk messages)
|
||||
example: Based on the policy documents, customers can return items within 30 days...
|
||||
explain_id:
|
||||
type: string
|
||||
description: Explainability node URI (for explain messages)
|
||||
example: urn:trustgraph:question:abc123
|
||||
explain_graph:
|
||||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
end-of-stream:
|
||||
type: boolean
|
||||
description: Indicates streaming is complete (streaming mode)
|
||||
description: Indicates LLM response stream is complete
|
||||
default: false
|
||||
example: true
|
||||
end_of_session:
|
||||
type: boolean
|
||||
description: Indicates entire session is complete (all messages sent)
|
||||
default: false
|
||||
example: true
|
||||
error:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,31 @@
|
|||
type: object
|
||||
description: Graph RAG response
|
||||
description: Graph RAG response message
|
||||
properties:
|
||||
message_type:
|
||||
type: string
|
||||
description: Type of message - "chunk" for LLM response chunks, "explain" for explainability events
|
||||
enum: [chunk, explain]
|
||||
example: chunk
|
||||
response:
|
||||
type: string
|
||||
description: Generated response based on retrieved knowledge graph
|
||||
description: Generated response text (for chunk messages)
|
||||
example: Quantum physics and computer science intersect in quantum computing...
|
||||
end-of-stream:
|
||||
explain_id:
|
||||
type: string
|
||||
description: Explainability node URI (for explain messages)
|
||||
example: urn:trustgraph:question:abc123
|
||||
explain_graph:
|
||||
type: string
|
||||
description: Named graph containing the explainability data
|
||||
example: urn:graph:retrieval
|
||||
end_of_stream:
|
||||
type: boolean
|
||||
description: Indicates streaming is complete (streaming mode)
|
||||
description: Indicates LLM response stream is complete
|
||||
default: false
|
||||
example: true
|
||||
end_of_session:
|
||||
type: boolean
|
||||
description: Indicates entire session is complete (all messages sent)
|
||||
default: false
|
||||
example: true
|
||||
error:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ openapi: 3.1.0
|
|||
|
||||
info:
|
||||
title: TrustGraph API Gateway
|
||||
version: "1.8"
|
||||
version: "2.1"
|
||||
description: |
|
||||
REST API for TrustGraph - an AI-powered knowledge graph and RAG system.
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ info:
|
|||
Require running flow instance, accessed via `/api/v1/flow/{flow}/service/{kind}`:
|
||||
- AI services: agent, text-completion, prompt, RAG (document/graph)
|
||||
- Embeddings: embeddings, graph-embeddings, document-embeddings
|
||||
- Query: triples, objects, nlp-query, structured-query
|
||||
- Query: triples, rows, nlp-query, structured-query, row-embeddings
|
||||
- Data loading: text-load, document-load
|
||||
- Utilities: mcp-tool, structured-diag
|
||||
|
||||
|
|
@ -140,6 +140,10 @@ paths:
|
|||
/api/v1/flow/{flow}/service/document-load:
|
||||
$ref: './paths/flow/document-load.yaml'
|
||||
|
||||
# Document streaming
|
||||
/api/v1/document-stream:
|
||||
$ref: './paths/document-stream.yaml'
|
||||
|
||||
# Import/Export endpoints
|
||||
/api/v1/import-core:
|
||||
$ref: './paths/import-core.yaml'
|
||||
|
|
|
|||
53
specs/api/paths/document-stream.yaml
Normal file
53
specs/api/paths/document-stream.yaml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
get:
|
||||
tags:
|
||||
- Import/Export
|
||||
summary: Stream document content from library
|
||||
description: |
|
||||
Streams the raw content of a document stored in the library.
|
||||
Returns the document content in chunked transfer encoding.
|
||||
|
||||
## Parameters
|
||||
|
||||
- `user`: User identifier (required)
|
||||
- `document-id`: Document IRI to retrieve (required)
|
||||
- `chunk-size`: Size of each response chunk in bytes (optional, default: 1MB)
|
||||
|
||||
operationId: documentStream
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: user
|
||||
in: query
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: User identifier
|
||||
example: trustgraph
|
||||
- name: document-id
|
||||
in: query
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: Document IRI to retrieve
|
||||
example: "urn:trustgraph:doc:abc123"
|
||||
- name: chunk-size
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
default: 1048576
|
||||
description: Chunk size in bytes (default 1MB)
|
||||
responses:
|
||||
'200':
|
||||
description: Document content streamed as raw bytes
|
||||
content:
|
||||
application/octet-stream:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
'400':
|
||||
description: Missing required parameters
|
||||
'401':
|
||||
$ref: '../components/responses/Unauthorized.yaml'
|
||||
'500':
|
||||
$ref: '../components/responses/Error.yaml'
|
||||
|
|
@ -24,7 +24,7 @@ echo
|
|||
# Build WebSocket API documentation
|
||||
echo "Building WebSocket API documentation (AsyncAPI)..."
|
||||
cd ../websocket
|
||||
npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template@3.0.0 --use-new-generator -o /tmp/asyncapi-build -p singleFile=true --force-write
|
||||
npx --yes -p @asyncapi/cli asyncapi generate fromTemplate asyncapi.yaml @asyncapi/html-template -o /tmp/asyncapi-build -p singleFile=true --force-write
|
||||
mv /tmp/asyncapi-build/index.html ../../docs/websocket.html
|
||||
rm -rf /tmp/asyncapi-build
|
||||
echo "✓ WebSocket API docs generated: docs/websocket.html"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ asyncapi: 3.0.0
|
|||
|
||||
info:
|
||||
title: TrustGraph WebSocket API
|
||||
version: "1.8"
|
||||
version: "2.1"
|
||||
description: |
|
||||
WebSocket API for TrustGraph - providing multiplexed, asynchronous access to all services.
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ info:
|
|||
**Flow-Hosted Services** (require `flow` parameter):
|
||||
- agent, text-completion, prompt, document-rag, graph-rag
|
||||
- embeddings, graph-embeddings, document-embeddings
|
||||
- triples, objects, nlp-query, structured-query, structured-diag
|
||||
- triples, rows, nlp-query, structured-query, structured-diag, row-embeddings
|
||||
- text-load, document-load, mcp-tool
|
||||
|
||||
## Schema Reuse
|
||||
|
|
|
|||
|
|
@ -95,8 +95,7 @@ def sample_message_data():
|
|||
"Metadata": {
|
||||
"id": "test-doc-123",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection",
|
||||
"metadata": []
|
||||
"collection": "test_collection"
|
||||
},
|
||||
"Term": {
|
||||
"type": IRI,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Ensures that message formats remain consistent across services
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, Error
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, ChunkMatch, Error
|
||||
from trustgraph.messaging.translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator,
|
||||
DocumentEmbeddingsResponseTranslator
|
||||
|
|
@ -20,20 +20,20 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
"""Test that DocumentEmbeddingsRequest has expected fields"""
|
||||
# Create a request
|
||||
request = DocumentEmbeddingsRequest(
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify all expected fields exist
|
||||
assert hasattr(request, 'vectors')
|
||||
assert hasattr(request, 'vector')
|
||||
assert hasattr(request, 'limit')
|
||||
assert hasattr(request, 'user')
|
||||
assert hasattr(request, 'collection')
|
||||
|
||||
|
||||
# Verify field values
|
||||
assert request.vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
assert request.vector == [0.1, 0.2, 0.3]
|
||||
assert request.limit == 10
|
||||
assert request.user == "test_user"
|
||||
assert request.collection == "test_collection"
|
||||
|
|
@ -41,18 +41,18 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
def test_request_translator_to_pulsar(self):
|
||||
"""Test request translator converts dict to Pulsar schema"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
||||
data = {
|
||||
"vectors": [[0.1, 0.2], [0.3, 0.4]],
|
||||
"vector": [0.1, 0.2, 0.3, 0.4],
|
||||
"limit": 5,
|
||||
"user": "custom_user",
|
||||
"collection": "custom_collection"
|
||||
}
|
||||
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
||||
assert result.vector == [0.1, 0.2, 0.3, 0.4]
|
||||
assert result.limit == 5
|
||||
assert result.user == "custom_user"
|
||||
assert result.collection == "custom_collection"
|
||||
|
|
@ -60,16 +60,16 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
def test_request_translator_to_pulsar_with_defaults(self):
|
||||
"""Test request translator uses correct defaults"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
||||
data = {
|
||||
"vectors": [[0.1, 0.2]]
|
||||
"vector": [0.1, 0.2]
|
||||
# No limit, user, or collection provided
|
||||
}
|
||||
|
||||
|
||||
result = translator.to_pulsar(data)
|
||||
|
||||
|
||||
assert isinstance(result, DocumentEmbeddingsRequest)
|
||||
assert result.vectors == [[0.1, 0.2]]
|
||||
assert result.vector == [0.1, 0.2]
|
||||
assert result.limit == 10 # Default
|
||||
assert result.user == "trustgraph" # Default
|
||||
assert result.collection == "default" # Default
|
||||
|
|
@ -77,18 +77,18 @@ class TestDocumentEmbeddingsRequestContract:
|
|||
def test_request_translator_from_pulsar(self):
|
||||
"""Test request translator converts Pulsar schema to dict"""
|
||||
translator = DocumentEmbeddingsRequestTranslator()
|
||||
|
||||
|
||||
request = DocumentEmbeddingsRequest(
|
||||
vectors=[[0.5, 0.6]],
|
||||
vector=[0.5, 0.6],
|
||||
limit=20,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
result = translator.from_pulsar(request)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["vectors"] == [[0.5, 0.6]]
|
||||
assert result["vector"] == [0.5, 0.6]
|
||||
assert result["limit"] == 20
|
||||
assert result["user"] == "test_user"
|
||||
assert result["collection"] == "test_collection"
|
||||
|
|
@ -102,16 +102,22 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
# Create a response with chunks
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=None,
|
||||
chunks=["chunk1", "chunk2", "chunk3"]
|
||||
chunks=[
|
||||
ChunkMatch(chunk_id="chunk1", score=0.9),
|
||||
ChunkMatch(chunk_id="chunk2", score=0.8),
|
||||
ChunkMatch(chunk_id="chunk3", score=0.7)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Verify all expected fields exist
|
||||
assert hasattr(response, 'error')
|
||||
assert hasattr(response, 'chunks')
|
||||
|
||||
|
||||
# Verify field values
|
||||
assert response.error is None
|
||||
assert response.chunks == ["chunk1", "chunk2", "chunk3"]
|
||||
assert len(response.chunks) == 3
|
||||
assert response.chunks[0].chunk_id == "chunk1"
|
||||
assert response.chunks[0].score == 0.9
|
||||
|
||||
def test_response_schema_with_error(self):
|
||||
"""Test response schema with error"""
|
||||
|
|
@ -119,52 +125,47 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
type="query_error",
|
||||
message="Database connection failed"
|
||||
)
|
||||
|
||||
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=error,
|
||||
chunks=None
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
|
||||
assert response.error == error
|
||||
assert response.chunks is None
|
||||
assert response.chunks == []
|
||||
|
||||
def test_response_translator_from_pulsar_with_chunks(self):
|
||||
"""Test response translator converts Pulsar schema with chunks to dict"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=None,
|
||||
chunks=["doc1", "doc2", "doc3"]
|
||||
chunks=[
|
||||
ChunkMatch(chunk_id="doc1/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc2/c2", score=0.85),
|
||||
ChunkMatch(chunk_id="doc3/c3", score=0.75)
|
||||
]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == ["doc1", "doc2", "doc3"]
|
||||
|
||||
def test_response_translator_from_pulsar_with_bytes(self):
|
||||
"""Test response translator handles byte chunks correctly"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
response = MagicMock()
|
||||
response.chunks = [b"byte_chunk1", b"byte_chunk2"]
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == ["byte_chunk1", "byte_chunk2"]
|
||||
assert len(result["chunks"]) == 3
|
||||
assert result["chunks"][0]["chunk_id"] == "doc1/c1"
|
||||
assert result["chunks"][0]["score"] == 0.95
|
||||
|
||||
def test_response_translator_from_pulsar_with_empty_chunks(self):
|
||||
"""Test response translator handles empty chunks list"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
response = MagicMock()
|
||||
response.chunks = []
|
||||
|
||||
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=None,
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == []
|
||||
|
|
@ -172,37 +173,41 @@ class TestDocumentEmbeddingsResponseContract:
|
|||
def test_response_translator_from_pulsar_with_none_chunks(self):
|
||||
"""Test response translator handles None chunks"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
||||
response = MagicMock()
|
||||
response.chunks = None
|
||||
|
||||
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" not in result or result.get("chunks") is None
|
||||
|
||||
def test_response_translator_from_response_with_completion(self):
|
||||
"""Test response translator with completion flag"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=None,
|
||||
chunks=["chunk1", "chunk2"]
|
||||
chunks=[
|
||||
ChunkMatch(chunk_id="chunk1", score=0.9),
|
||||
ChunkMatch(chunk_id="chunk2", score=0.8)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "chunks" in result
|
||||
assert result["chunks"] == ["chunk1", "chunk2"]
|
||||
assert len(result["chunks"]) == 2
|
||||
assert result["chunks"][0]["chunk_id"] == "chunk1"
|
||||
assert is_final is True # Document embeddings responses are always final
|
||||
|
||||
def test_response_translator_to_pulsar_not_implemented(self):
|
||||
"""Test that to_pulsar raises NotImplementedError for responses"""
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
translator.to_pulsar({"chunks": ["test"]})
|
||||
translator.to_pulsar({"chunks": [{"chunk_id": "test", "score": 0.9}]})
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsMessageCompatibility:
|
||||
|
|
@ -212,26 +217,29 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
"""Test complete request-response flow maintains data integrity"""
|
||||
# Create request
|
||||
request_data = {
|
||||
"vectors": [[0.1, 0.2, 0.3]],
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"limit": 5,
|
||||
"user": "test_user",
|
||||
"collection": "test_collection"
|
||||
}
|
||||
|
||||
|
||||
# Convert to Pulsar request
|
||||
req_translator = DocumentEmbeddingsRequestTranslator()
|
||||
pulsar_request = req_translator.to_pulsar(request_data)
|
||||
|
||||
|
||||
# Simulate service processing and creating response
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=None,
|
||||
chunks=["relevant chunk 1", "relevant chunk 2"]
|
||||
chunks=[
|
||||
ChunkMatch(chunk_id="doc1/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc2/c2", score=0.85)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Convert response back to dict
|
||||
resp_translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = resp_translator.from_pulsar(response)
|
||||
|
||||
|
||||
# Verify data integrity
|
||||
assert isinstance(pulsar_request, DocumentEmbeddingsRequest)
|
||||
assert isinstance(response_data, dict)
|
||||
|
|
@ -245,17 +253,18 @@ class TestDocumentEmbeddingsMessageCompatibility:
|
|||
type="vector_db_error",
|
||||
message="Collection not found"
|
||||
)
|
||||
|
||||
|
||||
response = DocumentEmbeddingsResponse(
|
||||
error=error,
|
||||
chunks=None
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
|
||||
# Convert response to dict
|
||||
translator = DocumentEmbeddingsResponseTranslator()
|
||||
response_data = translator.from_pulsar(response)
|
||||
|
||||
|
||||
# Verify error handling
|
||||
assert isinstance(response_data, dict)
|
||||
# The translator doesn't include error in the dict, only chunks
|
||||
assert "chunks" not in response_data or response_data.get("chunks") is None
|
||||
assert "chunks" in response_data
|
||||
assert response_data["chunks"] == []
|
||||
|
|
|
|||
|
|
@ -401,25 +401,6 @@ class TestMetadataMessageContracts:
|
|||
assert metadata.id == "test-doc-123"
|
||||
assert metadata.user == "test_user"
|
||||
assert metadata.collection == "test_collection"
|
||||
assert isinstance(metadata.metadata, list)
|
||||
|
||||
def test_metadata_with_triples_contract(self, sample_message_data):
|
||||
"""Test Metadata with embedded triples contract"""
|
||||
# Arrange
|
||||
triple = Triple(**sample_message_data["Triple"])
|
||||
metadata_data = {
|
||||
"id": "doc-with-triples",
|
||||
"user": "test_user",
|
||||
"collection": "test_collection",
|
||||
"metadata": [triple]
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(Metadata, metadata_data)
|
||||
|
||||
metadata = Metadata(**metadata_data)
|
||||
assert len(metadata.metadata) == 1
|
||||
assert metadata.metadata[0].s.iri == "http://example.com/subject"
|
||||
|
||||
def test_error_schema_contract(self):
|
||||
"""Test Error schema contract"""
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ class TestRowsCassandraContracts:
|
|||
id="test-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
test_object = ExtractedObject(
|
||||
|
|
@ -50,7 +49,6 @@ class TestRowsCassandraContracts:
|
|||
assert hasattr(test_object.metadata, 'id')
|
||||
assert hasattr(test_object.metadata, 'user')
|
||||
assert hasattr(test_object.metadata, 'collection')
|
||||
assert hasattr(test_object.metadata, 'metadata')
|
||||
|
||||
# Verify types
|
||||
assert isinstance(test_object.schema_name, str)
|
||||
|
|
@ -154,7 +152,6 @@ class TestRowsCassandraContracts:
|
|||
id="serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"field1": "value1", "field2": "123"}],
|
||||
|
|
@ -234,7 +231,6 @@ class TestRowsCassandraContracts:
|
|||
id="meta-001",
|
||||
user="user123", # -> keyspace
|
||||
collection="coll456", # -> partition key
|
||||
metadata=[{"key": "value"}]
|
||||
),
|
||||
schema_name="table789", # -> table name
|
||||
values=[{"field": "value"}],
|
||||
|
|
@ -262,7 +258,6 @@ class TestRowsCassandraContractsBatch:
|
|||
id="batch-doc-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
batch_object = ExtractedObject(
|
||||
|
|
@ -308,10 +303,9 @@ class TestRowsCassandraContractsBatch:
|
|||
test_metadata = Metadata(
|
||||
id="empty-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
collection="test_collection",
|
||||
)
|
||||
|
||||
|
||||
empty_batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="empty_schema",
|
||||
|
|
@ -332,9 +326,8 @@ class TestRowsCassandraContractsBatch:
|
|||
id="single-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
single_batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="customer_records",
|
||||
|
|
@ -362,12 +355,11 @@ class TestRowsCassandraContractsBatch:
|
|||
id="batch-serial-001",
|
||||
user="test_user",
|
||||
collection="test_coll",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[
|
||||
{"field1": "value1", "field2": "123"},
|
||||
{"field1": "value2", "field2": "456"},
|
||||
{"field1": "value2", "field2": "456"},
|
||||
{"field1": "value3", "field2": "789"}
|
||||
],
|
||||
confidence=0.92,
|
||||
|
|
@ -436,9 +428,8 @@ class TestRowsCassandraContractsBatch:
|
|||
id="partition-test-001",
|
||||
user="consistent_user", # Same keyspace
|
||||
collection="consistent_collection", # Same partition
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
batch_object = ExtractedObject(
|
||||
metadata=test_metadata,
|
||||
schema_name="partition_test",
|
||||
|
|
|
|||
|
|
@ -95,9 +95,8 @@ class TestStructuredDataSchemaContracts:
|
|||
id="structured-data-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
submission = StructuredDataSubmission(
|
||||
metadata=metadata,
|
||||
|
|
@ -121,9 +120,8 @@ class TestStructuredDataSchemaContracts:
|
|||
id="extracted-obj-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
|
|
@ -147,9 +145,8 @@ class TestStructuredDataSchemaContracts:
|
|||
id="extracted-batch-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act - create object with multiple values
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
|
|
@ -180,11 +177,10 @@ class TestStructuredDataSchemaContracts:
|
|||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="extracted-empty-001",
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act - create object with empty values array
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
|
|
@ -283,13 +279,12 @@ class TestStructuredEmbeddingsContracts:
|
|||
id="struct-embed-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
embedding = StructuredObjectEmbedding(
|
||||
metadata=metadata,
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
schema_name="customer_records",
|
||||
object_id="customer_123",
|
||||
field_embeddings={
|
||||
|
|
@ -301,7 +296,7 @@ class TestStructuredEmbeddingsContracts:
|
|||
# Assert
|
||||
assert embedding.schema_name == "customer_records"
|
||||
assert embedding.object_id == "customer_123"
|
||||
assert len(embedding.vectors) == 2
|
||||
assert len(embedding.vector) == 3
|
||||
assert len(embedding.field_embeddings) == 2
|
||||
assert "name" in embedding.field_embeddings
|
||||
|
||||
|
|
@ -313,7 +308,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_structured_data_submission_serialization(self):
|
||||
"""Test StructuredDataSubmission serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
submission_data = {
|
||||
"metadata": metadata,
|
||||
"format": "json",
|
||||
|
|
@ -328,7 +323,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_serialization(self):
|
||||
"""Test ExtractedObject serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -378,7 +373,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_batch_serialization(self):
|
||||
"""Test ExtractedObject batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
batch_object_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
@ -397,7 +392,7 @@ class TestStructuredDataSerializationContracts:
|
|||
def test_extracted_object_empty_batch_serialization(self):
|
||||
"""Test ExtractedObject empty batch serialization contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(id="test", user="user", collection="col", metadata=[])
|
||||
metadata = Metadata(id="test", user="user", collection="col")
|
||||
empty_batch_data = {
|
||||
"metadata": metadata,
|
||||
"schema_name": "test_schema",
|
||||
|
|
|
|||
|
|
@ -17,16 +17,18 @@ from trustgraph.messaging import TranslatorRegistry
|
|||
class TestRAGTranslatorCompletionFlags:
|
||||
"""Contract tests for RAG response translator completion flags"""
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
def test_graph_rag_translator_is_final_with_end_of_session_true(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
when end_of_session=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="A small domesticated mammal.",
|
||||
message_type="chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -34,20 +36,23 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
assert response_dict["response"] == "A small domesticated mammal."
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is True
|
||||
assert response_dict["message_type"] == "chunk"
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
def test_graph_rag_translator_is_final_with_end_of_session_false(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=False
|
||||
when end_of_stream=False.
|
||||
when end_of_session=False (even if end_of_stream=True).
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="Chunk 1",
|
||||
message_type="chunk",
|
||||
end_of_stream=False,
|
||||
end_of_session=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -55,20 +60,67 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
assert is_final is False, "is_final must be False when end_of_session=False"
|
||||
assert response_dict["response"] == "Chunk 1"
|
||||
assert response_dict["end_of_stream"] is False
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
def test_graph_rag_translator_provenance_message(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator handles provenance messages.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="",
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:session:abc123",
|
||||
end_of_stream=False,
|
||||
end_of_session=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
assert response_dict["message_type"] == "explain"
|
||||
assert response_dict["explain_id"] == "urn:trustgraph:session:abc123"
|
||||
|
||||
def test_graph_rag_translator_end_of_stream_not_final(self):
|
||||
"""
|
||||
Test that end_of_stream=True alone does NOT make is_final=True.
|
||||
The session continues with provenance messages after LLM stream completes.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="Final chunk",
|
||||
message_type="chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=False, # Session continues with provenance
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_session_true(self):
|
||||
"""
|
||||
Test that DocumentRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
when end_of_session=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="A document about cats.",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
|
|
@ -76,9 +128,31 @@ class TestRAGTranslatorCompletionFlags:
|
|||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert is_final is True, "is_final must be True when end_of_session=True"
|
||||
assert response_dict["response"] == "A document about cats."
|
||||
assert response_dict["end_of_session"] is True
|
||||
|
||||
def test_document_rag_translator_end_of_stream_not_final(self):
|
||||
"""
|
||||
Test that end_of_stream=True alone does NOT make is_final=True.
|
||||
The session continues with provenance messages after LLM stream completes.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="Final chunk",
|
||||
end_of_stream=True,
|
||||
end_of_session=False, # Session continues with provenance
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "end_of_stream=True should NOT make is_final=True"
|
||||
assert response_dict["end_of_stream"] is True
|
||||
assert response_dict["end_of_session"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response in JSONL format
|
||||
def mock_agent_response(recipient, question):
|
||||
def mock_agent_response(question):
|
||||
# Simulate agent processing and return structured JSONL response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
|
|
@ -76,13 +76,6 @@ class TestAgentKgExtractionIntegration:
|
|||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(
|
||||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -131,16 +124,12 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
# Get agent response (the mock returns a string directly)
|
||||
agent_client = flow("agent-request")
|
||||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
agent_response = agent_client.invoke(question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_jsonl(agent_response)
|
||||
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Add metadata triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
triples, entity_contexts, extracted_triples = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Emit outputs
|
||||
if triples:
|
||||
await extractor.emit_triples(flow("triples"), v.metadata, triples)
|
||||
|
|
@ -185,10 +174,6 @@ class TestAgentKgExtractionIntegration:
|
|||
label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL]
|
||||
assert len(label_triples) >= 2 # Should have labels for entities
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) >= 2 # Entities should be linked to document
|
||||
|
||||
# Verify entity contexts were emitted
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_called_once()
|
||||
|
|
@ -208,7 +193,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock agent error response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_error_response(recipient, question):
|
||||
def mock_error_response(question):
|
||||
# Simulate agent error by raising an exception
|
||||
raise RuntimeError("Agent processing failed")
|
||||
|
||||
|
|
@ -230,7 +215,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
def mock_invalid_json_response(question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
|
@ -242,9 +227,9 @@ class TestAgentKgExtractionIntegration:
|
|||
# Act - JSONL parsing is lenient, invalid lines are skipped
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should emit triples (with just metadata) but no entity contexts
|
||||
# Assert - with no valid extraction data, nothing is emitted
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
triples_publisher.send.assert_not_called()
|
||||
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
|
@ -255,7 +240,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock empty extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
def mock_empty_response(question):
|
||||
# Return empty JSONL (just empty/whitespace)
|
||||
return ''
|
||||
|
||||
|
|
@ -268,17 +253,12 @@ class TestAgentKgExtractionIntegration:
|
|||
# Act
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still emit outputs (even if empty) to maintain flow consistency
|
||||
# Assert - with empty extraction results, nothing is emitted
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
|
||||
# Triples should include metadata triples at minimum
|
||||
triples_publisher.send.assert_called_once()
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
|
||||
# Entity contexts should not be sent if empty
|
||||
|
||||
# No triples or entity contexts emitted for empty results
|
||||
triples_publisher.send.assert_not_called()
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -287,7 +267,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Arrange - mock malformed extraction response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
def mock_malformed_response(question):
|
||||
# JSONL with definition missing required field
|
||||
return '{"type": "definition", "entity": "Missing Definition"}'
|
||||
|
||||
|
|
@ -308,12 +288,12 @@ class TestAgentKgExtractionIntegration:
|
|||
test_text = "Test text for prompt rendering"
|
||||
chunk = Chunk(
|
||||
chunk=test_text.encode('utf-8'),
|
||||
metadata=Metadata(id="test-doc", metadata=[])
|
||||
metadata=Metadata(id="test-doc")
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def capture_prompt(recipient, question):
|
||||
def capture_prompt(question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '' # Empty JSONL response
|
||||
|
|
@ -340,13 +320,13 @@ class TestAgentKgExtractionIntegration:
|
|||
text = f"Test document {i} content"
|
||||
chunks.append(Chunk(
|
||||
chunk=text.encode('utf-8'),
|
||||
metadata=Metadata(id=f"doc{i}", metadata=[])
|
||||
metadata=Metadata(id=f"doc{i}")
|
||||
))
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
def mock_response(question):
|
||||
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
|
@ -375,12 +355,12 @@ class TestAgentKgExtractionIntegration:
|
|||
unicode_text = "Machine Learning (学习机器) は人工知能の一分野です。"
|
||||
chunk = Chunk(
|
||||
chunk=unicode_text.encode('utf-8'),
|
||||
metadata=Metadata(id="unicode-doc", metadata=[])
|
||||
metadata=Metadata(id="unicode-doc")
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_unicode_response(recipient, question):
|
||||
def mock_unicode_response(question):
|
||||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
|
|
@ -411,12 +391,12 @@ class TestAgentKgExtractionIntegration:
|
|||
large_text = "Machine Learning is important. " * 1000 # Repeat to create large text
|
||||
chunk = Chunk(
|
||||
chunk=large_text.encode('utf-8'),
|
||||
metadata=Metadata(id="large-doc", metadata=[])
|
||||
metadata=Metadata(id="large-doc")
|
||||
)
|
||||
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_large_text_response(recipient, question):
|
||||
def mock_large_text_response(question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'
|
||||
|
|
|
|||
|
|
@ -30,10 +30,13 @@ class TestAgentStructuredQueryIntegration:
|
|||
pulsar_client=AsyncMock(),
|
||||
max_iterations=3
|
||||
)
|
||||
|
||||
|
||||
# Mock the client method for structured query
|
||||
proc.client = MagicMock()
|
||||
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
proc.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
return proc
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -9,6 +9,15 @@ Following the TEST_STRATEGY.md approach for integration testing.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
CHUNK_CONTENT = {
|
||||
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -19,23 +28,35 @@ class TestDocumentRagIntegration:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||
client = AsyncMock()
|
||||
# New batch format: [[[vectors_for_text1], ...]]
|
||||
# One text input returns one vector set containing two vectors
|
||||
client.embed.return_value = [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
|
||||
[
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
|
||||
]
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns realistic document chunks"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
# Returns ChunkMatch objects with chunk_id and score
|
||||
client.query.return_value = [
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetch_chunk(self):
|
||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||
async def fetch(chunk_id, user):
|
||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
return fetch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
|
|
@ -48,17 +69,19 @@ class TestDocumentRagIntegration:
|
|||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||
mock_prompt_client, mock_fetch_chunk):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test complete DocumentRAG pipeline from query to response"""
|
||||
# Arrange
|
||||
|
|
@ -76,15 +99,16 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||
|
||||
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
||||
# Documents are fetched from librarian using chunk_ids
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query,
|
||||
documents=[
|
||||
|
|
@ -101,17 +125,19 @@ class TestDocumentRagIntegration:
|
|||
assert "artificial intelligence" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
|
@ -125,92 +151,98 @@ class TestDocumentRagIntegration:
|
|||
query="very obscure query",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when embeddings service fails"""
|
||||
# Arrange
|
||||
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "Embeddings service unavailable" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_not_called()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when document service fails"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "Document service connection failed" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk):
|
||||
"""Test DocumentRAG error handling when prompt service fails"""
|
||||
# Arrange
|
||||
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
|
||||
assert "LLM service rate limited" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG with various document limit configurations"""
|
||||
# Test different document limits
|
||||
test_cases = [1, 5, 10, 25, 50]
|
||||
|
||||
|
||||
for limit in test_cases:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
|
||||
# Act
|
||||
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
|
|
@ -230,14 +262,14 @@ class TestDocumentRagIntegration:
|
|||
for user, collection in test_scenarios:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
|
|
@ -245,19 +277,21 @@ class TestDocumentRagIntegration:
|
|||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
mock_fetch_chunk,
|
||||
caplog):
|
||||
"""Test DocumentRAG verbose logging functionality"""
|
||||
import logging
|
||||
|
||||
|
||||
# Arrange - Configure logging to capture debug messages
|
||||
caplog.set_level(logging.DEBUG)
|
||||
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
|
@ -269,25 +303,25 @@ class TestDocumentRagIntegration:
|
|||
assert "DocumentRag initialized" in log_messages
|
||||
assert "Constructing prompt..." in log_messages
|
||||
assert "Computing embeddings..." in log_messages
|
||||
assert "Getting documents..." in log_messages
|
||||
assert "chunks" in log_messages.lower()
|
||||
assert "Invoking LLM..." in log_messages
|
||||
assert "Query processing complete" in log_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG performance with large document retrieval"""
|
||||
# Arrange - Mock large document set (100 documents)
|
||||
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_doc_set
|
||||
# Arrange - Mock large chunk match set (100 chunks)
|
||||
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_chunk_matches
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
result = await document_rag.query("performance test query", doc_limit=100)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
|
|
@ -309,4 +343,4 @@ class TestDocumentRagIntegration:
|
|||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
|
|
|
|||
|
|
@ -8,12 +8,21 @@ response delivery through the complete pipeline.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
)
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
CHUNK_CONTENT = {
|
||||
"doc/c1": "Machine learning is a subset of AI.",
|
||||
"doc/c2": "Deep learning uses neural networks.",
|
||||
"doc/c3": "Supervised learning needs labeled data.",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDocumentRagStreaming:
|
||||
"""Integration tests for DocumentRAG streaming"""
|
||||
|
|
@ -22,20 +31,29 @@ class TestDocumentRagStreaming:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
# New batch format: [[[vectors_for_text1]]]
|
||||
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
# Returns ChunkMatch objects with chunk_id and score
|
||||
client.query.return_value = [
|
||||
"Machine learning is a subset of AI.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Supervised learning needs labeled data."
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90),
|
||||
ChunkMatch(chunk_id="doc/c3", score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetch_chunk(self):
|
||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||
async def fetch(chunk_id, user):
|
||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||
return fetch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||
"""Mock prompt client with streaming support"""
|
||||
|
|
@ -66,12 +84,13 @@ class TestDocumentRagStreaming:
|
|||
|
||||
@pytest.fixture
|
||||
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||
mock_streaming_prompt_client):
|
||||
mock_streaming_prompt_client, mock_fetch_chunk):
|
||||
"""Create DocumentRag instance with streaming support"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
|
|
@ -190,7 +209,7 @@ class TestDocumentRagStreaming:
|
|||
mock_doc_embeddings_client):
|
||||
"""Test streaming with no documents found"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ NOTE: This is the first integration test file for GraphRAG (previously had only
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -21,8 +22,12 @@ class TestGraphRagIntegration:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||
client = AsyncMock()
|
||||
# New batch format: [[[vectors_for_text1], ...]]
|
||||
# One text input returns one vector set containing one vector
|
||||
client.embed.return_value = [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
||||
[
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Vector for text
|
||||
]
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
@ -31,9 +36,9 @@ class TestGraphRagIntegration:
|
|||
"""Mock graph embeddings client that returns realistic entities"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"http://trustgraph.ai/e/machine-learning",
|
||||
"http://trustgraph.ai/e/artificial-intelligence",
|
||||
"http://trustgraph.ai/e/neural-networks"
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/artificial-intelligence"), score=0.90),
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/neural-networks"), score=0.85)
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
@ -43,7 +48,7 @@ class TestGraphRagIntegration:
|
|||
client = AsyncMock()
|
||||
|
||||
# Mock different queries return different triples
|
||||
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
||||
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
|
||||
# Mock label queries
|
||||
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||
if s == "http://trustgraph.ai/e/machine-learning":
|
||||
|
|
@ -71,18 +76,37 @@ class TestGraphRagIntegration:
|
|||
|
||||
return []
|
||||
|
||||
client.query.side_effect = query_side_effect
|
||||
client.query_stream.side_effect = query_stream_side_effect
|
||||
# Also mock query for label lookups (maybe_label uses query, not query_stream)
|
||||
client.query.side_effect = query_stream_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
"""Mock prompt client that generates realistic responses for two-step process"""
|
||||
client = AsyncMock()
|
||||
client.kg_prompt.return_value = (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
|
||||
# Mock responses for the multi-step process:
|
||||
# 1. extract-concepts extracts key concepts from the query
|
||||
# 2. kg-edge-scoring scores edges for relevance
|
||||
# 3. kg-edge-reasoning provides reasoning for selected edges
|
||||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
return ""
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -101,7 +125,7 @@ class TestGraphRagIntegration:
|
|||
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
|
||||
mock_graph_embeddings_client, mock_triples_client,
|
||||
mock_prompt_client):
|
||||
"""Test complete GraphRAG pipeline from query to response"""
|
||||
"""Test complete GraphRAG pipeline from query to response with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
user = "test_user"
|
||||
|
|
@ -109,41 +133,51 @@ class TestGraphRagIntegration:
|
|||
entity_limit = 50
|
||||
triple_limit = 30
|
||||
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit
|
||||
triple_limit=triple_limit,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
|
||||
# 1. Should compute embeddings for query
|
||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||
# 1. Should compute embeddings for query (now expects list of texts)
|
||||
mock_embeddings_client.embed.assert_called_once_with([query])
|
||||
|
||||
# 2. Should query graph embeddings to find relevant entities
|
||||
mock_graph_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_graph_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['vector'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
assert call_args.kwargs['limit'] == entity_limit
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query.call_count > 0
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt with knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert call_args.args[0] == query # First arg is query
|
||||
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 5
|
||||
for triples, prov_id in provenance_events:
|
||||
assert isinstance(triples, list)
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
|
||||
|
|
@ -197,21 +231,27 @@ class TestGraphRagIntegration:
|
|||
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
||||
# Arrange
|
||||
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||
mock_triples_client.query.return_value = [] # No triples found
|
||||
mock_triples_client.query_stream.return_value = [] # No triples found
|
||||
|
||||
# Collect provenance
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act
|
||||
result = await graph_rag.query(
|
||||
response = await graph_rag.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
collection="test_collection",
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should still call prompt client with empty knowledge graph
|
||||
mock_prompt_client.kg_prompt.assert_called_once()
|
||||
call_args = mock_prompt_client.kg_prompt.call_args
|
||||
assert isinstance(call_args.args[1], list) # kg should be a list
|
||||
assert result is not None
|
||||
# Should still call prompt client
|
||||
assert response is not None
|
||||
# Provenance should still be emitted (5 events)
|
||||
assert len(provenance_events) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
|
||||
|
|
@ -226,7 +266,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
first_call_count = mock_triples_client.query.call_count
|
||||
first_call_count = mock_triples_client.query_stream.call_count
|
||||
mock_triples_client.reset_mock()
|
||||
|
||||
# Second identical query
|
||||
|
|
@ -236,7 +276,7 @@ class TestGraphRagIntegration:
|
|||
collection="test_collection"
|
||||
)
|
||||
|
||||
second_call_count = mock_triples_client.query.call_count
|
||||
second_call_count = mock_triples_client.query_stream.call_count
|
||||
|
||||
# Assert - Second query should make fewer triple queries due to caching
|
||||
# Note: This is a weak assertion because caching behavior depends on
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ response delivery through the complete pipeline.
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -24,7 +25,8 @@ class TestGraphRagStreaming:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
# New batch format: [[[vectors_for_text1]]]
|
||||
client.embed.return_value = [[[0.1, 0.2, 0.3, 0.4, 0.5]]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -32,7 +34,7 @@ class TestGraphRagStreaming:
|
|||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"http://trustgraph.ai/e/machine-learning",
|
||||
EntityMatch(entity=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"), score=0.95),
|
||||
]
|
||||
return client
|
||||
|
||||
|
|
@ -51,30 +53,38 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||
"""Mock prompt client with streaming support"""
|
||||
"""Mock prompt client with streaming support for two-stage GraphRAG"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
# Both modes return the same text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
# Full synthesis text
|
||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
chunks.append(chunk)
|
||||
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return full_text
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -91,18 +101,25 @@ class TestGraphRagStreaming:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
|
||||
"""Test basic GraphRAG streaming functionality"""
|
||||
"""Test basic GraphRAG streaming functionality with real-time provenance"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
collector = streaming_chunk_collector()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
# Collect provenance events
|
||||
provenance_events = []
|
||||
|
||||
async def collect_provenance(triples, prov_id):
|
||||
provenance_events.append((triples, prov_id))
|
||||
|
||||
# Act - query() returns response, provenance via callback
|
||||
response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collector.collect
|
||||
chunk_callback=collector.collect,
|
||||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -114,10 +131,15 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Verify full response matches concatenated chunks
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert response == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert "machine" in result.lower() or "learning" in result.lower()
|
||||
assert "machine" in response.lower() or "learning" in response.lower()
|
||||
|
||||
# Verify provenance was emitted in real-time (5 events: question, grounding, exploration, focus, synthesis)
|
||||
assert len(provenance_events) == 5
|
||||
for triples, prov_id in provenance_events:
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
|
||||
|
|
@ -128,7 +150,7 @@ class TestGraphRagStreaming:
|
|||
collection = "test_collection"
|
||||
|
||||
# Act - Non-streaming
|
||||
non_streaming_result = await graph_rag_streaming.query(
|
||||
non_streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -141,7 +163,7 @@ class TestGraphRagStreaming:
|
|||
async def collect(chunk, end_of_stream):
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
streaming_result = await graph_rag_streaming.query(
|
||||
streaming_response = await graph_rag_streaming.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
@ -150,9 +172,9 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
assert streaming_response == non_streaming_response
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -161,7 +183,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -171,7 +193,7 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -181,7 +203,7 @@ class TestGraphRagStreaming:
|
|||
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
|
||||
"""Test streaming parameter without callback (should fall back to non-streaming)"""
|
||||
# Arrange & Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -190,8 +212,8 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
@ -202,7 +224,7 @@ class TestGraphRagStreaming:
|
|||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await graph_rag_streaming.query(
|
||||
response = await graph_rag_streaming.query(
|
||||
query="unknown topic",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
|
|
@ -211,7 +233,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
assert response is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -171,7 +171,6 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
triples_obj = Triples(
|
||||
metadata=Metadata(
|
||||
id=f"export-msg-{i}",
|
||||
metadata=to_subgraph(msg_data["metadata"]["metadata"]),
|
||||
user=msg_data["metadata"]["user"],
|
||||
collection=msg_data["metadata"]["collection"],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from trustgraph.extract.kg.relationships.extract import Processor as Relationshi
|
|||
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -92,7 +92,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
)
|
||||
|
|
@ -243,13 +242,12 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
entities = []
|
||||
|
||||
|
||||
for defn in sample_definitions_response:
|
||||
s = defn["entity"]
|
||||
o = defn["definition"]
|
||||
|
|
@ -302,12 +300,11 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
|
||||
|
||||
for rel in sample_relationships_response:
|
||||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
|
|
@ -373,7 +370,6 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
|
|
@ -406,12 +402,11 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri="http://example.org/entity"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -542,7 +537,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
]
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection", metadata=[]),
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
|
|
@ -569,7 +564,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection", metadata=[]),
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection"),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
|
|
@ -608,15 +603,8 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="doc:test"),
|
||||
p=Term(type=IRI, iri="dc:title"),
|
||||
o=Term(type=LITERAL, value="Test Document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=original_metadata,
|
||||
chunk=b"Test content for metadata propagation"
|
||||
|
|
|
|||
|
|
@ -231,7 +231,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
id="customer-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
chunk_text = """
|
||||
|
|
@ -299,7 +298,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
id="product-doc-001",
|
||||
user="integration_test",
|
||||
collection="test_documents",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
chunk_text = """
|
||||
|
|
@ -373,7 +371,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
id=chunk_id,
|
||||
user="concurrent_test",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
chunk = Chunk(metadata=metadata, chunk=text.encode('utf-8'))
|
||||
chunks.append(chunk)
|
||||
|
|
@ -470,7 +467,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
await processor.on_schema_config(integration_config, version=1)
|
||||
|
||||
# Create test chunk
|
||||
metadata = Metadata(id="error-test", user="test", collection="test", metadata=[])
|
||||
metadata = Metadata(id="error-test", user="test", collection="test")
|
||||
chunk = Chunk(metadata=metadata, chunk=b"Some text that will fail to process")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
|
|
@ -507,7 +504,6 @@ class TestObjectExtractionServiceIntegration:
|
|||
id="metadata-test-chunk",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[] # Could include source document metadata
|
||||
)
|
||||
|
||||
chunk = Chunk(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -18,14 +19,17 @@ class TestGraphRagStreamingProtocol:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_client(self):
|
||||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["entity1", "entity2"]
|
||||
client.query.return_value = [
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity1"), score=0.95),
|
||||
EntityMatch(entity=Term(type=IRI, iri="entity2"), score=0.90)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -40,18 +44,23 @@ class TestGraphRagStreamingProtocol:
|
|||
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -197,16 +206,26 @@ class TestDocumentRagStreamingProtocol:
|
|||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
client.embed.return_value = [[[0.1, 0.2, 0.3]]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client"""
|
||||
"""Mock document embeddings client that returns chunk matches"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["doc1", "doc2"]
|
||||
client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.95),
|
||||
ChunkMatch(chunk_id="doc/c2", score=0.90)
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetch_chunk(self):
|
||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||
async def fetch(chunk_id, user):
|
||||
return f"Content for {chunk_id}"
|
||||
return fetch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self):
|
||||
"""Mock prompt client with streaming support"""
|
||||
|
|
@ -227,12 +246,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||
mock_streaming_prompt_client):
|
||||
mock_streaming_prompt_client, mock_fetch_chunk):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
fetch_chunk=mock_fetch_chunk,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
|
@ -312,20 +332,24 @@ class TestStreamingProtocolEdgeCases:
|
|||
# Arrange
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
prompt_client=client,
|
||||
|
|
|
|||
|
|
@ -120,7 +120,6 @@ class TestRowsCassandraIntegration:
|
|||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values=[{
|
||||
|
|
@ -201,7 +200,7 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog"),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -209,7 +208,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales"),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
|
|
@ -254,7 +253,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
|
|
@ -337,7 +336,6 @@ class TestRowsCassandraIntegration:
|
|||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
|
|
@ -391,7 +389,7 @@ class TestRowsCassandraIntegration:
|
|||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty"),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
|
|
@ -426,7 +424,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
metadata=Metadata(id="t1", user="test", collection="test"),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
|
|
@ -470,7 +468,7 @@ class TestRowsCassandraIntegration:
|
|||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection"),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
@ -106,6 +109,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
@ -173,6 +179,9 @@ class TestAgentServiceNonStreaming:
|
|||
max_iterations=10
|
||||
)
|
||||
|
||||
# Mock librarian to avoid hanging on save operations
|
||||
processor.save_answer_content = AsyncMock(return_value=None)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
|
|
|
|||
495
tests/unit/test_agent/test_tool_service.py
Normal file
495
tests/unit/test_agent/test_tool_service.py
Normal file
|
|
@ -0,0 +1,495 @@
|
|||
"""
|
||||
Unit tests for Tool Service functionality
|
||||
|
||||
Tests the dynamically pluggable tool services feature including:
|
||||
- Tool service configuration parsing
|
||||
- ToolServiceImpl initialization
|
||||
- Request/response format
|
||||
- Config parameter handling
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
import json
|
||||
|
||||
|
||||
class TestToolServiceConfigParsing:
|
||||
"""Test cases for tool service configuration parsing"""
|
||||
|
||||
def test_tool_service_config_structure(self):
|
||||
"""Test that tool-service config has required fields"""
|
||||
# Arrange
|
||||
valid_config = {
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [
|
||||
{"name": "style", "required": False}
|
||||
]
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert "id" in valid_config
|
||||
assert "request-queue" in valid_config
|
||||
assert "response-queue" in valid_config
|
||||
assert valid_config["request-queue"].startswith("non-persistent://")
|
||||
assert valid_config["response-queue"].startswith("non-persistent://")
|
||||
|
||||
def test_tool_service_config_without_queues_is_invalid(self):
|
||||
"""Test that tool-service config requires request-queue and response-queue"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
"id": "joke-service",
|
||||
"config-params": []
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
def validate_config(config):
|
||||
request_queue = config.get("request-queue")
|
||||
response_queue = config.get("response-queue")
|
||||
if not request_queue or not response_queue:
|
||||
raise RuntimeError("Tool-service must define 'request-queue' and 'response-queue'")
|
||||
return True
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
validate_config(invalid_config)
|
||||
assert "request-queue" in str(exc_info.value)
|
||||
|
||||
def test_tool_config_references_tool_service(self):
|
||||
"""Test that tool config correctly references a tool-service"""
|
||||
# Arrange
|
||||
tool_services = {
|
||||
"joke-service": {
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [{"name": "style", "required": False}]
|
||||
}
|
||||
}
|
||||
|
||||
tool_config = {
|
||||
"type": "tool-service",
|
||||
"name": "tell-joke",
|
||||
"description": "Tell a joke on a given topic",
|
||||
"service": "joke-service",
|
||||
"style": "pun",
|
||||
"arguments": [
|
||||
{"name": "topic", "type": "string", "description": "The topic for the joke"}
|
||||
]
|
||||
}
|
||||
|
||||
# Act
|
||||
service_ref = tool_config.get("service")
|
||||
service_config = tool_services.get(service_ref)
|
||||
|
||||
# Assert
|
||||
assert service_ref == "joke-service"
|
||||
assert service_config is not None
|
||||
assert service_config["request-queue"] == "non-persistent://tg/request/joke"
|
||||
|
||||
def test_tool_config_extracts_config_values(self):
|
||||
"""Test that config values are extracted from tool config"""
|
||||
# Arrange
|
||||
tool_services = {
|
||||
"joke-service": {
|
||||
"id": "joke-service",
|
||||
"request-queue": "non-persistent://tg/request/joke",
|
||||
"response-queue": "non-persistent://tg/response/joke",
|
||||
"config-params": [
|
||||
{"name": "style", "required": False},
|
||||
{"name": "max-length", "required": False}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
tool_config = {
|
||||
"type": "tool-service",
|
||||
"name": "tell-joke",
|
||||
"description": "Tell a joke",
|
||||
"service": "joke-service",
|
||||
"style": "pun",
|
||||
"max-length": 100,
|
||||
"arguments": []
|
||||
}
|
||||
|
||||
# Act - simulate config extraction
|
||||
service_config = tool_services[tool_config["service"]]
|
||||
config_params = service_config.get("config-params", [])
|
||||
config_values = {}
|
||||
for param in config_params:
|
||||
param_name = param.get("name") if isinstance(param, dict) else param
|
||||
if param_name in tool_config:
|
||||
config_values[param_name] = tool_config[param_name]
|
||||
|
||||
# Assert
|
||||
assert config_values == {"style": "pun", "max-length": 100}
|
||||
|
||||
def test_required_config_param_validation(self):
|
||||
"""Test that required config params are validated"""
|
||||
# Arrange
|
||||
tool_services = {
|
||||
"custom-service": {
|
||||
"id": "custom-service",
|
||||
"request-queue": "non-persistent://tg/request/custom",
|
||||
"response-queue": "non-persistent://tg/response/custom",
|
||||
"config-params": [
|
||||
{"name": "collection", "required": True},
|
||||
{"name": "optional-param", "required": False}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
tool_config_missing_required = {
|
||||
"type": "tool-service",
|
||||
"name": "custom-tool",
|
||||
"description": "Custom tool",
|
||||
"service": "custom-service",
|
||||
# Missing required "collection" param
|
||||
"optional-param": "value"
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
def validate_config_params(tool_config, service_config):
|
||||
config_params = service_config.get("config-params", [])
|
||||
for param in config_params:
|
||||
param_name = param.get("name")
|
||||
if param.get("required", False) and param_name not in tool_config:
|
||||
raise RuntimeError(f"Missing required config param '{param_name}'")
|
||||
return True
|
||||
|
||||
service_config = tool_services["custom-service"]
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
validate_config_params(tool_config_missing_required, service_config)
|
||||
assert "collection" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestToolServiceRequest:
|
||||
"""Test cases for tool service request format"""
|
||||
|
||||
def test_request_format(self):
|
||||
"""Test that request is properly formatted with user, config, and arguments"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
config_values = {"style": "pun", "collection": "jokes"}
|
||||
arguments = {"topic": "programming"}
|
||||
|
||||
# Act - simulate request building
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values),
|
||||
"arguments": json.dumps(arguments)
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["user"] == "alice"
|
||||
assert json.loads(request["config"]) == {"style": "pun", "collection": "jokes"}
|
||||
assert json.loads(request["arguments"]) == {"topic": "programming"}
|
||||
|
||||
def test_request_with_empty_config(self):
|
||||
"""Test request when no config values are provided"""
|
||||
# Arrange
|
||||
user = "bob"
|
||||
config_values = {}
|
||||
arguments = {"query": "test"}
|
||||
|
||||
# Act
|
||||
request = {
|
||||
"user": user,
|
||||
"config": json.dumps(config_values) if config_values else "{}",
|
||||
"arguments": json.dumps(arguments) if arguments else "{}"
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert request["config"] == "{}"
|
||||
assert json.loads(request["arguments"]) == {"query": "test"}
|
||||
|
||||
|
||||
class TestToolServiceResponse:
|
||||
"""Test cases for tool service response handling"""
|
||||
|
||||
def test_success_response_handling(self):
|
||||
"""Test handling of successful response"""
|
||||
# Arrange
|
||||
response = {
|
||||
"error": None,
|
||||
"response": "Hey alice! Here's a pun for you:\n\nWhy do programmers prefer dark mode?",
|
||||
"end_of_stream": True
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert response["error"] is None
|
||||
assert "pun" in response["response"]
|
||||
assert response["end_of_stream"] is True
|
||||
|
||||
def test_error_response_handling(self):
|
||||
"""Test handling of error response"""
|
||||
# Arrange
|
||||
response = {
|
||||
"error": {
|
||||
"type": "tool-service-error",
|
||||
"message": "Service unavailable"
|
||||
},
|
||||
"response": "",
|
||||
"end_of_stream": True
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert response["error"] is not None
|
||||
assert response["error"]["type"] == "tool-service-error"
|
||||
assert response["error"]["message"] == "Service unavailable"
|
||||
|
||||
def test_string_response_passthrough(self):
|
||||
"""Test that string responses are passed through directly"""
|
||||
# Arrange
|
||||
response_text = "This is a joke response"
|
||||
|
||||
# Act - simulate response handling
|
||||
def handle_response(response):
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
else:
|
||||
return json.dumps(response)
|
||||
|
||||
result = handle_response(response_text)
|
||||
|
||||
# Assert
|
||||
assert result == response_text
|
||||
|
||||
def test_dict_response_json_serialization(self):
|
||||
"""Test that dict responses are JSON serialized"""
|
||||
# Arrange
|
||||
response_data = {"joke": "Why did the chicken cross the road?", "category": "classic"}
|
||||
|
||||
# Act
|
||||
def handle_response(response):
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
else:
|
||||
return json.dumps(response)
|
||||
|
||||
result = handle_response(response_data)
|
||||
|
||||
# Assert
|
||||
assert result == json.dumps(response_data)
|
||||
assert json.loads(result) == response_data
|
||||
|
||||
|
||||
class TestToolServiceImpl:
|
||||
"""Test cases for ToolServiceImpl class"""
|
||||
|
||||
def test_tool_service_impl_initialization(self):
|
||||
"""Test ToolServiceImpl stores queues and config correctly"""
|
||||
# Arrange
|
||||
class MockArgument:
|
||||
def __init__(self, name, type, description):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
|
||||
# Simulate ToolServiceImpl initialization
|
||||
class MockToolServiceImpl:
|
||||
def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None):
|
||||
self.context = context
|
||||
self.request_queue = request_queue
|
||||
self.response_queue = response_queue
|
||||
self.config_values = config_values or {}
|
||||
self.arguments = arguments or []
|
||||
self.processor = processor
|
||||
self._client = None
|
||||
|
||||
def get_arguments(self):
|
||||
return self.arguments
|
||||
|
||||
# Act
|
||||
arguments = [
|
||||
MockArgument("topic", "string", "The topic for the joke")
|
||||
]
|
||||
|
||||
impl = MockToolServiceImpl(
|
||||
context=lambda x: None,
|
||||
request_queue="non-persistent://tg/request/joke",
|
||||
response_queue="non-persistent://tg/response/joke",
|
||||
config_values={"style": "pun"},
|
||||
arguments=arguments,
|
||||
processor=Mock()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert impl.request_queue == "non-persistent://tg/request/joke"
|
||||
assert impl.response_queue == "non-persistent://tg/response/joke"
|
||||
assert impl.config_values == {"style": "pun"}
|
||||
assert len(impl.get_arguments()) == 1
|
||||
assert impl.get_arguments()[0].name == "topic"
|
||||
|
||||
def test_tool_service_impl_client_caching(self):
|
||||
"""Test that client is cached and reused"""
|
||||
# Arrange
|
||||
client_key = "non-persistent://tg/request/joke|non-persistent://tg/response/joke"
|
||||
|
||||
# Simulate client caching behavior
|
||||
tool_service_clients = {}
|
||||
|
||||
def get_or_create_client(request_queue, response_queue, clients_cache):
|
||||
client_key = f"{request_queue}|{response_queue}"
|
||||
if client_key in clients_cache:
|
||||
return clients_cache[client_key], False # False = not created
|
||||
client = Mock()
|
||||
clients_cache[client_key] = client
|
||||
return client, True # True = newly created
|
||||
|
||||
# Act
|
||||
client1, created1 = get_or_create_client(
|
||||
"non-persistent://tg/request/joke",
|
||||
"non-persistent://tg/response/joke",
|
||||
tool_service_clients
|
||||
)
|
||||
client2, created2 = get_or_create_client(
|
||||
"non-persistent://tg/request/joke",
|
||||
"non-persistent://tg/response/joke",
|
||||
tool_service_clients
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert created1 is True
|
||||
assert created2 is False
|
||||
assert client1 is client2
|
||||
|
||||
|
||||
class TestJokeServiceLogic:
|
||||
"""Test cases for the joke service example"""
|
||||
|
||||
def test_topic_to_category_mapping(self):
|
||||
"""Test that topics are mapped to categories correctly"""
|
||||
# Arrange
|
||||
def map_topic_to_category(topic):
|
||||
topic = topic.lower()
|
||||
if "program" in topic or "code" in topic or "computer" in topic or "software" in topic:
|
||||
return "programming"
|
||||
elif "llama" in topic:
|
||||
return "llama"
|
||||
elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic:
|
||||
return "animals"
|
||||
elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic:
|
||||
return "food"
|
||||
else:
|
||||
return "default"
|
||||
|
||||
# Act & Assert
|
||||
assert map_topic_to_category("programming") == "programming"
|
||||
assert map_topic_to_category("software engineering") == "programming"
|
||||
assert map_topic_to_category("llamas") == "llama"
|
||||
assert map_topic_to_category("llama") == "llama"
|
||||
assert map_topic_to_category("animals") == "animals"
|
||||
assert map_topic_to_category("my dog") == "animals"
|
||||
assert map_topic_to_category("food") == "food"
|
||||
assert map_topic_to_category("cooking recipes") == "food"
|
||||
assert map_topic_to_category("random topic") == "default"
|
||||
assert map_topic_to_category("") == "default"
|
||||
|
||||
def test_joke_response_personalization(self):
|
||||
"""Test that joke responses include user personalization"""
|
||||
# Arrange
|
||||
user = "alice"
|
||||
style = "pun"
|
||||
joke = "Why do programmers prefer dark mode? Because light attracts bugs!"
|
||||
|
||||
# Act
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
# Assert
|
||||
assert "Hey alice!" in response
|
||||
assert "pun" in response
|
||||
assert joke in response
|
||||
|
||||
def test_style_normalization(self):
|
||||
"""Test that invalid styles fall back to valid ones"""
|
||||
import random
|
||||
|
||||
# Arrange
|
||||
valid_styles = ["pun", "dad-joke", "one-liner"]
|
||||
|
||||
def normalize_style(style):
|
||||
if style not in valid_styles:
|
||||
return random.choice(valid_styles)
|
||||
return style
|
||||
|
||||
# Act & Assert
|
||||
assert normalize_style("pun") == "pun"
|
||||
assert normalize_style("dad-joke") == "dad-joke"
|
||||
assert normalize_style("one-liner") == "one-liner"
|
||||
assert normalize_style("invalid-style") in valid_styles
|
||||
assert normalize_style("") in valid_styles
|
||||
|
||||
|
||||
class TestDynamicToolServiceBase:
|
||||
"""Test cases for DynamicToolService base class behavior"""
|
||||
|
||||
def test_topic_to_pulsar_path_conversion(self):
|
||||
"""Test that topic names are converted to full Pulsar paths"""
|
||||
# Arrange
|
||||
topic = "joke"
|
||||
|
||||
# Act
|
||||
request_topic = f"non-persistent://tg/request/{topic}"
|
||||
response_topic = f"non-persistent://tg/response/{topic}"
|
||||
|
||||
# Assert
|
||||
assert request_topic == "non-persistent://tg/request/joke"
|
||||
assert response_topic == "non-persistent://tg/response/joke"
|
||||
|
||||
def test_request_parsing(self):
|
||||
"""Test parsing of incoming request"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
"user": "alice",
|
||||
"config": '{"style": "pun"}',
|
||||
"arguments": '{"topic": "programming"}'
|
||||
}
|
||||
|
||||
# Act
|
||||
user = request_data.get("user", "trustgraph")
|
||||
config = json.loads(request_data["config"]) if request_data["config"] else {}
|
||||
arguments = json.loads(request_data["arguments"]) if request_data["arguments"] else {}
|
||||
|
||||
# Assert
|
||||
assert user == "alice"
|
||||
assert config == {"style": "pun"}
|
||||
assert arguments == {"topic": "programming"}
|
||||
|
||||
def test_response_building(self):
|
||||
"""Test building of response message"""
|
||||
# Arrange
|
||||
response_text = "Hey alice! Here's a joke"
|
||||
error = None
|
||||
|
||||
# Act
|
||||
response = {
|
||||
"error": error,
|
||||
"response": response_text if isinstance(response_text, str) else json.dumps(response_text),
|
||||
"end_of_stream": True
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert response["error"] is None
|
||||
assert response["response"] == "Hey alice! Here's a joke"
|
||||
assert response["end_of_stream"] is True
|
||||
|
||||
def test_error_response_building(self):
|
||||
"""Test building of error response"""
|
||||
# Arrange
|
||||
error_message = "Service temporarily unavailable"
|
||||
|
||||
# Act
|
||||
response = {
|
||||
"error": {
|
||||
"type": "tool-service-error",
|
||||
"message": error_message
|
||||
},
|
||||
"response": "",
|
||||
"end_of_stream": True
|
||||
}
|
||||
|
||||
# Assert
|
||||
assert response["error"]["type"] == "tool-service-error"
|
||||
assert response["error"]["message"] == error_message
|
||||
assert response["response"] == ""
|
||||
624
tests/unit/test_agent/test_tool_service_lifecycle.py
Normal file
624
tests/unit/test_agent/test_tool_service_lifecycle.py
Normal file
|
|
@ -0,0 +1,624 @@
|
|||
"""
|
||||
Tests for tool service lifecycle, invoke contract, streaming responses,
|
||||
multi-tenancy, and error propagation.
|
||||
|
||||
Tests the actual DynamicToolService, ToolService, and ToolServiceClient
|
||||
classes rather than plain dicts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
ToolServiceRequest, ToolServiceResponse, Error,
|
||||
ToolRequest, ToolResponse,
|
||||
)
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DynamicToolService tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDynamicToolServiceInvokeContract:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_invoke_raises_not_implemented(self):
|
||||
"""Base class invoke() should raise NotImplementedError."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await svc.invoke("user", {}, {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_calls_invoke_with_parsed_args(self):
|
||||
"""on_request should JSON-parse config/arguments and pass to invoke."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
calls = []
|
||||
|
||||
async def tracking_invoke(user, config, arguments):
|
||||
calls.append({"user": user, "config": config, "arguments": arguments})
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking_invoke
|
||||
|
||||
# Ensure the class-level metric exists
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user="alice",
|
||||
config='{"style": "pun"}',
|
||||
arguments='{"topic": "cats"}',
|
||||
)
|
||||
msg.properties.return_value = {"id": "req-1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["user"] == "alice"
|
||||
assert calls[0]["config"] == {"style": "pun"}
|
||||
assert calls[0]["arguments"] == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_empty_user_defaults_to_trustgraph(self):
|
||||
"""Empty user field should default to 'trustgraph'."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
received_user = None
|
||||
|
||||
async def capture_invoke(user, config, arguments):
|
||||
nonlocal received_user
|
||||
received_user = user
|
||||
return "ok"
|
||||
|
||||
svc.invoke = capture_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="", config="", arguments="")
|
||||
msg.properties.return_value = {"id": "req-2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert received_user == "trustgraph"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_string_response_sent_directly(self):
|
||||
"""String return from invoke → response field is the string."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def string_invoke(user, config, arguments):
|
||||
return "hello world"
|
||||
|
||||
svc.invoke = string_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
sent = svc.producer.send.call_args[0][0]
|
||||
assert isinstance(sent, ToolServiceResponse)
|
||||
assert sent.response == "hello world"
|
||||
assert sent.end_of_stream is True
|
||||
assert sent.error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_dict_response_json_encoded(self):
|
||||
"""Dict return from invoke → response field is JSON-encoded."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def dict_invoke(user, config, arguments):
|
||||
return {"result": 42}
|
||||
|
||||
svc.invoke = dict_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
sent = svc.producer.send.call_args[0][0]
|
||||
assert json.loads(sent.response) == {"result": 42}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_error_sends_error_response(self):
|
||||
"""Exception in invoke → error response sent."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def failing_invoke(user, config, arguments):
|
||||
raise ValueError("bad input")
|
||||
|
||||
svc.invoke = failing_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
sent = svc.producer.send.call_args[0][0]
|
||||
assert sent.error is not None
|
||||
assert sent.error.type == "tool-service-error"
|
||||
assert "bad input" in sent.error.message
|
||||
assert sent.response == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_too_many_requests_propagates(self):
|
||||
"""TooManyRequests should propagate (not caught as error response)."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def rate_limited_invoke(user, config, arguments):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
svc.invoke = rate_limited_invoke
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "r4"}
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_request_preserves_message_id(self):
|
||||
"""Response should include the original message id in properties."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test-svc"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
async def ok_invoke(user, config, arguments):
|
||||
return "ok"
|
||||
|
||||
svc.invoke = ok_invoke
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(user="u", config="{}", arguments="{}")
|
||||
msg.properties.return_value = {"id": "unique-42"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
props = svc.producer.send.call_args[1]["properties"]
|
||||
assert props["id"] == "unique-42"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolService (flow-based) tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolServiceOnRequest:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_response_sent_as_text(self):
|
||||
"""String return from invoke_tool → ToolResponse.text is set."""
|
||||
from trustgraph.base.tool_service import ToolService
|
||||
|
||||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
return "tool result"
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
||||
if not hasattr(ToolService, "tool_invocation_metric"):
|
||||
ToolService.tool_invocation_metric = MagicMock()
|
||||
|
||||
mock_response_pub = AsyncMock()
|
||||
flow = MagicMock()
|
||||
flow.name = "test-flow"
|
||||
|
||||
def flow_callable(name):
|
||||
if name == "response":
|
||||
return mock_response_pub
|
||||
return MagicMock()
|
||||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters='{"key": "val"}')
|
||||
msg.properties.return_value = {"id": "t1"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), flow_callable)
|
||||
|
||||
sent = mock_response_pub.send.call_args[0][0]
|
||||
assert isinstance(sent, ToolResponse)
|
||||
assert sent.text == "tool result"
|
||||
assert sent.object is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dict_response_sent_as_json_object(self):
|
||||
"""Dict return from invoke_tool → ToolResponse.object is JSON."""
|
||||
from trustgraph.base.tool_service import ToolService
|
||||
|
||||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def mock_invoke(name, params):
|
||||
return {"data": [1, 2, 3]}
|
||||
|
||||
svc.invoke_tool = mock_invoke
|
||||
|
||||
if not hasattr(ToolService, "tool_invocation_metric"):
|
||||
ToolService.tool_invocation_metric = MagicMock()
|
||||
|
||||
mock_response_pub = AsyncMock()
|
||||
flow = MagicMock()
|
||||
|
||||
def flow_callable(name):
|
||||
if name == "response":
|
||||
return mock_response_pub
|
||||
return MagicMock()
|
||||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
msg.properties.return_value = {"id": "t2"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), flow_callable)
|
||||
|
||||
sent = mock_response_pub.send.call_args[0][0]
|
||||
assert sent.text is None
|
||||
assert json.loads(sent.object) == {"data": [1, 2, 3]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_sends_error_response(self):
|
||||
"""Exception in invoke_tool → error response via flow producer."""
|
||||
from trustgraph.base.tool_service import ToolService
|
||||
|
||||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def failing_invoke(name, params):
|
||||
raise RuntimeError("tool broke")
|
||||
|
||||
svc.invoke_tool = failing_invoke
|
||||
|
||||
mock_response_pub = AsyncMock()
|
||||
flow = MagicMock()
|
||||
|
||||
def flow_callable(name):
|
||||
return MagicMock()
|
||||
|
||||
flow_callable.producer = {"response": mock_response_pub}
|
||||
flow_callable.name = "test-flow"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
msg.properties.return_value = {"id": "t3"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), flow_callable)
|
||||
|
||||
sent = mock_response_pub.send.call_args[0][0]
|
||||
assert sent.error is not None
|
||||
assert sent.error.type == "tool-error"
|
||||
assert "tool broke" in sent.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_many_requests_propagates(self):
|
||||
"""TooManyRequests should propagate from ToolService.on_request."""
|
||||
from trustgraph.base.tool_service import ToolService
|
||||
|
||||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
async def rate_limited(name, params):
|
||||
raise TooManyRequests("slow down")
|
||||
|
||||
svc.invoke_tool = rate_limited
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(name="my-tool", parameters="{}")
|
||||
msg.properties.return_value = {"id": "t4"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
flow.name = "test-flow"
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parameters_json_parsed(self):
|
||||
"""Parameters should be JSON-parsed before passing to invoke_tool."""
|
||||
from trustgraph.base.tool_service import ToolService
|
||||
|
||||
svc = ToolService.__new__(ToolService)
|
||||
svc.id = "test-tool"
|
||||
|
||||
received = {}
|
||||
|
||||
async def capture_invoke(name, params):
|
||||
received["name"] = name
|
||||
received["params"] = params
|
||||
return "ok"
|
||||
|
||||
svc.invoke_tool = capture_invoke
|
||||
|
||||
if not hasattr(ToolService, "tool_invocation_metric"):
|
||||
ToolService.tool_invocation_metric = MagicMock()
|
||||
|
||||
mock_pub = AsyncMock()
|
||||
flow = lambda name: mock_pub
|
||||
flow.producer = {"response": mock_pub}
|
||||
flow.name = "f"
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolRequest(
|
||||
name="search",
|
||||
parameters='{"query": "test", "limit": 10}',
|
||||
)
|
||||
msg.properties.return_value = {"id": "t5"}
|
||||
|
||||
await svc.on_request(msg, MagicMock(), flow)
|
||||
|
||||
assert received["name"] == "search"
|
||||
assert received["params"] == {"query": "test", "limit": 10}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolServiceClient tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolServiceClientCall:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_sends_request_and_returns_response(self):
|
||||
"""call() should send ToolServiceRequest and return response string."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="joke result", end_of_stream=True,
|
||||
))
|
||||
|
||||
result = await client.call(
|
||||
user="alice",
|
||||
config={"style": "pun"},
|
||||
arguments={"topic": "cats"},
|
||||
)
|
||||
|
||||
assert result == "joke result"
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, ToolServiceRequest)
|
||||
assert req.user == "alice"
|
||||
assert json.loads(req.config) == {"style": "pun"}
|
||||
assert json.loads(req.arguments) == {"topic": "cats"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_raises_on_error(self):
|
||||
"""call() should raise RuntimeError when response has error."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=Error(type="tool-service-error", message="service down"),
|
||||
response="",
|
||||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await client.call(user="u", config={}, arguments={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_empty_config_sends_empty_json(self):
|
||||
"""Empty config/arguments should be sent as '{}'."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config=None, arguments=None)
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.config == "{}"
|
||||
assert req.arguments == "{}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_passes_timeout(self):
|
||||
"""call() should forward timeout to underlying request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="u", config={}, arguments={}, timeout=30)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 30
|
||||
|
||||
|
||||
class TestToolServiceClientStreaming:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_streaming_collects_chunks(self):
|
||||
"""call_streaming should accumulate chunks and return full result."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
|
||||
# Simulate streaming: request() calls recipient with each chunk
|
||||
chunks = [
|
||||
ToolServiceResponse(error=None, response="chunk1", end_of_stream=False),
|
||||
ToolServiceResponse(error=None, response="chunk2", end_of_stream=True),
|
||||
]
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for chunk in chunks:
|
||||
done = await recipient(chunk)
|
||||
if done:
|
||||
break
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
received = []
|
||||
|
||||
async def callback(text):
|
||||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
assert result == "chunk1chunk2"
|
||||
assert received == ["chunk1", "chunk2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_streaming_raises_on_error(self):
|
||||
"""call_streaming should raise RuntimeError on error chunk."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
error_resp = ToolServiceResponse(
|
||||
error=Error(type="tool-service-error", message="stream failed"),
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
)
|
||||
await recipient(error_resp)
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
with pytest.raises(RuntimeError, match="stream failed"):
|
||||
await client.call_streaming(
|
||||
user="u", config={}, arguments={},
|
||||
callback=AsyncMock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_streaming_skips_empty_response(self):
|
||||
"""Empty response chunks should not be added to result."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
|
||||
chunks = [
|
||||
ToolServiceResponse(error=None, response="", end_of_stream=False),
|
||||
ToolServiceResponse(error=None, response="data", end_of_stream=True),
|
||||
]
|
||||
|
||||
async def mock_request(req, timeout=600, recipient=None):
|
||||
for chunk in chunks:
|
||||
done = await recipient(chunk)
|
||||
if done:
|
||||
break
|
||||
|
||||
client.request = mock_request
|
||||
|
||||
received = []
|
||||
|
||||
async def callback(text):
|
||||
received.append(text)
|
||||
|
||||
result = await client.call_streaming(
|
||||
user="u", config={}, arguments={}, callback=callback,
|
||||
)
|
||||
|
||||
# Empty response is falsy, so callback shouldn't be called for it
|
||||
assert result == "data"
|
||||
assert received == ["data"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tenancy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiTenancy:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_propagated_to_invoke(self):
|
||||
"""User from request should reach the invoke method."""
|
||||
from trustgraph.base.dynamic_tool_service import DynamicToolService
|
||||
|
||||
svc = DynamicToolService.__new__(DynamicToolService)
|
||||
svc.id = "test"
|
||||
svc.producer = AsyncMock()
|
||||
|
||||
users_seen = []
|
||||
|
||||
async def tracking(user, config, arguments):
|
||||
users_seen.append(user)
|
||||
return "ok"
|
||||
|
||||
svc.invoke = tracking
|
||||
|
||||
if not hasattr(DynamicToolService, "tool_service_metric"):
|
||||
DynamicToolService.tool_service_metric = MagicMock()
|
||||
|
||||
for u in ["tenant-a", "tenant-b", "tenant-c"]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = ToolServiceRequest(
|
||||
user=u, config="{}", arguments="{}",
|
||||
)
|
||||
msg.properties.return_value = {"id": f"req-{u}"}
|
||||
await svc.on_request(msg, MagicMock(), None)
|
||||
|
||||
assert users_seen == ["tenant-a", "tenant-b", "tenant-c"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_sends_user_in_request(self):
|
||||
"""ToolServiceClient.call should include user in request."""
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
|
||||
client = ToolServiceClient.__new__(ToolServiceClient)
|
||||
client.request = AsyncMock(return_value=ToolServiceResponse(
|
||||
error=None, response="ok",
|
||||
))
|
||||
|
||||
await client.call(user="isolated-tenant", config={}, arguments={})
|
||||
|
||||
req = client.request.call_args[0][0]
|
||||
assert req.user == "isolated-tenant"
|
||||
|
|
@ -23,27 +23,27 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
|
||||
# Mock the request method
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
|
||||
# Act
|
||||
result = await client.query(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.request.assert_called_once()
|
||||
call_args = client.request.call_args[0][0]
|
||||
assert isinstance(call_args, DocumentEmbeddingsRequest)
|
||||
assert call_args.vectors == vectors
|
||||
assert call_args.vector == vector
|
||||
assert call_args.limit == 10
|
||||
assert call_args.user == "test_user"
|
||||
assert call_args.collection == "test_collection"
|
||||
|
|
@ -63,7 +63,7 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||
await client.query(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -76,12 +76,12 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunks = []
|
||||
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
|
@ -94,11 +94,11 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunks = ["test_chunk"]
|
||||
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
client.request.assert_called_once()
|
||||
|
|
@ -116,15 +116,15 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunks = ["chunk1"]
|
||||
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
await client.query(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
timeout=60
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert client.request.call_args[1]["timeout"] == 60
|
||||
|
||||
|
|
@ -137,13 +137,13 @@ class TestDocumentEmbeddingsClient(IsolatedAsyncioTestCase):
|
|||
mock_response = MagicMock(spec=DocumentEmbeddingsResponse)
|
||||
mock_response.error = None
|
||||
mock_response.chunks = ["test_chunk"]
|
||||
|
||||
|
||||
client.request = AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.document_embeddings_client.logger') as mock_logger:
|
||||
result = await client.query(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = await client.query(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
mock_logger.debug.assert_called_once()
|
||||
assert "Document embeddings response" in str(mock_logger.debug.call_args)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ def sample_text_document():
|
|||
"""Sample document with moderate length text."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
|
|
@ -44,7 +43,6 @@ def long_text_document():
|
|||
"""Long document for testing multiple chunks."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-long",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
|
|
@ -61,7 +59,6 @@ def unicode_text_document():
|
|||
"""Document with various unicode characters."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-unicode",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
|
|
@ -87,7 +84,6 @@ def empty_text_document():
|
|||
"""Empty document for edge case testing."""
|
||||
metadata = Metadata(
|
||||
id="test-doc-empty",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,17 @@ class MockAsyncProcessor:
|
|||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
self.pubsub = MagicMock()
|
||||
self.taskgroup = params.get('taskgroup', MagicMock())
|
||||
|
||||
|
||||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Recursive chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -47,8 +51,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -79,8 +85,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 2000 # Should use overridden value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -111,8 +119,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 200 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -143,9 +153,11 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 1500 # Should use overridden value
|
||||
assert chunk_overlap == 150 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
|
|
@ -164,26 +176,31 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
mock_text_doc.document_id = "" # No librarian fetch needed
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
"output": mock_producer
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
|
|
@ -202,8 +219,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
|
|||
|
|
@ -17,13 +17,17 @@ class MockAsyncProcessor:
|
|||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
self.pubsub = MagicMock()
|
||||
self.taskgroup = params.get('taskgroup', MagicMock())
|
||||
|
||||
|
||||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Token chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -47,8 +51,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -79,8 +85,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 400 # Should use overridden value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -111,8 +119,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 25 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -143,9 +153,11 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 350 # Should use overridden value
|
||||
assert chunk_overlap == 30 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
|
|
@ -164,26 +176,31 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid librarian producer interactions
|
||||
processor.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
mock_text_doc.document_id = "" # No librarian fetch needed
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
"output": mock_producer
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
|
|
@ -206,8 +223,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
|
|
@ -235,8 +254,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_token_chunker_uses_different_defaults(self):
|
||||
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
|
||||
"""Test that token chunker has different defaults than recursive chunker"""
|
||||
# Arrange & Act
|
||||
config = {
|
||||
|
|
|
|||
|
|
@ -69,24 +69,24 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
|
||||
# Act
|
||||
result = client.request(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
limit=10,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2", "chunk3"]
|
||||
client.call.assert_called_once_with(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
limit=10,
|
||||
timeout=300
|
||||
)
|
||||
|
|
@ -101,18 +101,18 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = ["test_chunk"]
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
|
||||
# Act
|
||||
result = client.request(vectors=vectors)
|
||||
|
||||
result = client.request(vector=vector)
|
||||
|
||||
# Assert
|
||||
assert result == ["test_chunk"]
|
||||
client.call.assert_called_once_with(
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
limit=10,
|
||||
timeout=300
|
||||
)
|
||||
|
|
@ -127,10 +127,10 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = []
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = client.request(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
|
@ -144,10 +144,10 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = None
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
result = client.request(vectors=[[0.1, 0.2, 0.3]])
|
||||
|
||||
result = client.request(vector=[0.1, 0.2, 0.3])
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
|
@ -161,12 +161,12 @@ class TestSyncDocumentEmbeddingsClient:
|
|||
mock_response = MagicMock()
|
||||
mock_response.chunks = ["chunk1"]
|
||||
client.call = MagicMock(return_value=mock_response)
|
||||
|
||||
|
||||
# Act
|
||||
client.request(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
timeout=600
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert client.call.call_args[1]["timeout"] == 600
|
||||
1
tests/unit/test_concurrency/__init__.py
Normal file
1
tests/unit/test_concurrency/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
286
tests/unit/test_concurrency/test_consumer_concurrency.py
Normal file
286
tests/unit/test_concurrency/test_consumer_concurrency.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
"""
|
||||
Tests for Consumer concurrency: TaskGroup-based concurrent message processing,
|
||||
rate-limit retry with backpressure, and message acknowledgement.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.base.consumer import Consumer
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_consumer(
|
||||
concurrency=1,
|
||||
handler=None,
|
||||
rate_limit_retry_time=0.01,
|
||||
rate_limit_timeout=1,
|
||||
):
|
||||
"""Create a Consumer with mocked infrastructure."""
|
||||
taskgroup = MagicMock()
|
||||
flow = MagicMock()
|
||||
backend = MagicMock()
|
||||
schema = MagicMock()
|
||||
handler = handler or AsyncMock()
|
||||
|
||||
consumer = Consumer(
|
||||
taskgroup=taskgroup,
|
||||
flow=flow,
|
||||
backend=backend,
|
||||
topic="test-topic",
|
||||
subscriber="test-sub",
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
rate_limit_retry_time=rate_limit_retry_time,
|
||||
rate_limit_timeout=rate_limit_timeout,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
|
||||
return consumer
|
||||
|
||||
|
||||
def _make_msg():
|
||||
"""Create a mock Pulsar message."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrency configuration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConcurrencyConfiguration:
|
||||
|
||||
def test_default_concurrency_is_1(self):
|
||||
consumer = _make_consumer()
|
||||
assert consumer.concurrency == 1
|
||||
|
||||
def test_custom_concurrency(self):
|
||||
consumer = _make_consumer(concurrency=10)
|
||||
assert consumer.concurrency == 10
|
||||
|
||||
def test_concurrency_stored(self):
|
||||
for n in [1, 5, 20, 100]:
|
||||
consumer = _make_consumer(concurrency=n)
|
||||
assert consumer.concurrency == n
|
||||
|
||||
|
||||
class TestTaskGroupConcurrency:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_n_concurrent_tasks(self):
|
||||
"""consumer_run should create exactly N concurrent consume_from_queue tasks."""
|
||||
concurrency = 5
|
||||
consumer = _make_consumer(concurrency=concurrency)
|
||||
|
||||
# Track how many consume_from_queue calls are made
|
||||
call_count = 0
|
||||
original_running = True
|
||||
|
||||
async def mock_consume():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Wait a bit to let all tasks start, then signal stop
|
||||
await asyncio.sleep(0.05)
|
||||
consumer.running = False
|
||||
|
||||
consumer.consume_from_queue = mock_consume
|
||||
|
||||
# Mock the backend.create_consumer
|
||||
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Run consumer_run - it will create TaskGroup with N tasks
|
||||
consumer.running = True
|
||||
await consumer.consumer_run()
|
||||
|
||||
assert call_count == concurrency
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_concurrency_creates_one_task(self):
|
||||
"""With concurrency=1, only one consume_from_queue task is created."""
|
||||
consumer = _make_consumer(concurrency=1)
|
||||
call_count = 0
|
||||
|
||||
async def mock_consume():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
consumer.running = False
|
||||
|
||||
consumer.consume_from_queue = mock_consume
|
||||
consumer.backend.create_consumer = MagicMock(return_value=MagicMock())
|
||||
|
||||
consumer.running = True
|
||||
await consumer.consumer_run()
|
||||
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate-limit retry tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimitRetry:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_retries_then_succeeds(self):
|
||||
"""TooManyRequests should cause retry, then succeed on next attempt."""
|
||||
call_count = 0
|
||||
|
||||
async def handler_with_retry(msg, consumer_ref, flow):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise TooManyRequests("rate limited")
|
||||
# Second call succeeds
|
||||
|
||||
consumer = _make_consumer(
|
||||
handler=handler_with_retry,
|
||||
rate_limit_retry_time=0.01,
|
||||
)
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
assert call_count == 2
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_timeout_negative_acks(self):
|
||||
"""If rate limit retries exhaust the timeout, message is negative-acked."""
|
||||
async def always_rate_limited(msg, consumer_ref, flow):
|
||||
raise TooManyRequests("rate limited")
|
||||
|
||||
consumer = _make_consumer(
|
||||
handler=always_rate_limited,
|
||||
rate_limit_retry_time=0.01,
|
||||
rate_limit_timeout=0.05,
|
||||
)
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
|
||||
consumer.consumer.acknowledge.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_rate_limit_error_negative_acks_immediately(self):
|
||||
"""Non-TooManyRequests errors should negative-ack immediately (no retry)."""
|
||||
call_count = 0
|
||||
|
||||
async def failing_handler(msg, consumer_ref, flow):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise ValueError("bad data")
|
||||
|
||||
consumer = _make_consumer(handler=failing_handler)
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
assert call_count == 1
|
||||
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_message_acknowledged(self):
|
||||
"""Successfully processed messages are acknowledged."""
|
||||
consumer = _make_consumer(handler=AsyncMock())
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetricsIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_metric_on_success(self):
|
||||
consumer = _make_consumer(handler=AsyncMock())
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.record_time.return_value.__enter__ = MagicMock()
|
||||
mock_metrics.record_time.return_value.__exit__ = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("success")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_metric_on_failure(self):
|
||||
async def failing(msg, c, f):
|
||||
raise ValueError("fail")
|
||||
|
||||
consumer = _make_consumer(handler=failing)
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
mock_metrics = MagicMock()
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
mock_metrics.process.assert_called_once_with("error")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_metric_on_too_many_requests(self):
|
||||
call_count = 0
|
||||
|
||||
async def handler(msg, c, f):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise TooManyRequests("limited")
|
||||
|
||||
consumer = _make_consumer(
|
||||
handler=handler,
|
||||
rate_limit_retry_time=0.01,
|
||||
)
|
||||
mock_msg = _make_msg()
|
||||
consumer.consumer = MagicMock()
|
||||
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.record_time.return_value.__enter__ = MagicMock()
|
||||
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
|
||||
consumer.metrics = mock_metrics
|
||||
|
||||
await consumer.handle_one_from_queue(mock_msg)
|
||||
|
||||
mock_metrics.rate_limit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stop / running flag
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStopBehaviour:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self):
|
||||
consumer = _make_consumer()
|
||||
consumer.running = True
|
||||
|
||||
await consumer.stop()
|
||||
|
||||
assert consumer.running is False
|
||||
|
||||
def test_initial_running_state(self):
|
||||
consumer = _make_consumer()
|
||||
assert consumer.running is True
|
||||
136
tests/unit/test_concurrency/test_dispatcher_semaphore.py
Normal file
136
tests/unit/test_concurrency/test_dispatcher_semaphore.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Tests for MessageDispatcher semaphore-based concurrency enforcement.
|
||||
|
||||
Verifies that the dispatcher limits concurrent message processing to
|
||||
max_workers via asyncio.Semaphore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.rev_gateway.dispatcher import MessageDispatcher
|
||||
|
||||
|
||||
class TestSemaphoreEnforcement:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semaphore_limits_concurrent_processing(self):
|
||||
"""Only max_workers messages should be processed concurrently."""
|
||||
max_workers = 2
|
||||
dispatcher = MessageDispatcher(max_workers=max_workers)
|
||||
|
||||
concurrent_count = 0
|
||||
max_concurrent = 0
|
||||
processing_event = asyncio.Event()
|
||||
|
||||
async def slow_process(message):
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.05)
|
||||
concurrent_count -= 1
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = slow_process
|
||||
|
||||
# Launch more tasks than max_workers
|
||||
messages = [
|
||||
{"id": f"msg-{i}", "service": "test", "request": {}}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(dispatcher.handle_message(m))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# At no point should more than max_workers have been active
|
||||
assert max_concurrent <= max_workers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semaphore_value_matches_max_workers(self):
|
||||
for n in [1, 5, 20]:
|
||||
dispatcher = MessageDispatcher(max_workers=n)
|
||||
assert dispatcher.semaphore._value == n
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_tasks_tracked(self):
|
||||
"""Active tasks should be added/removed during processing."""
|
||||
dispatcher = MessageDispatcher(max_workers=5)
|
||||
|
||||
task_was_tracked = False
|
||||
|
||||
original_process = dispatcher._process_message
|
||||
|
||||
async def tracking_process(message):
|
||||
nonlocal task_was_tracked
|
||||
# During processing, our task should be in active_tasks
|
||||
if len(dispatcher.active_tasks) > 0:
|
||||
task_was_tracked = True
|
||||
return {"id": message.get("id"), "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = tracking_process
|
||||
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
)
|
||||
|
||||
assert task_was_tracked
|
||||
# After completion, task should be discarded
|
||||
assert len(dispatcher.active_tasks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semaphore_released_on_error(self):
|
||||
"""Semaphore should be released even if processing raises."""
|
||||
dispatcher = MessageDispatcher(max_workers=2)
|
||||
|
||||
async def failing_process(message):
|
||||
raise RuntimeError("process failed")
|
||||
|
||||
dispatcher._process_message = failing_process
|
||||
|
||||
# Should not deadlock — semaphore must be released on error
|
||||
with pytest.raises(RuntimeError):
|
||||
await dispatcher.handle_message(
|
||||
{"id": "test", "service": "test", "request": {}}
|
||||
)
|
||||
|
||||
# Semaphore should be back at max
|
||||
assert dispatcher.semaphore._value == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_worker_serializes_processing(self):
|
||||
"""With max_workers=1, messages are processed one at a time."""
|
||||
dispatcher = MessageDispatcher(max_workers=1)
|
||||
|
||||
order = []
|
||||
|
||||
async def ordered_process(message):
|
||||
msg_id = message["id"]
|
||||
order.append(f"start-{msg_id}")
|
||||
await asyncio.sleep(0.02)
|
||||
order.append(f"end-{msg_id}")
|
||||
return {"id": msg_id, "response": {"ok": True}}
|
||||
|
||||
dispatcher._process_message = ordered_process
|
||||
|
||||
messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)]
|
||||
tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# With semaphore=1, each message should complete before next starts
|
||||
# Check that no two "start" entries appear without an intervening "end"
|
||||
active = 0
|
||||
max_active = 0
|
||||
for event in order:
|
||||
if event.startswith("start"):
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
elif event.startswith("end"):
|
||||
active -= 1
|
||||
|
||||
assert max_active == 1
|
||||
268
tests/unit/test_concurrency/test_graph_rag_concurrency.py
Normal file
268
tests/unit/test_concurrency/test_graph_rag_concurrency.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
"""
|
||||
Tests for Graph RAG concurrent query execution.
|
||||
|
||||
Covers: execute_batch_triple_queries concurrent task spawning,
|
||||
exception handling in gather, and result aggregation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_query(
|
||||
triples_client=None,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=1000,
|
||||
max_path_length=2,
|
||||
):
|
||||
"""Create a Query object with mocked rag dependencies."""
|
||||
rag = MagicMock()
|
||||
rag.triples_client = triples_client or AsyncMock()
|
||||
rag.label_cache = LRUCacheWithTTL()
|
||||
|
||||
query = Query(
|
||||
rag=rag,
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
verbose=False,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
def _make_triple(s, p, o):
|
||||
"""Create a simple mock triple."""
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBatchTripleQueries:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_queries_per_entity(self):
|
||||
"""Each entity should generate 3 concurrent queries (s, p, o positions)."""
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
entities = ["entity-1"]
|
||||
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
|
||||
|
||||
assert client.query_stream.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_entities_multiply_queries(self):
|
||||
"""N entities should produce N*3 concurrent queries."""
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
entities = ["e1", "e2", "e3"]
|
||||
await query.execute_batch_triple_queries(entities, limit_per_entity=10)
|
||||
|
||||
assert client.query_stream.call_count == 9 # 3 * 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queries_executed_concurrently(self):
|
||||
"""All queries should run concurrently via asyncio.gather."""
|
||||
concurrent_count = 0
|
||||
max_concurrent = 0
|
||||
|
||||
async def tracking_query(**kwargs):
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.02)
|
||||
concurrent_count -= 1
|
||||
return []
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = tracking_query
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
entities = ["e1", "e2", "e3"]
|
||||
await query.execute_batch_triple_queries(entities, limit_per_entity=5)
|
||||
|
||||
# All 9 queries should have run concurrently
|
||||
assert max_concurrent == 9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_results_aggregated(self):
|
||||
"""Results from all queries should be combined into a single list."""
|
||||
triple_a = _make_triple("a", "p", "b")
|
||||
triple_b = _make_triple("c", "p", "d")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def alternating_results(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count % 2 == 0:
|
||||
return [triple_a]
|
||||
return [triple_b]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = alternating_results
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# 3 queries, alternating results
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_in_one_query_does_not_block_others(self):
|
||||
"""If one query raises, other results are still collected."""
|
||||
good_triple = _make_triple("a", "p", "b")
|
||||
call_count = 0
|
||||
|
||||
async def mixed_results(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
raise RuntimeError("query failed")
|
||||
return [good_triple]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = mixed_results
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# 3 queries: 2 succeed, 1 fails → 2 triples
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_results_filtered(self):
|
||||
"""None results from queries should be filtered out."""
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_none(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return None
|
||||
return [_make_triple("a", "p", "b")]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = sometimes_none
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# 3 queries: 1 returns None, 2 return triples
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_no_queries(self):
|
||||
"""Empty entity list should produce no queries."""
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries([], limit_per_entity=10)
|
||||
|
||||
assert result == []
|
||||
client.query_stream.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_params_correct(self):
|
||||
"""Each query should use correct s/p/o positions and params."""
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
entities = ["ent-1"]
|
||||
await query.execute_batch_triple_queries(entities, limit_per_entity=15)
|
||||
|
||||
calls = client.query_stream.call_args_list
|
||||
assert len(calls) == 3
|
||||
|
||||
# First call: s=entity, p=None, o=None
|
||||
assert calls[0].kwargs["s"] == "ent-1"
|
||||
assert calls[0].kwargs["p"] is None
|
||||
assert calls[0].kwargs["o"] is None
|
||||
assert calls[0].kwargs["limit"] == 15
|
||||
assert calls[0].kwargs["user"] == "test-user"
|
||||
assert calls[0].kwargs["collection"] == "test-collection"
|
||||
assert calls[0].kwargs["batch_size"] == 20
|
||||
|
||||
# Second call: s=None, p=entity, o=None
|
||||
assert calls[1].kwargs["s"] is None
|
||||
assert calls[1].kwargs["p"] == "ent-1"
|
||||
assert calls[1].kwargs["o"] is None
|
||||
|
||||
# Third call: s=None, p=None, o=entity
|
||||
assert calls[2].kwargs["s"] is None
|
||||
assert calls[2].kwargs["p"] is None
|
||||
assert calls[2].kwargs["o"] == "ent-1"
|
||||
|
||||
|
||||
class TestLRUCacheWithTTL:
|
||||
|
||||
def test_put_and_get(self):
|
||||
cache = LRUCacheWithTTL(max_size=10, ttl=60)
|
||||
cache.put("key1", "value1")
|
||||
assert cache.get("key1") == "value1"
|
||||
|
||||
def test_get_missing_returns_none(self):
|
||||
cache = LRUCacheWithTTL()
|
||||
assert cache.get("nonexistent") is None
|
||||
|
||||
def test_max_size_eviction(self):
|
||||
cache = LRUCacheWithTTL(max_size=2, ttl=60)
|
||||
cache.put("a", 1)
|
||||
cache.put("b", 2)
|
||||
cache.put("c", 3) # Should evict "a"
|
||||
assert cache.get("a") is None
|
||||
assert cache.get("b") == 2
|
||||
assert cache.get("c") == 3
|
||||
|
||||
def test_lru_order(self):
|
||||
cache = LRUCacheWithTTL(max_size=2, ttl=60)
|
||||
cache.put("a", 1)
|
||||
cache.put("b", 2)
|
||||
cache.get("a") # Access "a" — now "b" is LRU
|
||||
cache.put("c", 3) # Should evict "b"
|
||||
assert cache.get("a") == 1
|
||||
assert cache.get("b") is None
|
||||
assert cache.get("c") == 3
|
||||
|
||||
def test_ttl_expiration(self):
|
||||
cache = LRUCacheWithTTL(max_size=10, ttl=0) # TTL=0 means instant expiry
|
||||
cache.put("key", "value")
|
||||
# With TTL=0, any time check > 0 means expired
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
assert cache.get("key") is None
|
||||
|
||||
def test_update_existing_key(self):
|
||||
cache = LRUCacheWithTTL(max_size=10, ttl=60)
|
||||
cache.put("key", "v1")
|
||||
cache.put("key", "v2")
|
||||
assert cache.get("key") == "v2"
|
||||
|
|
@ -73,7 +73,6 @@ def sample_triples():
|
|||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
|
|
@ -93,12 +92,11 @@ def sample_graph_embeddings():
|
|||
id="test-doc-id",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
metadata=[]
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri="http://example.org/john"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
vector=[0.1, 0.2, 0.3]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,218 +12,184 @@ from trustgraph.decoding.pdf.pdf_decoder import Processor
|
|||
from trustgraph.schema import Document, TextDocument, Metadata
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
self.pubsub = MagicMock()
|
||||
self.taskgroup = params.get('taskgroup', MagicMock())
|
||||
|
||||
|
||||
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test PDF decoder processor functionality"""
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_processor_initialization(self, mock_flow_init):
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
"""Test PDF decoder processor initialization"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(Processor, 'register_specification') as mock_register:
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
mock_flow_init.assert_called_once()
|
||||
# Verify register_specification was called twice (consumer and producer)
|
||||
assert mock_register.call_count == 2
|
||||
|
||||
# Check consumer spec
|
||||
consumer_call = mock_register.call_args_list[0]
|
||||
consumer_spec = consumer_call[0][0]
|
||||
assert consumer_spec.name == "input"
|
||||
assert consumer_spec.schema == Document
|
||||
assert consumer_spec.handler == processor.on_message
|
||||
|
||||
# Check producer spec
|
||||
producer_call = mock_register.call_args_list[1]
|
||||
producer_spec = producer_call[0][0]
|
||||
assert producer_spec.name == "output"
|
||||
assert producer_spec.schema == TextDocument
|
||||
consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')]
|
||||
assert len(consumer_specs) >= 1
|
||||
assert consumer_specs[0].name == "input"
|
||||
assert consumer_specs[0].schema == Document
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_on_message_success(self, mock_flow_init, mock_pdf_loader_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
"""Test successful PDF processing"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
|
||||
# Mock PDF content
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
||||
|
||||
# Mock PyPDFLoader
|
||||
mock_loader = MagicMock()
|
||||
mock_page1 = MagicMock(page_content="Page 1 content")
|
||||
mock_page2 = MagicMock(page_content="Page 2 content")
|
||||
mock_loader.load.return_value = [mock_page1, mock_page2]
|
||||
mock_pdf_loader_class.return_value = mock_loader
|
||||
|
||||
|
||||
# Mock message
|
||||
mock_metadata = Metadata(id="test-doc")
|
||||
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
# Mock flow - needs to be a callable that returns an object with send method
|
||||
|
||||
# Mock flow - separate mocks for output and triples
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_flow = MagicMock(return_value=mock_output_flow)
|
||||
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify PyPDFLoader was called
|
||||
mock_pdf_loader_class.assert_called_once()
|
||||
mock_loader.load.assert_called_once()
|
||||
|
||||
# Verify output was sent for each page
|
||||
assert mock_output_flow.send.call_count == 2
|
||||
|
||||
# Check first page output
|
||||
first_call = mock_output_flow.send.call_args_list[0]
|
||||
first_output = first_call[0][0]
|
||||
assert isinstance(first_output, TextDocument)
|
||||
assert first_output.metadata == mock_metadata
|
||||
assert first_output.text == b"Page 1 content"
|
||||
|
||||
# Check second page output
|
||||
second_call = mock_output_flow.send.call_args_list[1]
|
||||
second_output = second_call[0][0]
|
||||
assert isinstance(second_output, TextDocument)
|
||||
assert second_output.metadata == mock_metadata
|
||||
assert second_output.text == b"Page 2 content"
|
||||
# Verify triples were sent for each page (provenance)
|
||||
assert mock_triples_flow.send.call_count == 2
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_on_message_empty_pdf(self, mock_flow_init, mock_pdf_loader_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
"""Test handling of empty PDF"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
|
||||
# Mock PDF content
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
||||
# Mock PyPDFLoader with no pages
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load.return_value = []
|
||||
mock_pdf_loader_class.return_value = mock_loader
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_metadata = Metadata(id="test-doc")
|
||||
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
# Mock flow - needs to be a callable that returns an object with send method
|
||||
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_flow = MagicMock(return_value=mock_output_flow)
|
||||
|
||||
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify PyPDFLoader was called
|
||||
mock_pdf_loader_class.assert_called_once()
|
||||
mock_loader.load.assert_called_once()
|
||||
|
||||
# Verify no output was sent
|
||||
mock_output_flow.send.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.chunking_service.Consumer')
|
||||
@patch('trustgraph.base.chunking_service.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.__init__')
|
||||
async def test_on_message_unicode_content(self, mock_flow_init, mock_pdf_loader_class):
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
||||
"""Test handling of unicode content in PDF"""
|
||||
# Arrange
|
||||
mock_flow_init.return_value = None
|
||||
|
||||
# Mock PDF content
|
||||
pdf_content = b"fake pdf content"
|
||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||
|
||||
# Mock PyPDFLoader with unicode content
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_page = MagicMock(page_content="Page with unicode: 你好世界 🌍")
|
||||
mock_loader.load.return_value = [mock_page]
|
||||
mock_pdf_loader_class.return_value = mock_loader
|
||||
|
||||
# Mock message
|
||||
|
||||
mock_metadata = Metadata(id="test-doc")
|
||||
mock_document = Document(metadata=mock_metadata, data=pdf_base64)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = mock_document
|
||||
|
||||
# Mock flow - needs to be a callable that returns an object with send method
|
||||
|
||||
# Mock flow - separate mocks for output and triples
|
||||
mock_output_flow = AsyncMock()
|
||||
mock_flow = MagicMock(return_value=mock_output_flow)
|
||||
|
||||
mock_triples_flow = AsyncMock()
|
||||
mock_flow = MagicMock(side_effect=lambda name: {
|
||||
"output": mock_output_flow,
|
||||
"triples": mock_triples_flow,
|
||||
}.get(name))
|
||||
|
||||
config = {
|
||||
'id': 'test-pdf-decoder',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
with patch.object(Processor, 'register_specification'):
|
||||
processor = Processor(**config)
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid waiting for librarian response
|
||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify output was sent
|
||||
mock_output_flow.send.assert_called_once()
|
||||
|
||||
# Check output
|
||||
call_args = mock_output_flow.send.call_args[0][0]
|
||||
assert isinstance(call_args, TextDocument)
|
||||
assert call_args.text == "Page with unicode: 你好世界 🌍".encode('utf-8')
|
||||
# PDF decoder now forwards document_id, chunker fetches content from librarian
|
||||
assert call_args.document_id == "test-doc/p1"
|
||||
assert call_args.text == b"" # Content stored in librarian, not inline
|
||||
|
||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||
def test_add_args(self, mock_parent_add_args):
|
||||
"""Test add_args calls parent method"""
|
||||
# Arrange
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Processor.launch')
|
||||
def test_run(self, mock_launch):
|
||||
"""Test run function"""
|
||||
# Act
|
||||
from trustgraph.decoding.pdf.pdf_decoder import run
|
||||
run()
|
||||
|
||||
# Assert
|
||||
mock_launch.assert_called_once_with("pdf-decoder",
|
||||
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n")
|
||||
mock_launch.assert_called_once_with("pdf-decoder",
|
||||
"\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n\nSupports both inline document data and fetching from librarian via Pulsar\nfor large documents.\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -305,9 +305,8 @@ class TestEntityCentricKnowledgeGraph:
|
|||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
|
||||
"""Test that g='*' returns quads from all graphs"""
|
||||
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
|
||||
def test_graph_none_returns_all_graphs(self, entity_kg):
|
||||
"""Test that g=None returns quads from all graphs"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
|
|
@ -320,7 +319,7 @@ class TestEntityCentricKnowledgeGraph:
|
|||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice', g=None)
|
||||
|
||||
# Should return quads from both graphs
|
||||
assert len(results) == 2
|
||||
|
|
@ -547,21 +546,21 @@ class TestServiceHelperFunctions:
|
|||
"""Test cases for helper functions in service.py"""
|
||||
|
||||
def test_create_term_with_uri_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='u'"""
|
||||
"""Test create_term creates IRI Term for term_type='u'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice', otype='u')
|
||||
term = create_term('http://example.org/Alice', term_type='u')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_with_literal_otype(self):
|
||||
"""Test create_term creates LITERAL Term for otype='l'"""
|
||||
"""Test create_term creates LITERAL Term for term_type='l'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
|
||||
term = create_term('Alice Smith', term_type='l', datatype='xsd:string', language='en')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
|
|
@ -569,14 +568,24 @@ class TestServiceHelperFunctions:
|
|||
assert term.language == 'en'
|
||||
|
||||
def test_create_term_with_triple_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='t'"""
|
||||
"""Test create_term creates TRIPLE Term for term_type='t' with valid JSON"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
from trustgraph.schema import TRIPLE, IRI
|
||||
import json
|
||||
|
||||
term = create_term('http://example.org/statement1', otype='t')
|
||||
# Valid JSON triple data
|
||||
triple_json = json.dumps({
|
||||
"s": {"type": "i", "iri": "http://example.org/Alice"},
|
||||
"p": {"type": "i", "iri": "http://example.org/knows"},
|
||||
"o": {"type": "i", "iri": "http://example.org/Bob"},
|
||||
})
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/statement1'
|
||||
term = create_term(triple_json, term_type='t')
|
||||
|
||||
assert term.type == TRIPLE
|
||||
assert term.triple is not None
|
||||
assert term.triple.s.type == IRI
|
||||
assert term.triple.s.iri == "http://example.org/Alice"
|
||||
|
||||
def test_create_term_heuristic_fallback_uri(self):
|
||||
"""Test create_term uses URL heuristic when otype not provided"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Tests for entity-centric KG write amplification, delete collection batching,
|
||||
in-partition filtering, and term type metadata round-trips.
|
||||
|
||||
Complements test_entity_centric_kg.py with deeper verification of the
|
||||
2-table schema mechanics.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra():
|
||||
"""Provide mocked Cassandra cluster, session, and BatchStatement."""
|
||||
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cls, \
|
||||
patch('trustgraph.direct.cassandra_kg.BatchStatement') as mock_batch_cls:
|
||||
|
||||
mock_cluster = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster.connect.return_value = mock_session
|
||||
mock_cls.return_value = mock_cluster
|
||||
|
||||
# Track batch.add calls per batch instance
|
||||
batches = []
|
||||
|
||||
def make_batch():
|
||||
batch = MagicMock()
|
||||
batch._adds = []
|
||||
original_add = batch.add
|
||||
|
||||
def tracking_add(stmt, params):
|
||||
batch._adds.append((stmt, params))
|
||||
|
||||
batch.add = tracking_add
|
||||
batches.append(batch)
|
||||
return batch
|
||||
|
||||
mock_batch_cls.side_effect = make_batch
|
||||
|
||||
yield {
|
||||
"cluster_cls": mock_cls,
|
||||
"cluster": mock_cluster,
|
||||
"session": mock_session,
|
||||
"batch_cls": mock_batch_cls,
|
||||
"batches": batches,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_kg(mock_cassandra):
|
||||
"""Create an EntityCentricKnowledgeGraph with mocked Cassandra."""
|
||||
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
|
||||
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_ks')
|
||||
return kg, mock_cassandra
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write amplification: row count verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWriteAmplification:
|
||||
|
||||
def test_uri_object_produces_4_entity_rows_plus_collection(self, entity_kg):
|
||||
"""URI object → S + P + O + G-if-non-default entity rows + 1 collection row."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/Alice',
|
||||
p='http://ex.org/knows',
|
||||
o='http://ex.org/Bob',
|
||||
g='http://ex.org/g1',
|
||||
otype='u',
|
||||
)
|
||||
|
||||
# Should be exactly one batch
|
||||
assert len(ctx["batches"]) == 1
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# 4 entity rows (S, P, O, G) + 1 collection row = 5
|
||||
assert len(batch._adds) == 5
|
||||
|
||||
# Check roles assigned
|
||||
roles = [params[2] for _, params in batch._adds if len(params) == 10]
|
||||
assert 'S' in roles
|
||||
assert 'P' in roles
|
||||
assert 'O' in roles
|
||||
assert 'G' in roles
|
||||
|
||||
def test_literal_object_produces_3_entity_rows(self, entity_kg):
|
||||
"""Literal object → S + P entity rows (no O row) + collection row."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/Alice',
|
||||
p='http://ex.org/name',
|
||||
o='Alice Smith',
|
||||
g=None, # default graph
|
||||
otype='l',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# S + P entity rows + 1 collection = 3 (no O row for literal, no G for default)
|
||||
assert len(batch._adds) == 3
|
||||
|
||||
roles = [params[2] for _, params in batch._adds if len(params) == 10]
|
||||
assert 'S' in roles
|
||||
assert 'P' in roles
|
||||
assert 'O' not in roles
|
||||
assert 'G' not in roles
|
||||
|
||||
def test_triple_otype_gets_object_entity_row(self, entity_kg):
|
||||
"""otype='t' (quoted triple) → object gets entity row like URI."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/p',
|
||||
o='{"s":{},"p":{},"o":{}}',
|
||||
g=None,
|
||||
otype='t',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# S + P + O entity rows + collection = 4 (no G for default graph)
|
||||
assert len(batch._adds) == 4
|
||||
|
||||
roles = [params[2] for _, params in batch._adds if len(params) == 10]
|
||||
assert 'O' in roles
|
||||
|
||||
def test_default_graph_no_g_row(self, entity_kg):
|
||||
"""Default graph (g=None) → no G entity row."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/p',
|
||||
o='http://ex.org/o',
|
||||
g=None,
|
||||
otype='u',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# S + P + O entity rows + collection = 4 (no G)
|
||||
assert len(batch._adds) == 4
|
||||
roles = [params[2] for _, params in batch._adds if len(params) == 10]
|
||||
assert 'G' not in roles
|
||||
|
||||
def test_non_default_graph_gets_g_row(self, entity_kg):
|
||||
"""Non-default graph → gets G entity row."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/p',
|
||||
o='http://ex.org/o',
|
||||
g='http://ex.org/graph1',
|
||||
otype='u',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# S + P + O + G entity rows + collection = 5
|
||||
assert len(batch._adds) == 5
|
||||
roles = [params[2] for _, params in batch._adds if len(params) == 10]
|
||||
assert 'G' in roles
|
||||
|
||||
def test_dtype_and_lang_passed_to_all_rows(self, entity_kg):
|
||||
"""dtype and lang should be stored in every entity row."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/label',
|
||||
o='thing',
|
||||
g=None,
|
||||
otype='l',
|
||||
dtype='xsd:string',
|
||||
lang='en',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
|
||||
# Check entity rows carry dtype and lang
|
||||
for _, params in batch._adds:
|
||||
if len(params) == 10:
|
||||
# Entity row: (collection, entity, role, p, otype, s, o, d, dtype, lang)
|
||||
assert params[8] == 'xsd:string'
|
||||
assert params[9] == 'en'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-partition filtering: get_os, get_spo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInPartitionFiltering:
|
||||
|
||||
def test_get_os_filters_by_object(self, entity_kg):
|
||||
"""get_os should filter results by matching object value."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
# Simulate rows returned from subject partition (all have same s)
|
||||
mock_rows = [
|
||||
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
|
||||
d='', otype='u', dtype='', lang='',
|
||||
s='http://ex.org/Alice'),
|
||||
MagicMock(p='http://ex.org/likes', o='http://ex.org/Charlie',
|
||||
d='', otype='u', dtype='', lang='',
|
||||
s='http://ex.org/Alice'),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
|
||||
results = kg.get_os('col', 'http://ex.org/Bob', 'http://ex.org/Alice')
|
||||
|
||||
# Only the Bob row should pass the filter
|
||||
assert len(results) == 1
|
||||
assert results[0].o == 'http://ex.org/Bob'
|
||||
assert results[0].p == 'http://ex.org/knows'
|
||||
|
||||
def test_get_os_returns_empty_when_no_match(self, entity_kg):
|
||||
"""get_os should return empty list when object doesn't match any row."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
|
||||
d='', otype='u', dtype='', lang='',
|
||||
s='http://ex.org/Alice'),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
|
||||
results = kg.get_os('col', 'http://ex.org/Charlie', 'http://ex.org/Alice')
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
def test_get_spo_filters_by_object(self, entity_kg):
|
||||
"""get_spo should filter results by matching object value."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(o='http://ex.org/Bob', d='', otype='u', dtype='', lang=''),
|
||||
MagicMock(o='http://ex.org/Charlie', d='', otype='u', dtype='', lang=''),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
|
||||
results = kg.get_spo(
|
||||
'col', 'http://ex.org/Alice', 'http://ex.org/knows',
|
||||
'http://ex.org/Bob',
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].o == 'http://ex.org/Bob'
|
||||
|
||||
def test_get_os_with_graph_filter(self, entity_kg):
|
||||
"""get_os with specific graph should filter both object and graph."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
|
||||
d='http://ex.org/g1', otype='u', dtype='', lang='',
|
||||
s='http://ex.org/Alice'),
|
||||
MagicMock(p='http://ex.org/knows', o='http://ex.org/Bob',
|
||||
d='http://ex.org/g2', otype='u', dtype='', lang='',
|
||||
s='http://ex.org/Alice'),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
|
||||
results = kg.get_os(
|
||||
'col', 'http://ex.org/Bob', 'http://ex.org/Alice',
|
||||
g='http://ex.org/g1',
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].g == 'http://ex.org/g1'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete collection batching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteCollectionBatching:
|
||||
|
||||
def test_extracts_unique_entities_from_quads(self, entity_kg):
|
||||
"""delete_collection should extract s, p, and URI o as entities."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/knows',
|
||||
o='http://ex.org/B', otype='u', dtype='', lang=''),
|
||||
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
|
||||
o='Alice', otype='l', dtype='', lang=''),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.delete_collection('col')
|
||||
|
||||
# Unique entities: A, knows, B, name (literal 'Alice' excluded)
|
||||
# The batches should include entity partition deletes
|
||||
all_adds = []
|
||||
for batch in ctx["batches"]:
|
||||
all_adds.extend(batch._adds)
|
||||
|
||||
# We expect entity deletes + collection row deletes + metadata delete
|
||||
# Just verify the function completes and calls execute
|
||||
assert ctx["session"].execute.called
|
||||
|
||||
def test_literal_objects_not_treated_as_entities(self, entity_kg):
|
||||
"""Literal objects (otype='l') should not get entity partition deletes."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(d='', s='http://ex.org/A', p='http://ex.org/name',
|
||||
o='Alice', otype='l', dtype='', lang=''),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.delete_collection('col')
|
||||
|
||||
# Entity partition deletes should only include A and name, not Alice
|
||||
entity_deletes = []
|
||||
for batch in ctx["batches"]:
|
||||
for _, params in batch._adds:
|
||||
if len(params) == 2: # delete_entity_partition takes (collection, entity)
|
||||
entity_deletes.append(params[1])
|
||||
|
||||
assert 'http://ex.org/A' in entity_deletes
|
||||
assert 'http://ex.org/name' in entity_deletes
|
||||
assert 'Alice' not in entity_deletes
|
||||
|
||||
def test_non_default_graph_treated_as_entity(self, entity_kg):
|
||||
"""Non-default graphs should get entity partition deletes."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(d='http://ex.org/g1', s='http://ex.org/A',
|
||||
p='http://ex.org/p', o='http://ex.org/B',
|
||||
otype='u', dtype='', lang=''),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.delete_collection('col')
|
||||
|
||||
entity_deletes = []
|
||||
for batch in ctx["batches"]:
|
||||
for _, params in batch._adds:
|
||||
if len(params) == 2:
|
||||
entity_deletes.append(params[1])
|
||||
|
||||
assert 'http://ex.org/g1' in entity_deletes
|
||||
|
||||
def test_empty_collection_delete_completes(self, entity_kg):
|
||||
"""Deleting an empty collection should not error."""
|
||||
kg, ctx = entity_kg
|
||||
|
||||
ctx["session"].execute.return_value = []
|
||||
ctx["batches"].clear()
|
||||
|
||||
# Should not raise
|
||||
kg.delete_collection('empty-col')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Term type metadata round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTermTypeMetadata:
|
||||
|
||||
def test_query_results_include_otype(self, entity_kg):
|
||||
"""Query results should include otype from Cassandra rows."""
|
||||
kg, ctx = entity_kg
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
mock_rows = [
|
||||
MagicMock(p='http://ex.org/name', o='Alice',
|
||||
d='', otype='l', dtype='xsd:string', lang='en',
|
||||
s='http://ex.org/Alice'),
|
||||
]
|
||||
ctx["session"].execute.return_value = mock_rows
|
||||
|
||||
results = kg.get_s('col', 'http://ex.org/Alice')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].otype == 'l'
|
||||
assert results[0].dtype == 'xsd:string'
|
||||
assert results[0].lang == 'en'
|
||||
|
||||
def test_auto_detect_otype_uri(self, entity_kg):
|
||||
"""Auto-detect should classify http:// as URI."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/p',
|
||||
o='http://ex.org/o',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
# Check otype in entity rows (position 4)
|
||||
for _, params in batch._adds:
|
||||
if len(params) == 10:
|
||||
assert params[4] == 'u'
|
||||
|
||||
def test_auto_detect_otype_literal(self, entity_kg):
|
||||
"""Auto-detect should classify non-http:// as literal."""
|
||||
kg, ctx = entity_kg
|
||||
ctx["batches"].clear()
|
||||
|
||||
kg.insert(
|
||||
collection='col',
|
||||
s='http://ex.org/s',
|
||||
p='http://ex.org/p',
|
||||
o='plain text',
|
||||
)
|
||||
|
||||
batch = ctx["batches"][0]
|
||||
for _, params in batch._adds:
|
||||
if len(params) == 10:
|
||||
assert params[4] == 'l'
|
||||
164
tests/unit/test_embeddings/test_document_embeddings_processor.py
Normal file
164
tests/unit/test_embeddings/test_document_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Tests for document embeddings processor — single-chunk embedding via batch API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.embeddings.document_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import (
|
||||
Chunk, DocumentEmbeddings, ChunkEmbeddings,
|
||||
EmbeddingsRequest, EmbeddingsResponse, Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return Processor(
|
||||
taskgroup=AsyncMock(),
|
||||
id="test-doc-embeddings",
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
|
||||
user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
return msg
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsProcessor:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_single_text_as_list(self, processor):
|
||||
"""Document embeddings should wrap single chunk in a list for the API."""
|
||||
msg = _make_chunk_message("test chunk text")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.1, 0.2, 0.3]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# Should send EmbeddingsRequest with texts=[chunk]
|
||||
mock_request.assert_called_once()
|
||||
req = mock_request.call_args[0][0]
|
||||
assert isinstance(req, EmbeddingsRequest)
|
||||
assert req.texts == ["test chunk text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_first_vector(self, processor):
|
||||
"""Should use vectors[0] from the response."""
|
||||
msg = _make_chunk_message("chunk")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[1.0, 2.0, 3.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert isinstance(result, DocumentEmbeddings)
|
||||
assert len(result.chunks) == 1
|
||||
assert result.chunks[0].vector == [1.0, 2.0, 3.0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_vectors_response(self, processor):
|
||||
"""Should handle empty vectors response gracefully."""
|
||||
msg = _make_chunk_message("chunk")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.chunks[0].vector == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_id_is_document_id(self, processor):
|
||||
"""ChunkEmbeddings should use document_id as chunk_id."""
|
||||
msg = _make_chunk_message(doc_id="my-doc-42")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.chunks[0].chunk_id == "my-doc-42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self, processor):
|
||||
"""Output should carry the original metadata."""
|
||||
msg = _make_chunk_message(user="alice", collection="reports", doc_id="d1")
|
||||
|
||||
mock_request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]]
|
||||
))
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "reports"
|
||||
assert result.metadata.id == "d1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_propagates(self, processor):
|
||||
"""Embedding errors should propagate for retry."""
|
||||
msg = _make_chunk_message()
|
||||
|
||||
mock_request = AsyncMock(side_effect=RuntimeError("service down"))
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(request=mock_request)
|
||||
return MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="service down"):
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
109
tests/unit/test_embeddings/test_embeddings_client.py
Normal file
109
tests/unit/test_embeddings/test_embeddings_client.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
Tests for EmbeddingsClient — the client interface for batch embeddings.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.base.embeddings_client import EmbeddingsClient
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
|
||||
|
||||
class TestEmbeddingsClient:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_sends_request_and_returns_vectors(self):
|
||||
"""embed() should send an EmbeddingsRequest and return vectors."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4]],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=["hello", "world"])
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
client.request.assert_called_once()
|
||||
req = client.request.call_args[0][0]
|
||||
assert isinstance(req, EmbeddingsRequest)
|
||||
assert req.texts == ["hello", "world"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_single_text(self):
|
||||
"""embed() should work with a single text."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None,
|
||||
vectors=[[1.0, 2.0, 3.0]],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=["single"])
|
||||
|
||||
assert result == [[1.0, 2.0, 3.0]]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_raises_on_error_response(self):
|
||||
"""embed() should raise RuntimeError when response contains an error."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=Error(type="embeddings-error", message="model not found"),
|
||||
vectors=[],
|
||||
))
|
||||
|
||||
with pytest.raises(RuntimeError, match="model not found"):
|
||||
await client.embed(texts=["test"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_passes_timeout(self):
|
||||
"""embed() should pass timeout to the underlying request."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]],
|
||||
))
|
||||
|
||||
await client.embed(texts=["test"], timeout=60)
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_default_timeout(self):
|
||||
"""embed() should use 300s default timeout."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[[0.0]],
|
||||
))
|
||||
|
||||
await client.embed(texts=["test"])
|
||||
|
||||
_, kwargs = client.request.call_args
|
||||
assert kwargs["timeout"] == 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_empty_texts(self):
|
||||
"""embed() with empty list should still make the request."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=[],
|
||||
))
|
||||
|
||||
result = await client.embed(texts=[])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_large_batch(self):
|
||||
"""embed() should handle large batches."""
|
||||
client = EmbeddingsClient.__new__(EmbeddingsClient)
|
||||
n = 100
|
||||
vectors = [[float(i)] for i in range(n)]
|
||||
client.request = AsyncMock(return_value=EmbeddingsResponse(
|
||||
error=None, vectors=vectors,
|
||||
))
|
||||
|
||||
texts = [f"text {i}" for i in range(n)]
|
||||
result = await client.embed(texts=texts)
|
||||
|
||||
assert len(result) == n
|
||||
req = client.request.call_args[0][0]
|
||||
assert len(req.texts) == n
|
||||
135
tests/unit/test_embeddings/test_embeddings_service_request.py
Normal file
135
tests/unit/test_embeddings/test_embeddings_service_request.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""
|
||||
Tests for EmbeddingsService.on_request — the request handler that dispatches
|
||||
to on_embeddings and sends responses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.base import EmbeddingsService
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class StubEmbeddingsService(EmbeddingsService):
|
||||
"""Minimal concrete implementation for testing on_request."""
|
||||
|
||||
def __init__(self, embed_result=None, embed_error=None):
|
||||
# Skip super().__init__ to avoid taskgroup/registration
|
||||
self.embed_result = embed_result or [[0.1, 0.2]]
|
||||
self.embed_error = embed_error
|
||||
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
if self.embed_error:
|
||||
raise self.embed_error
|
||||
return self.embed_result
|
||||
|
||||
|
||||
def _make_msg(texts, msg_id="req-1"):
|
||||
request = EmbeddingsRequest(texts=texts)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": msg_id}
|
||||
return msg
|
||||
|
||||
|
||||
def _make_flow(model="test-model"):
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
|
||||
def flow_callable(name):
|
||||
if name == "model":
|
||||
return model
|
||||
if name == "response":
|
||||
return mock_response_producer
|
||||
return MagicMock()
|
||||
|
||||
flow_callable.producer = {"response": mock_response_producer}
|
||||
return flow_callable, mock_response_producer
|
||||
|
||||
|
||||
class TestEmbeddingsServiceOnRequest:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request(self):
|
||||
"""on_request should call on_embeddings and send response."""
|
||||
service = StubEmbeddingsService(embed_result=[[0.1, 0.2], [0.3, 0.4]])
|
||||
msg = _make_msg(["hello", "world"], msg_id="r1")
|
||||
flow, mock_response = _make_flow(model="my-model")
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
mock_response.send.assert_called_once()
|
||||
resp = mock_response.send.call_args[0][0]
|
||||
assert isinstance(resp, EmbeddingsResponse)
|
||||
assert resp.error is None
|
||||
assert resp.vectors == [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Check id is passed through
|
||||
props = mock_response.send.call_args[1]["properties"]
|
||||
assert props["id"] == "r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_model_from_flow(self):
|
||||
"""on_request should pass model parameter from flow to on_embeddings."""
|
||||
calls = []
|
||||
|
||||
class TrackingService(EmbeddingsService):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
calls.append({"texts": texts, "model": model})
|
||||
return [[0.0]]
|
||||
|
||||
service = TrackingService()
|
||||
msg = _make_msg(["test"])
|
||||
flow, _ = _make_flow(model="custom-model-v2")
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["model"] == "custom-model-v2"
|
||||
assert calls[0]["texts"] == ["test"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_sends_error_response(self):
|
||||
"""Non-rate-limit errors should send an error response."""
|
||||
service = StubEmbeddingsService(
|
||||
embed_error=ValueError("dimension mismatch")
|
||||
)
|
||||
msg = _make_msg(["test"], msg_id="r2")
|
||||
flow, mock_response = _make_flow()
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
mock_response.send.assert_called_once()
|
||||
resp = mock_response.send.call_args[0][0]
|
||||
assert resp.error is not None
|
||||
assert resp.error.type == "embeddings-error"
|
||||
assert "dimension mismatch" in resp.error.message
|
||||
assert resp.vectors == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_propagates(self):
|
||||
"""TooManyRequests should propagate (not caught as error response)."""
|
||||
service = StubEmbeddingsService(
|
||||
embed_error=TooManyRequests("rate limited")
|
||||
)
|
||||
msg = _make_msg(["test"])
|
||||
flow, _ = _make_flow()
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_id_preserved(self):
|
||||
"""The request message id should be forwarded in the response properties."""
|
||||
service = StubEmbeddingsService()
|
||||
msg = _make_msg(["test"], msg_id="unique-id-42")
|
||||
flow, mock_response = _make_flow()
|
||||
|
||||
await service.on_request(msg, MagicMock(), flow)
|
||||
|
||||
props = mock_response.send.call_args[1]["properties"]
|
||||
assert props["id"] == "unique-id-42"
|
||||
|
|
@ -103,7 +103,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_fastembed_instance.embed.assert_called_once_with(["test text"])
|
||||
|
|
@ -126,7 +126,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
mock_text_embedding_class.reset_mock()
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_text_embedding_class.assert_called_once_with(model_name="custom-model")
|
||||
|
|
@ -149,16 +149,16 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_call_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act - switch between models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
call_count_after_a = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text2", model="model-a") # Same, no reload
|
||||
await processor.on_embeddings(["text2"], model="model-a") # Same, no reload
|
||||
call_count_after_a_repeat = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text3", model="model-b") # Different, reload
|
||||
await processor.on_embeddings(["text3"], model="model-b") # Different, reload
|
||||
call_count_after_b = mock_text_embedding_class.call_count
|
||||
|
||||
await processor.on_embeddings("text4", model="model-a") # Back to A, reload
|
||||
await processor.on_embeddings(["text4"], model="model-a") # Back to A, reload
|
||||
call_count_after_a_again = mock_text_embedding_class.call_count
|
||||
|
||||
# Assert
|
||||
|
|
@ -183,7 +183,7 @@ class TestFastEmbedDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
initial_count = mock_text_embedding_class.call_count
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
# No reload, using cached default
|
||||
|
|
|
|||
233
tests/unit/test_embeddings/test_graph_embeddings_processor.py
Normal file
233
tests/unit/test_embeddings/test_graph_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""
|
||||
Tests for graph embeddings processor — batch embedding of entity contexts.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.embeddings.graph_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import (
|
||||
EntityContexts, EntityEmbeddings, GraphEmbeddings,
|
||||
Term, IRI, Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return Processor(
|
||||
taskgroup=AsyncMock(),
|
||||
id="test-graph-embeddings",
|
||||
batch_size=3,
|
||||
)
|
||||
|
||||
|
||||
def _make_entity_context(name, context, chunk_id="chunk-1"):
|
||||
"""Create an entity context for testing."""
|
||||
entity = Term(type=IRI, iri=f"urn:entity:{name}")
|
||||
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
|
||||
|
||||
|
||||
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
|
||||
metadata = Metadata(id=doc_id, user=user, collection=collection)
|
||||
value = EntityContexts(metadata=metadata, entities=entities)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = value
|
||||
return msg
|
||||
|
||||
|
||||
class TestGraphEmbeddingsInit:
|
||||
|
||||
def test_default_batch_size(self):
|
||||
p = Processor(taskgroup=AsyncMock(), id="test")
|
||||
assert p.batch_size == 5
|
||||
|
||||
def test_custom_batch_size(self):
|
||||
p = Processor(taskgroup=AsyncMock(), id="test", batch_size=20)
|
||||
assert p.batch_size == 20
|
||||
|
||||
|
||||
class TestGraphEmbeddingsBatchProcessing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_batch_call_for_all_entities(self, processor):
|
||||
"""All entity contexts should be embedded in a single API call."""
|
||||
entities = [
|
||||
_make_entity_context("Alice", "Alice is a person"),
|
||||
_make_entity_context("Bob", "Bob is a developer"),
|
||||
_make_entity_context("Acme", "Acme is a company"),
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[
|
||||
[0.1, 0.2], [0.3, 0.4], [0.5, 0.6],
|
||||
])
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# Single batch call with all three texts
|
||||
mock_embed.assert_called_once_with(
|
||||
texts=["Alice is a person", "Bob is a developer", "Acme is a company"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectors_paired_with_correct_entities(self, processor):
|
||||
"""Each vector should be paired with its corresponding entity."""
|
||||
entities = [
|
||||
_make_entity_context("Alice", "ctx-A", chunk_id="c1"),
|
||||
_make_entity_context("Bob", "ctx-B", chunk_id="c2"),
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
vectors = [[1.0, 2.0], [3.0, 4.0]]
|
||||
mock_embed = AsyncMock(return_value=vectors)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# With batch_size=3, all 2 entities fit in one output message
|
||||
mock_output.send.assert_called_once()
|
||||
result = mock_output.send.call_args[0][0]
|
||||
assert isinstance(result, GraphEmbeddings)
|
||||
assert len(result.entities) == 2
|
||||
assert result.entities[0].vector == [1.0, 2.0]
|
||||
assert result.entities[0].entity.iri == "urn:entity:Alice"
|
||||
assert result.entities[0].chunk_id == "c1"
|
||||
assert result.entities[1].vector == [3.0, 4.0]
|
||||
assert result.entities[1].entity.iri == "urn:entity:Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_batching(self, processor):
|
||||
"""Output should be split into batches of batch_size."""
|
||||
# batch_size=3, 7 entities -> 3 output messages (3+3+1)
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"context {i}")
|
||||
for i in range(7)
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
vectors = [[float(i)] for i in range(7)]
|
||||
mock_embed = AsyncMock(return_value=vectors)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert mock_output.send.call_count == 3
|
||||
# First batch has 3 entities
|
||||
batch1 = mock_output.send.call_args_list[0][0][0]
|
||||
assert len(batch1.entities) == 3
|
||||
# Second batch has 3 entities
|
||||
batch2 = mock_output.send.call_args_list[1][0][0]
|
||||
assert len(batch2.entities) == 3
|
||||
# Third batch has 1 entity
|
||||
batch3 = mock_output.send.call_args_list[2][0][0]
|
||||
assert len(batch3.entities) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_batches_preserve_metadata(self, processor):
|
||||
"""Each output batch should carry the original metadata."""
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(5)
|
||||
]
|
||||
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for call in mock_output.send.call_args_list:
|
||||
result = call[0][0]
|
||||
assert result.metadata.id == "doc-42"
|
||||
assert result.metadata.user == "alice"
|
||||
assert result.metadata.collection == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_entity(self, processor):
|
||||
"""Single entity should work with one embed call and one output."""
|
||||
entities = [_make_entity_context("Solo", "solo context")]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[1.0, 2.0, 3.0]])
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
mock_embed.assert_called_once_with(texts=["solo context"])
|
||||
mock_output.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_error_propagates(self, processor):
|
||||
"""Embedding service errors should propagate for retry."""
|
||||
entities = [_make_entity_context("E", "ctx")]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(side_effect=RuntimeError("embedding failed"))
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
return MagicMock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="embedding failed"):
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_batch_size(self, processor):
|
||||
"""When entity count equals batch_size, exactly one output message."""
|
||||
entities = [
|
||||
_make_entity_context(f"E{i}", f"ctx {i}")
|
||||
for i in range(3) # batch_size=3
|
||||
]
|
||||
msg = _make_message(entities)
|
||||
|
||||
mock_embed = AsyncMock(return_value=[[0.0]] * 3)
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow(name):
|
||||
if name == "embeddings-request":
|
||||
return MagicMock(embed=mock_embed)
|
||||
elif name == "output":
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
await processor.on_message(msg, MagicMock(), flow)
|
||||
|
||||
mock_output.send.assert_called_once()
|
||||
assert len(mock_output.send.call_args[0][0].entities) == 3
|
||||
|
|
@ -53,12 +53,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text")
|
||||
result = await processor.on_embeddings(["test text"])
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
|
|
@ -79,12 +79,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model="custom-model")
|
||||
result = await processor.on_embeddings(["test text"], model="custom-model")
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="custom-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
|
||||
|
|
@ -105,10 +105,10 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act - switch between different models
|
||||
await processor.on_embeddings("text1", model="model-a")
|
||||
await processor.on_embeddings("text2", model="model-b")
|
||||
await processor.on_embeddings("text3", model="model-a")
|
||||
await processor.on_embeddings("text4") # Use default
|
||||
await processor.on_embeddings(["text1"], model="model-a")
|
||||
await processor.on_embeddings(["text2"], model="model-b")
|
||||
await processor.on_embeddings(["text3"], model="model-a")
|
||||
await processor.on_embeddings(["text4"]) # Use default
|
||||
|
||||
# Assert
|
||||
calls = mock_ollama_client.embed.call_args_list
|
||||
|
|
@ -135,12 +135,12 @@ class TestOllamaDynamicModelLoading(IsolatedAsyncioTestCase):
|
|||
processor = Processor(id="test", concurrency=1, model="test-model", taskgroup=AsyncMock())
|
||||
|
||||
# Act
|
||||
result = await processor.on_embeddings("test text", model=None)
|
||||
result = await processor.on_embeddings(["test text"], model=None)
|
||||
|
||||
# Assert
|
||||
mock_ollama_client.embed.assert_called_once_with(
|
||||
model="test-model",
|
||||
input="test text"
|
||||
input=["test text"]
|
||||
)
|
||||
|
||||
@patch('trustgraph.embeddings.ollama.processor.Client')
|
||||
|
|
|
|||
|
|
@ -353,7 +353,14 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
# Return batch of vector sets (one per text)
|
||||
# 4 unique texts: CUST001, John Doe, CUST002, Jane Smith
|
||||
mock_embeddings_request.embed.return_value = [
|
||||
[[0.1, 0.2, 0.3]], # vectors for text 1
|
||||
[[0.2, 0.3, 0.4]], # vectors for text 2
|
||||
[[0.3, 0.4, 0.5]], # vectors for text 3
|
||||
[[0.4, 0.5, 0.6]], # vectors for text 4
|
||||
]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
|
|
@ -368,9 +375,12 @@ class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
|||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
# Should have called embed once with all texts in a batch
|
||||
assert mock_embeddings_request.embed.call_count == 1
|
||||
# Verify it was called with a list of texts
|
||||
call_args = mock_embeddings_request.embed.call_args
|
||||
assert 'texts' in call_args.kwargs
|
||||
assert len(call_args.kwargs['texts']) == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
Tests for streaming triple and entity context batching in the definitions
|
||||
KG extractor.
|
||||
|
||||
Covers: triples batch splitting, entity context batch splitting,
|
||||
metadata preservation, provenance, and empty/null filtering.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_processor(triples_batch_size=default_triples_batch_size,
|
||||
entity_batch_size=default_entity_batch_size):
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.triples_batch_size = triples_batch_size
|
||||
proc.entity_batch_size = entity_batch_size
|
||||
return proc
|
||||
|
||||
|
||||
def _make_defn(entity, definition):
|
||||
return {"entity": entity, "definition": definition}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = chunk
|
||||
return msg
|
||||
|
||||
|
||||
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
||||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
if name == "triples":
|
||||
return mock_triples_pub
|
||||
if name == "entity-contexts":
|
||||
return mock_ecs_pub
|
||||
if name == "llm-model":
|
||||
return llm_model
|
||||
if name == "ontology":
|
||||
return ontology_uri
|
||||
return MagicMock()
|
||||
|
||||
return flow, mock_triples_pub, mock_ecs_pub, mock_prompt_client
|
||||
|
||||
|
||||
def _sent_triples(mock_pub):
|
||||
return [call.args[0] for call in mock_pub.send.call_args_list]
|
||||
|
||||
|
||||
def _sent_ecs(mock_pub):
|
||||
return [call.args[0] for call in mock_pub.send.call_args_list]
|
||||
|
||||
|
||||
def _all_triples_flat(mock_pub):
|
||||
result = []
|
||||
for triples_msg in _sent_triples(mock_pub):
|
||||
result.extend(triples_msg.triples)
|
||||
return result
|
||||
|
||||
|
||||
def _all_entities_flat(mock_pub):
|
||||
result = []
|
||||
for ecs_msg in _sent_ecs(mock_pub):
|
||||
result.extend(ecs_msg.entities)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDefaults:
|
||||
|
||||
def test_default_triples_batch_size(self):
|
||||
assert default_triples_batch_size == 50
|
||||
|
||||
def test_default_entity_batch_size(self):
|
||||
assert default_entity_batch_size == 5
|
||||
|
||||
|
||||
class TestTriplesBatching:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_batch_when_under_limit(self):
|
||||
proc = _make_processor(triples_batch_size=100)
|
||||
defs = [_make_defn("Cat", "A feline animal")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_triples_batches(self):
|
||||
proc = _make_processor(triples_batch_size=2)
|
||||
defs = [
|
||||
_make_defn("Cat", "A feline"),
|
||||
_make_defn("Dog", "A canine"),
|
||||
]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# 2 defs → 2 labels + 2 definitions = 4 triples + provenance
|
||||
# With batch_size=2, should produce multiple batches
|
||||
assert triples_pub.send.call_count > 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triples_batch_sizes_within_limit(self):
|
||||
batch_size = 3
|
||||
proc = _make_processor(triples_batch_size=batch_size)
|
||||
defs = [
|
||||
_make_defn("A", "def A"),
|
||||
_make_defn("B", "def B"),
|
||||
_make_defn("C", "def C"),
|
||||
]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for triples_msg in _sent_triples(triples_pub):
|
||||
assert len(triples_msg.triples) <= batch_size
|
||||
|
||||
|
||||
class TestEntityContextBatching:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_entity_batch_when_under_limit(self):
|
||||
proc = _make_processor(entity_batch_size=100)
|
||||
defs = [_make_defn("Cat", "A feline")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# 1 def → 2 entity contexts (name + definition)
|
||||
assert ecs_pub.send.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_entity_batches(self):
|
||||
proc = _make_processor(entity_batch_size=2)
|
||||
defs = [
|
||||
_make_defn("Cat", "A feline"),
|
||||
_make_defn("Dog", "A canine"),
|
||||
]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# 2 defs → 4 entity contexts, batch_size=2 → 2 batches
|
||||
assert ecs_pub.send.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_batch_sizes_within_limit(self):
|
||||
batch_size = 3
|
||||
proc = _make_processor(entity_batch_size=batch_size)
|
||||
defs = [
|
||||
_make_defn("A", "def A"),
|
||||
_make_defn("B", "def B"),
|
||||
_make_defn("C", "def C"),
|
||||
]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for ecs_msg in _sent_ecs(ecs_pub):
|
||||
assert len(ecs_msg.entities) <= batch_size
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_contexts_have_name_and_definition(self):
|
||||
"""Each definition produces 2 entity contexts: name and definition."""
|
||||
proc = _make_processor(entity_batch_size=100)
|
||||
defs = [_make_defn("Cat", "A feline animal")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
entities = _all_entities_flat(ecs_pub)
|
||||
assert len(entities) == 2
|
||||
contexts = {e.context for e in entities}
|
||||
assert "Cat" in contexts
|
||||
assert "A feline animal" in contexts
|
||||
|
||||
|
||||
class TestMetadataPreservation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triples_metadata(self):
|
||||
proc = _make_processor(triples_batch_size=2)
|
||||
defs = [_make_defn("X", "def X")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for triples_msg in _sent_triples(triples_pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_contexts_metadata(self):
|
||||
proc = _make_processor(entity_batch_size=1)
|
||||
defs = [_make_defn("X", "def X")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-2", root="r-2",
|
||||
user="u-2", collection="coll-2",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for ecs_msg in _sent_ecs(ecs_pub):
|
||||
assert ecs_msg.metadata.id == "c-2"
|
||||
assert ecs_msg.metadata.root == "r-2"
|
||||
|
||||
|
||||
class TestEmptyAndNullFiltering:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entity_skipped(self):
|
||||
proc = _make_processor()
|
||||
defs = [
|
||||
_make_defn("", "some definition"),
|
||||
_make_defn("Valid", "a valid definition"),
|
||||
]
|
||||
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(triples_pub)
|
||||
all_e = _all_entities_flat(ecs_pub)
|
||||
# Only "Valid" should be present
|
||||
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
|
||||
assert any("valid" in iri for iri in entity_iris)
|
||||
assert len(all_e) == 2 # name + definition for "Valid" only
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_definition_skipped(self):
|
||||
proc = _make_processor()
|
||||
defs = [
|
||||
_make_defn("Entity", ""),
|
||||
_make_defn("Good", "good definition"),
|
||||
]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(triples_pub)
|
||||
entity_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri")}
|
||||
assert any("good" in iri for iri in entity_iris)
|
||||
# "Entity" with empty def should have been skipped
|
||||
assert not any("entity" in iri and "good" not in iri for iri in entity_iris)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_fields_skipped(self):
|
||||
proc = _make_processor()
|
||||
defs = [
|
||||
_make_defn(None, "some definition"),
|
||||
_make_defn("Entity", None),
|
||||
]
|
||||
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 0
|
||||
assert ecs_pub.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_filtered_no_output(self):
|
||||
proc = _make_processor()
|
||||
defs = [_make_defn("", ""), _make_defn(None, None)]
|
||||
flow, triples_pub, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 0
|
||||
assert ecs_pub.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_response(self):
|
||||
proc = _make_processor()
|
||||
flow, triples_pub, ecs_pub, _ = _make_flow([])
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 0
|
||||
assert ecs_pub.send.call_count == 0
|
||||
|
||||
|
||||
class TestProvenanceInclusion:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provenance_triples_present(self):
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
defs = [_make_defn("Cat", "A feline")]
|
||||
flow, triples_pub, _, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(triples_pub)
|
||||
# 1 def → 1 label + 1 definition = 2 content triples
|
||||
# Provenance adds more
|
||||
assert len(all_t) > 2
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_error_caught(self):
|
||||
proc = _make_processor()
|
||||
flow, triples_pub, ecs_pub, prompt = _make_flow([])
|
||||
prompt.extract_definitions = AsyncMock(
|
||||
side_effect=RuntimeError("LLM error")
|
||||
)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 0
|
||||
assert ecs_pub.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_list_response_caught(self):
|
||||
proc = _make_processor()
|
||||
flow, triples_pub, ecs_pub, prompt = _make_flow("not a list")
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert triples_pub.send.call_count == 0
|
||||
assert ecs_pub.send.call_count == 0
|
||||
|
||||
|
||||
class TestDocumentIdProvenance:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_id_used_for_chunk_id(self):
|
||||
"""When document_id is set, entity contexts should use it as chunk_id."""
|
||||
proc = _make_processor(entity_batch_size=100)
|
||||
defs = [_make_defn("Cat", "A feline")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text", document_id="doc-123")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
entities = _all_entities_flat(ecs_pub)
|
||||
for e in entities:
|
||||
assert e.chunk_id == "doc-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_id_fallback_for_chunk_id(self):
|
||||
"""When document_id is empty, metadata.id is used as chunk_id."""
|
||||
proc = _make_processor(entity_batch_size=100)
|
||||
defs = [_make_defn("Cat", "A feline")]
|
||||
flow, _, ecs_pub, _ = _make_flow(defs)
|
||||
msg = _make_chunk_msg("text", meta_id="chunk-42", document_id="")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
entities = _all_entities_flat(ecs_pub)
|
||||
for e in entities:
|
||||
assert e.chunk_id == "chunk-42"
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Tests for streaming triple batching in the relationships KG extractor.
|
||||
|
||||
Covers: batch size configuration, output splitting, metadata preservation,
|
||||
provenance inclusion, empty/null filtering, and error propagation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.relationships.extract import (
|
||||
Processor, default_triples_batch_size,
|
||||
)
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_processor(triples_batch_size=default_triples_batch_size):
|
||||
"""Create a Processor without triggering FlowProcessor.__init__."""
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.triples_batch_size = triples_batch_size
|
||||
return proc
|
||||
|
||||
|
||||
def _make_rel(subject, predicate, obj, object_entity=True):
|
||||
"""Build a relationship dict as returned by the prompt client."""
|
||||
return {
|
||||
"subject": subject,
|
||||
"predicate": predicate,
|
||||
"object": obj,
|
||||
"object-entity": object_entity,
|
||||
}
|
||||
|
||||
|
||||
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
|
||||
user="user-1", collection="col-1", document_id=""):
|
||||
"""Build a mock message wrapping a Chunk."""
|
||||
chunk = Chunk(
|
||||
metadata=Metadata(
|
||||
id=meta_id, root=root, user=user, collection=collection,
|
||||
),
|
||||
chunk=text.encode("utf-8"),
|
||||
document_id=document_id,
|
||||
)
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = chunk
|
||||
return msg
|
||||
|
||||
|
||||
def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
||||
"""Build a mock flow callable that provides prompt client, triples
|
||||
producer, and parameter specs."""
|
||||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
if name == "triples":
|
||||
return mock_triples_pub
|
||||
if name == "llm-model":
|
||||
return llm_model
|
||||
if name == "ontology":
|
||||
return ontology_uri
|
||||
return MagicMock()
|
||||
|
||||
return flow, mock_triples_pub, mock_prompt_client
|
||||
|
||||
|
||||
def _sent_triples(mock_pub):
|
||||
"""Collect all Triples objects sent to a mock publisher."""
|
||||
return [call.args[0] for call in mock_pub.send.call_args_list]
|
||||
|
||||
|
||||
def _all_triples_flat(mock_pub):
|
||||
"""Flatten all batches into one list of Triple objects."""
|
||||
result = []
|
||||
for triples_msg in _sent_triples(mock_pub):
|
||||
result.extend(triples_msg.triples)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDefaultBatchSize:
|
||||
|
||||
def test_default_is_50(self):
|
||||
assert default_triples_batch_size == 50
|
||||
|
||||
def test_processor_uses_default(self):
|
||||
proc = _make_processor()
|
||||
assert proc.triples_batch_size == 50
|
||||
|
||||
|
||||
class TestBatchSplitting:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_batch_when_under_limit(self):
|
||||
"""Few triples → single send call."""
|
||||
proc = _make_processor(triples_batch_size=50)
|
||||
rels = [_make_rel("A", "knows", "B")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("some text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# One relationship produces: rel triple + 3 labels + provenance
|
||||
# All should fit in one batch of 50
|
||||
assert pub.send.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_batches_with_small_batch_size(self):
|
||||
"""With batch_size=3 and many triples, multiple batches are sent."""
|
||||
proc = _make_processor(triples_batch_size=3)
|
||||
# 2 relationships → 2 rel triples + 6 labels = 8 triples + provenance
|
||||
rels = [
|
||||
_make_rel("A", "knows", "B"),
|
||||
_make_rel("C", "likes", "D"),
|
||||
]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("some text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
# Should have more than one batch
|
||||
assert pub.send.call_count > 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_sizes_respect_limit(self):
|
||||
"""No batch should exceed the configured batch size."""
|
||||
batch_size = 3
|
||||
proc = _make_processor(triples_batch_size=batch_size)
|
||||
rels = [
|
||||
_make_rel("A", "knows", "B"),
|
||||
_make_rel("C", "likes", "D"),
|
||||
_make_rel("E", "has", "F"),
|
||||
]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for triples_msg in _sent_triples(pub):
|
||||
assert len(triples_msg.triples) <= batch_size
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_present_across_batches(self):
|
||||
"""Total triples across batches equals expected count."""
|
||||
proc = _make_processor(triples_batch_size=2)
|
||||
# 1 relationship with object-entity=True → 1 rel + 3 labels = 4 triples
|
||||
# + provenance triples
|
||||
rels = [_make_rel("A", "knows", "B", object_entity=True)]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
# At minimum: 1 rel + 3 labels = 4 content triples
|
||||
assert len(all_t) >= 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_batch_size(self):
|
||||
"""Processor respects custom triples_batch_size parameter."""
|
||||
proc = _make_processor(triples_batch_size=100)
|
||||
assert proc.triples_batch_size == 100
|
||||
|
||||
|
||||
class TestMetadataPreservation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_forwarded_to_all_batches(self):
|
||||
"""Every batch should carry the original chunk metadata."""
|
||||
proc = _make_processor(triples_batch_size=2)
|
||||
rels = [_make_rel("X", "rel", "Y")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg(
|
||||
"text", meta_id="c-1", root="r-1",
|
||||
user="u-1", collection="coll-1",
|
||||
)
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
for triples_msg in _sent_triples(pub):
|
||||
assert triples_msg.metadata.id == "c-1"
|
||||
assert triples_msg.metadata.root == "r-1"
|
||||
assert triples_msg.metadata.user == "u-1"
|
||||
assert triples_msg.metadata.collection == "coll-1"
|
||||
|
||||
|
||||
class TestRelationshipTriples:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_object_produces_iri(self):
|
||||
"""object-entity=True → object is an IRI, with label triple."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [_make_rel("Alice", "knows", "Bob", object_entity=True)]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
# Find the relationship triple (not a label)
|
||||
rel_triples = [
|
||||
t for t in all_t
|
||||
if t.o.type == IRI and "bob" in t.o.iri
|
||||
]
|
||||
assert len(rel_triples) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_literal_object_produces_literal(self):
|
||||
"""object-entity=False → object is a LITERAL, no label for object."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [_make_rel("Alice", "age", "30", object_entity=False)]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
# Find the relationship triple with literal object
|
||||
lit_triples = [
|
||||
t for t in all_t
|
||||
if t.o.type == LITERAL and t.o.value == "30"
|
||||
]
|
||||
assert len(lit_triples) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_labels_emitted_for_subject_and_predicate(self):
|
||||
"""Every relationship should produce label triples for s and p."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [_make_rel("Alice", "knows", "Bob")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
label_triples = [
|
||||
t for t in all_t
|
||||
if t.p.type == IRI and "label" in t.p.iri.lower()
|
||||
]
|
||||
labels = {t.o.value for t in label_triples}
|
||||
assert "Alice" in labels
|
||||
assert "knows" in labels
|
||||
assert "Bob" in labels # object-entity default is True
|
||||
|
||||
|
||||
class TestEmptyAndNullFiltering:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_fields_skipped(self):
|
||||
"""Relationships with empty string s/p/o are skipped."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [
|
||||
_make_rel("", "knows", "Bob"),
|
||||
_make_rel("Alice", "", "Bob"),
|
||||
_make_rel("Alice", "knows", ""),
|
||||
_make_rel("Good", "triple", "Here"),
|
||||
]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
# Only the "Good triple Here" relationship should produce content triples
|
||||
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
|
||||
assert any("good" in iri for iri in rel_iris)
|
||||
assert not any("alice" in iri for iri in rel_iris)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_fields_skipped(self):
|
||||
"""Relationships with None s/p/o are skipped."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [
|
||||
_make_rel(None, "knows", "Bob"),
|
||||
_make_rel("Alice", None, "Bob"),
|
||||
_make_rel("Alice", "knows", None),
|
||||
_make_rel("Valid", "rel", "Here"),
|
||||
]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
rel_iris = {t.s.iri for t in all_t if hasattr(t.s, "iri") and t.s.iri}
|
||||
assert any("valid" in iri for iri in rel_iris)
|
||||
assert not any("alice" in iri for iri in rel_iris)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_filtered_produces_no_output(self):
|
||||
"""If all relationships are empty/null, nothing is emitted."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [
|
||||
_make_rel("", "", ""),
|
||||
_make_rel(None, None, None),
|
||||
]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert pub.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_response_produces_no_output(self):
|
||||
"""Empty relationship list from prompt → no triples emitted."""
|
||||
proc = _make_processor()
|
||||
flow, pub, _ = _make_flow([])
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert pub.send.call_count == 0
|
||||
|
||||
|
||||
class TestProvenanceInclusion:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provenance_triples_present(self):
|
||||
"""Extracted relationships should include provenance triples."""
|
||||
proc = _make_processor(triples_batch_size=200)
|
||||
rels = [_make_rel("A", "knows", "B")]
|
||||
flow, pub, _ = _make_flow(rels)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
all_t = _all_triples_flat(pub)
|
||||
# Provenance triples use GRAPH_SOURCE graph context
|
||||
# They contain terms referencing prov: namespace or subgraph URIs
|
||||
# We just check that total count > 4 (1 rel + 3 labels)
|
||||
assert len(all_t) > 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_provenance_when_no_extracted_triples(self):
|
||||
"""Empty relationships → no provenance generated."""
|
||||
proc = _make_processor()
|
||||
flow, pub, _ = _make_flow([_make_rel("", "x", "y")])
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert pub.send.call_count == 0
|
||||
|
||||
|
||||
class TestErrorPropagation:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_error_is_caught(self):
|
||||
"""Errors from the prompt client are caught (logged, not raised)."""
|
||||
proc = _make_processor()
|
||||
flow, pub, prompt = _make_flow([])
|
||||
prompt.extract_relationships = AsyncMock(
|
||||
side_effect=RuntimeError("LLM unavailable")
|
||||
)
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
# The outer try/except in on_message catches and logs
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert pub.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_list_response_is_caught(self):
|
||||
"""Non-list prompt response triggers RuntimeError, caught by handler."""
|
||||
proc = _make_processor()
|
||||
flow, pub, prompt = _make_flow("not a list")
|
||||
msg = _make_chunk_msg("text")
|
||||
|
||||
await proc.on_message(msg, MagicMock(), flow)
|
||||
|
||||
assert pub.send.call_count == 0
|
||||
|
||||
|
||||
class TestToUri:
|
||||
|
||||
def test_spaces_replaced_with_hyphens(self):
|
||||
proc = _make_processor()
|
||||
uri = proc.to_uri("hello world")
|
||||
assert "hello-world" in uri
|
||||
|
||||
def test_lowercased(self):
|
||||
proc = _make_processor()
|
||||
uri = proc.to_uri("Hello World")
|
||||
assert "hello-world" in uri
|
||||
|
||||
def test_special_chars_encoded(self):
|
||||
proc = _make_processor()
|
||||
# urllib.parse.quote keeps / as safe by default
|
||||
uri = proc.to_uri("a/b")
|
||||
assert "a/b" in uri
|
||||
# Characters like spaces are encoded (handled via replace → hyphen)
|
||||
uri2 = proc.to_uri("hello world")
|
||||
assert " " not in uri2
|
||||
|
|
@ -55,13 +55,6 @@ def sample_objects_message():
|
|||
return {
|
||||
"metadata": {
|
||||
"id": "obj-123",
|
||||
"metadata": [
|
||||
{
|
||||
"s": {"v": "obj-123", "e": False},
|
||||
"p": {"v": "source", "e": False},
|
||||
"o": {"v": "test", "e": False}
|
||||
}
|
||||
],
|
||||
"user": "testuser",
|
||||
"collection": "testcollection"
|
||||
},
|
||||
|
|
@ -244,7 +237,6 @@ class TestRowsImportMessageProcessing:
|
|||
assert sent_object.metadata.id == "obj-123"
|
||||
assert sent_object.metadata.user == "testuser"
|
||||
assert sent_object.metadata.collection == "testcollection"
|
||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -277,7 +269,6 @@ class TestRowsImportMessageProcessing:
|
|||
assert sent_object.values[0]["field1"] == "value1"
|
||||
assert sent_object.confidence == 1.0 # Default value
|
||||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -96,20 +96,21 @@ class TestGraphRagResponseTranslator:
|
|||
assert is_final is False
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
# Test final chunk with empty content
|
||||
# Test final message with end_of_session=True
|
||||
final_response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
|
||||
# Assert
|
||||
# Assert - is_final is based on end_of_session, not end_of_stream
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
assert result["end_of_session"] is True
|
||||
|
||||
|
||||
class TestDocumentRagResponseTranslator:
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ class Triple:
|
|||
self.o = o
|
||||
|
||||
class Metadata:
|
||||
def __init__(self, id, user, collection, metadata):
|
||||
def __init__(self, id, user, collection, root=""):
|
||||
self.id = id
|
||||
self.root = root
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.metadata = metadata
|
||||
|
||||
class Triples:
|
||||
def __init__(self, metadata, triples):
|
||||
|
|
@ -110,7 +110,6 @@ def sample_triples(sample_triple):
|
|||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
return Triples(
|
||||
|
|
@ -126,7 +125,6 @@ def sample_chunk():
|
|||
id="test-chunk-456",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
return Chunk(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
||||
|
||||
|
|
@ -51,13 +51,6 @@ class TestAgentKgExtractor:
|
|||
"""Sample metadata for testing"""
|
||||
return Metadata(
|
||||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -175,7 +168,7 @@ This is not JSON at all
|
|||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check entity label triple
|
||||
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
||||
|
|
@ -190,12 +183,6 @@ This is not JSON at all
|
|||
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert def_triple.o.value == "A subset of AI that enables learning from data."
|
||||
|
||||
# Check subject-of triple
|
||||
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
|
||||
assert subject_of_triple is not None
|
||||
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert subject_of_triple.o.iri == "doc123"
|
||||
|
||||
# Check entity context
|
||||
assert len(entity_contexts) == 1
|
||||
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
|
|
@ -213,7 +200,7 @@ This is not JSON at all
|
|||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check that subject, predicate, and object labels are created
|
||||
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
|
|
@ -235,10 +222,6 @@ This is not JSON at all
|
|||
assert rel_triple.o.iri == object_uri
|
||||
assert rel_triple.o.type == IRI
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
|
||||
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
|
||||
|
||||
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of relationships with literal objects"""
|
||||
data = [
|
||||
|
|
@ -251,7 +234,7 @@ This is not JSON at all
|
|||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check that object labels are not created for literal objects
|
||||
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
|
||||
|
|
@ -260,7 +243,7 @@ This is not JSON at all
|
|||
|
||||
def test_process_extraction_data_combined(self, agent_extractor, sample_metadata, sample_extraction_data):
|
||||
"""Test processing of combined definitions and relationships"""
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
|
||||
|
||||
# Check that we have both definition and relationship triples
|
||||
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
|
||||
|
|
@ -274,16 +257,12 @@ This is not JSON at all
|
|||
|
||||
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
|
||||
"""Test processing when metadata has no ID"""
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
metadata = Metadata(id=None)
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should not create subject-of relationships when no metadata ID
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should still create entity contexts
|
||||
assert len(entity_contexts) == 1
|
||||
|
|
@ -292,7 +271,7 @@ This is not JSON at all
|
|||
"""Test processing of empty extraction data"""
|
||||
data = []
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should have no entity contexts
|
||||
assert len(entity_contexts) == 0
|
||||
|
|
@ -307,7 +286,7 @@ This is not JSON at all
|
|||
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
triples, entity_contexts, _ = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should process valid items and ignore unknown types
|
||||
assert len(entity_contexts) == 1 # Only the definition creates entity context
|
||||
|
|
@ -345,8 +324,6 @@ This is not JSON at all
|
|||
assert sent_triples.metadata.id == sample_metadata.id
|
||||
assert sent_triples.metadata.user == sample_metadata.user
|
||||
assert sent_triples.metadata.collection == sample_metadata.collection
|
||||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_triples.metadata.metadata == []
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
||||
|
|
@ -371,8 +348,6 @@ This is not JSON at all
|
|||
assert sent_contexts.metadata.id == sample_metadata.id
|
||||
assert sent_contexts.metadata.user == sample_metadata.user
|
||||
assert sent_contexts.metadata.collection == sample_metadata.collection
|
||||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_contexts.metadata.metadata == []
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
|
@ -168,7 +168,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
"""Test processing with empty or minimal metadata"""
|
||||
# Test with None metadata - may not raise AttributeError depending on implementation
|
||||
try:
|
||||
triples, contexts = agent_extractor.process_extraction_data([], None)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data([], None)
|
||||
# If it doesn't raise, check the results
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
|
|
@ -177,23 +177,19 @@ class TestAgentKgExtractionEdgeCases:
|
|||
pass
|
||||
|
||||
# Test with metadata without ID
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
triples, contexts = agent_extractor.process_extraction_data([], metadata)
|
||||
metadata = Metadata(id=None)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data([], metadata)
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
|
||||
# Test with metadata with empty string ID
|
||||
metadata = Metadata(id="", metadata=[])
|
||||
metadata = Metadata(id="")
|
||||
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should not create subject-of triples when ID is empty string
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
def test_process_extraction_data_special_entity_names(self, agent_extractor):
|
||||
"""Test processing with special characters in entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
special_entities = [
|
||||
"Entity with spaces",
|
||||
|
|
@ -213,7 +209,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
for entity in special_entities
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Verify all entities were processed
|
||||
assert len(contexts) == len(special_entities)
|
||||
|
|
@ -225,7 +221,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
|
||||
"""Test processing with very long entity definitions"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
# Create very long definition
|
||||
long_definition = "This is a very long definition. " * 1000
|
||||
|
|
@ -234,7 +230,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should handle long definitions without issues
|
||||
assert len(contexts) == 1
|
||||
|
|
@ -247,7 +243,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
|
||||
"""Test processing with duplicate entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
|
||||
|
|
@ -256,7 +252,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should process all entries (including duplicates)
|
||||
assert len(contexts) == 4
|
||||
|
|
@ -269,7 +265,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_empty_strings(self, agent_extractor):
|
||||
"""Test processing with empty strings in data"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
|
||||
|
|
@ -280,7 +276,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should handle empty strings by creating URIs (even if empty)
|
||||
assert len(contexts) == 3
|
||||
|
|
@ -291,7 +287,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
|
||||
"""Test processing when definitions contain JSON-like strings"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -306,7 +302,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should handle JSON strings in definitions without parsing them
|
||||
assert len(contexts) == 2
|
||||
|
|
@ -315,7 +311,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
|
||||
"""Test processing with various boolean values for object-entity"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
metadata = Metadata(id="doc123")
|
||||
|
||||
data = [
|
||||
# Explicit True
|
||||
|
|
@ -334,16 +330,16 @@ class TestAgentKgExtractionEdgeCases:
|
|||
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
# Should process all relationships
|
||||
# Note: The current implementation has some logic issues that these tests document
|
||||
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
|
||||
assert len([t for t in triples if t.p.iri != RDF_LABEL]) >= 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_empty_collections(self, agent_extractor):
|
||||
"""Test emitting empty triples and entity contexts"""
|
||||
metadata = Metadata(id="test", metadata=[])
|
||||
metadata = Metadata(id="test")
|
||||
|
||||
# Test emitting empty triples
|
||||
mock_publisher = AsyncMock()
|
||||
|
|
@ -389,7 +385,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
|
||||
"""Test performance with large extraction datasets"""
|
||||
metadata = Metadata(id="large-doc", metadata=[])
|
||||
metadata = Metadata(id="large-doc")
|
||||
|
||||
# Create large dataset in JSONL format
|
||||
num_definitions = 1000
|
||||
|
|
@ -416,7 +412,7 @@ class TestAgentKgExtractionEdgeCases:
|
|||
import time
|
||||
start_time = time.time()
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
|
||||
triples, contexts, _ = agent_extractor.process_extraction_data(large_data, metadata)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
|
|
|||
|
|
@ -314,7 +314,6 @@ class TestObjectExtractionBusinessLogic:
|
|||
id="test-extraction-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
values = [{
|
||||
|
|
|
|||
|
|
@ -373,7 +373,6 @@ class TestTripleConstructionLogic:
|
|||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
|
|||
0
tests/unit/test_librarian/__init__.py
Normal file
0
tests/unit/test_librarian/__init__.py
Normal file
716
tests/unit/test_librarian/test_chunked_upload.py
Normal file
716
tests/unit/test_librarian/test_chunked_upload.py
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
"""
|
||||
Tests for librarian chunked upload operations:
|
||||
begin_upload, upload_chunk, complete_upload, abort_upload, get_upload_status,
|
||||
list_uploads, and stream_document.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.librarian.librarian import Librarian, DEFAULT_CHUNK_SIZE
|
||||
from trustgraph.exceptions import RequestError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_librarian(min_chunk_size=1):
|
||||
"""Create a Librarian with mocked blob_store and table_store."""
|
||||
lib = Librarian.__new__(Librarian)
|
||||
lib.blob_store = MagicMock()
|
||||
lib.table_store = AsyncMock()
|
||||
lib.load_document = AsyncMock()
|
||||
lib.min_chunk_size = min_chunk_size
|
||||
return lib
|
||||
|
||||
|
||||
def _make_doc_metadata(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice", title="Test Doc"
|
||||
):
|
||||
meta = MagicMock()
|
||||
meta.id = doc_id
|
||||
meta.kind = kind
|
||||
meta.user = user
|
||||
meta.title = title
|
||||
meta.time = 1700000000
|
||||
meta.comments = ""
|
||||
meta.tags = []
|
||||
return meta
|
||||
|
||||
|
||||
def _make_begin_request(
|
||||
doc_id="doc-1", kind="application/pdf", user="alice",
|
||||
total_size=10_000_000, chunk_size=0
|
||||
):
|
||||
req = MagicMock()
|
||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, user=user)
|
||||
req.total_size = total_size
|
||||
req.chunk_size = chunk_size
|
||||
return req
|
||||
|
||||
|
||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, user="alice", content=b"data"):
|
||||
req = MagicMock()
|
||||
req.upload_id = upload_id
|
||||
req.chunk_index = chunk_index
|
||||
req.user = user
|
||||
req.content = base64.b64encode(content)
|
||||
return req
|
||||
|
||||
|
||||
def _make_session(
|
||||
user="alice", total_chunks=5, chunk_size=2_000_000,
|
||||
total_size=10_000_000, chunks_received=None, object_id="obj-1",
|
||||
s3_upload_id="s3-up-1", document_metadata=None, document_id="doc-1",
|
||||
):
|
||||
if chunks_received is None:
|
||||
chunks_received = {}
|
||||
if document_metadata is None:
|
||||
document_metadata = json.dumps({
|
||||
"id": document_id, "kind": "application/pdf",
|
||||
"user": user, "title": "Test", "time": 1700000000,
|
||||
"comments": "", "tags": [],
|
||||
})
|
||||
return {
|
||||
"user": user,
|
||||
"total_chunks": total_chunks,
|
||||
"chunk_size": chunk_size,
|
||||
"total_size": total_size,
|
||||
"chunks_received": chunks_received,
|
||||
"object_id": object_id,
|
||||
"s3_upload_id": s3_upload_id,
|
||||
"document_metadata": document_metadata,
|
||||
"document_id": document_id,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# begin_upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBeginUpload:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_session(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
|
||||
|
||||
req = _make_begin_request(total_size=10_000_000)
|
||||
resp = await lib.begin_upload(req)
|
||||
|
||||
assert resp.error is None
|
||||
assert resp.upload_id is not None
|
||||
assert resp.total_chunks == math.ceil(10_000_000 / DEFAULT_CHUNK_SIZE)
|
||||
assert resp.chunk_size == DEFAULT_CHUNK_SIZE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_chunk_size(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||
|
||||
req = _make_begin_request(total_size=10_000, chunk_size=3000)
|
||||
resp = await lib.begin_upload(req)
|
||||
|
||||
assert resp.chunk_size == 3000
|
||||
assert resp.total_chunks == math.ceil(10_000 / 3000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_invalid_kind(self):
|
||||
lib = _make_librarian()
|
||||
req = _make_begin_request(kind="image/png")
|
||||
|
||||
with pytest.raises(RequestError, match="Invalid document kind"):
|
||||
await lib.begin_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_duplicate_document(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = True
|
||||
|
||||
req = _make_begin_request()
|
||||
with pytest.raises(RequestError, match="already exists"):
|
||||
await lib.begin_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_zero_size(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
|
||||
req = _make_begin_request(total_size=0)
|
||||
with pytest.raises(RequestError, match="positive"):
|
||||
await lib.begin_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_chunk_below_minimum(self):
|
||||
lib = _make_librarian(min_chunk_size=1024)
|
||||
lib.table_store.document_exists.return_value = False
|
||||
|
||||
req = _make_begin_request(total_size=10_000, chunk_size=512)
|
||||
with pytest.raises(RequestError, match="below minimum"):
|
||||
await lib.begin_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_s3_create_multipart(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||
|
||||
req = _make_begin_request(kind="application/pdf")
|
||||
await lib.begin_upload(req)
|
||||
|
||||
lib.blob_store.create_multipart_upload.assert_called_once()
|
||||
# create_multipart_upload(object_id, kind) — positional args
|
||||
args = lib.blob_store.create_multipart_upload.call_args[0]
|
||||
assert args[1] == "application/pdf"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_session_in_cassandra(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||
|
||||
req = _make_begin_request(total_size=5_000_000)
|
||||
resp = await lib.begin_upload(req)
|
||||
|
||||
lib.table_store.create_upload_session.assert_called_once()
|
||||
kwargs = lib.table_store.create_upload_session.call_args[1]
|
||||
assert kwargs["upload_id"] == resp.upload_id
|
||||
assert kwargs["total_size"] == 5_000_000
|
||||
assert kwargs["total_chunks"] == resp.total_chunks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_text_plain(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.document_exists.return_value = False
|
||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||
|
||||
req = _make_begin_request(kind="text/plain", total_size=1000)
|
||||
resp = await lib.begin_upload(req)
|
||||
assert resp.error is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# upload_chunk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUploadChunk:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_chunk_upload(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(total_chunks=5, chunks_received={})
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "etag-1"
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
|
||||
resp = await lib.upload_chunk(req)
|
||||
|
||||
assert resp.error is None
|
||||
assert resp.chunk_index == 0
|
||||
assert resp.total_chunks == 5
|
||||
# The chunk is added to the dict (len=1), then +1 applied => 2
|
||||
assert resp.chunks_received == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_part_number_is_1_indexed(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session()
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "etag"
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=0)
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_index_3_becomes_part_4(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session()
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "etag"
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=3)
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||
assert kwargs["part_number"] == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_expired_session(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_upload_session.return_value = None
|
||||
|
||||
req = _make_upload_chunk_request()
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(user="bob")
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_negative_chunk_index(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(total_chunks=5)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=-1)
|
||||
with pytest.raises(RequestError, match="Invalid chunk index"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_out_of_range_chunk_index(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(total_chunks=5)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=5)
|
||||
with pytest.raises(RequestError, match="Invalid chunk index"):
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_tracking(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=4, chunk_size=1000, total_size=3500,
|
||||
chunks_received={0: "e1", 1: "e2"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "e3"
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=2)
|
||||
resp = await lib.upload_chunk(req)
|
||||
|
||||
# Dict gets chunk 2 added (len=3), then +1 => 4
|
||||
assert resp.chunks_received == 4
|
||||
assert resp.total_chunks == 4
|
||||
assert resp.total_bytes == 3500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bytes_capped_at_total_size(self):
|
||||
"""bytes_received should not exceed total_size for the final chunk."""
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=2, chunk_size=3000, total_size=5000,
|
||||
chunks_received={0: "e1"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "e2"
|
||||
|
||||
req = _make_upload_chunk_request(chunk_index=1)
|
||||
resp = await lib.upload_chunk(req)
|
||||
|
||||
# 3 chunks × 3000 = 9000 > 5000, so capped
|
||||
assert resp.bytes_received <= 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base64_decodes_content(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session()
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
lib.blob_store.upload_part.return_value = "etag"
|
||||
|
||||
raw = b"hello world binary data"
|
||||
req = _make_upload_chunk_request(content=raw)
|
||||
await lib.upload_chunk(req)
|
||||
|
||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||
assert kwargs["data"] == raw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# complete_upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCompleteUpload:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_completion(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=3,
|
||||
chunks_received={0: "e1", 1: "e2", 2: "e3"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.complete_upload(req)
|
||||
|
||||
assert resp.error is None
|
||||
assert resp.document_id == "doc-1"
|
||||
lib.blob_store.complete_multipart_upload.assert_called_once()
|
||||
lib.table_store.add_document.assert_called_once()
|
||||
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parts_sorted_by_index(self):
|
||||
lib = _make_librarian()
|
||||
# Chunks received out of order
|
||||
session = _make_session(
|
||||
total_chunks=3,
|
||||
chunks_received={2: "e3", 0: "e1", 1: "e2"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
await lib.complete_upload(req)
|
||||
|
||||
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
|
||||
part_numbers = [p[0] for p in parts]
|
||||
assert part_numbers == [1, 2, 3] # Sorted, 1-indexed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_missing_chunks(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=3,
|
||||
chunks_received={0: "e1", 2: "e3"}, # chunk 1 missing
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="Missing chunks"):
|
||||
await lib.complete_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_expired_session(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_upload_session.return_value = None
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.complete_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.complete_upload(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# abort_upload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAbortUpload:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aborts_and_cleans_up(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session()
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.abort_upload(req)
|
||||
|
||||
assert resp.error is None
|
||||
lib.blob_store.abort_multipart_upload.assert_called_once_with(
|
||||
object_id="obj-1", upload_id="s3-up-1"
|
||||
)
|
||||
lib.table_store.delete_upload_session.assert_called_once_with("up-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_expired_session(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_upload_session.return_value = None
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-gone"
|
||||
req.user = "alice"
|
||||
|
||||
with pytest.raises(RequestError, match="not found"):
|
||||
await lib.abort_upload(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.abort_upload(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_upload_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetUploadStatus:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_progress_status(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=5, chunk_size=2000, total_size=10_000,
|
||||
chunks_received={0: "e1", 2: "e3", 4: "e5"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
assert resp.upload_state == "in-progress"
|
||||
assert resp.chunks_received == 3
|
||||
assert resp.total_chunks == 5
|
||||
assert sorted(resp.received_chunks) == [0, 2, 4]
|
||||
assert sorted(resp.missing_chunks) == [1, 3]
|
||||
assert resp.total_bytes == 10_000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_upload_session.return_value = None
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-expired"
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
assert resp.upload_state == "expired"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_chunks_received(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(
|
||||
total_chunks=3, chunk_size=1000, total_size=2500,
|
||||
chunks_received={0: "e1", 1: "e2", 2: "e3"},
|
||||
)
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.get_upload_status(req)
|
||||
|
||||
assert resp.missing_chunks == []
|
||||
assert resp.chunks_received == 3
|
||||
# 3 * 1000 = 3000 > 2500, so capped
|
||||
assert resp.bytes_received <= 2500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_wrong_user(self):
|
||||
lib = _make_librarian()
|
||||
session = _make_session(user="alice")
|
||||
lib.table_store.get_upload_session.return_value = session
|
||||
|
||||
req = MagicMock()
|
||||
req.upload_id = "up-1"
|
||||
req.user = "bob"
|
||||
|
||||
with pytest.raises(RequestError, match="Not authorized"):
|
||||
await lib.get_upload_status(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_document
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStreamDocument:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streams_chunks_with_progress(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
chunks = []
|
||||
async for resp in lib.stream_document(req):
|
||||
chunks.append(resp)
|
||||
|
||||
assert len(chunks) == 3 # ceil(5000/2000)
|
||||
assert chunks[0].chunk_index == 0
|
||||
assert chunks[0].total_chunks == 3
|
||||
assert chunks[0].is_final is False
|
||||
assert chunks[-1].is_final is True
|
||||
assert chunks[-1].chunk_index == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_document(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=500)
|
||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
chunks = []
|
||||
async for resp in lib.stream_document(req):
|
||||
chunks.append(resp)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].bytes_received == 500
|
||||
assert chunks[0].total_bytes == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byte_ranges_correct(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 2000
|
||||
|
||||
chunks = []
|
||||
async for resp in lib.stream_document(req):
|
||||
chunks.append(resp)
|
||||
|
||||
# Verify the byte ranges passed to get_range
|
||||
calls = lib.blob_store.get_range.call_args_list
|
||||
assert calls[0][0] == ("obj-1", 0, 2000)
|
||||
assert calls[1][0] == ("obj-1", 2000, 2000)
|
||||
assert calls[2][0] == ("obj-1", 4000, 1000) # Last chunk: 5000-4000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_chunk_size(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=2_000_000)
|
||||
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 0 # Should use default 1MB
|
||||
|
||||
chunks = []
|
||||
async for resp in lib.stream_document(req):
|
||||
chunks.append(resp)
|
||||
|
||||
assert len(chunks) == 2 # ceil(2MB / 1MB)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_is_base64_encoded(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=100)
|
||||
raw = b"hello world"
|
||||
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 1000
|
||||
|
||||
chunks = []
|
||||
async for resp in lib.stream_document(req):
|
||||
chunks.append(resp)
|
||||
|
||||
assert chunks[0].content == base64.b64encode(raw)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_chunk_below_minimum(self):
|
||||
lib = _make_librarian(min_chunk_size=1024)
|
||||
lib.table_store.get_document_object_id.return_value = "obj-1"
|
||||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
req.document_id = "doc-1"
|
||||
req.chunk_size = 512
|
||||
|
||||
with pytest.raises(RequestError, match="below minimum"):
|
||||
async for _ in lib.stream_document(req):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_uploads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListUploads:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_sessions(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.list_upload_sessions.return_value = [
|
||||
{
|
||||
"upload_id": "up-1",
|
||||
"document_id": "doc-1",
|
||||
"document_metadata": '{"id":"doc-1"}',
|
||||
"total_size": 10000,
|
||||
"chunk_size": 2000,
|
||||
"total_chunks": 5,
|
||||
"chunks_received": {0: "e1", 1: "e2"},
|
||||
"created_at": "2024-01-01",
|
||||
},
|
||||
]
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
assert resp.error is None
|
||||
assert len(resp.upload_sessions) == 1
|
||||
assert resp.upload_sessions[0].upload_id == "up-1"
|
||||
assert resp.upload_sessions[0].total_chunks == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_uploads(self):
|
||||
lib = _make_librarian()
|
||||
lib.table_store.list_upload_sessions.return_value = []
|
||||
|
||||
req = MagicMock()
|
||||
req.user = "alice"
|
||||
|
||||
resp = await lib.list_uploads(req)
|
||||
|
||||
assert resp.upload_sessions == []
|
||||
0
tests/unit/test_provenance/__init__.py
Normal file
0
tests/unit/test_provenance/__init__.py
Normal file
336
tests/unit/test_provenance/test_agent_provenance.py
Normal file
336
tests/unit/test_provenance/test_agent_provenance.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
"""
|
||||
Tests for agent provenance triple builder functions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import Triple, Term, IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance.agent import (
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_final_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ACTIVITY, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
PROV_WAS_GENERATED_BY, PROV_STARTED_AT_TIME,
|
||||
TG_QUERY, TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_ANALYSIS, TG_CONCLUSION, TG_DOCUMENT,
|
||||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_AGENT_QUESTION,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
for t in triples:
|
||||
if (t.s.iri == subject and t.p.iri == RDF_TYPE
|
||||
and t.o.type == IRI and t.o.iri == rdf_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_session_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentSessionTriples:
|
||||
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
|
||||
def test_session_types(self):
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
assert has_type(triples, self.SESSION_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.SESSION_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.SESSION_URI, TG_AGENT_QUESTION)
|
||||
|
||||
def test_session_query_text(self):
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "What is X?", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
query = find_triple(triples, TG_QUERY, self.SESSION_URI)
|
||||
assert query is not None
|
||||
assert query.o.value == "What is X?"
|
||||
|
||||
def test_session_timestamp(self):
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-06-15T10:00:00Z"
|
||||
)
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
|
||||
assert ts is not None
|
||||
assert ts.o.value == "2024-06-15T10:00:00Z"
|
||||
|
||||
def test_session_default_timestamp(self):
|
||||
triples = agent_session_triples(self.SESSION_URI, "Q")
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.SESSION_URI)
|
||||
assert ts is not None
|
||||
assert len(ts.o.value) > 0
|
||||
|
||||
def test_session_label(self):
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.SESSION_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Agent Question"
|
||||
|
||||
def test_session_triple_count(self):
|
||||
triples = agent_session_triples(
|
||||
self.SESSION_URI, "Q", "2024-01-01T00:00:00Z"
|
||||
)
|
||||
assert len(triples) == 6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_iteration_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentIterationTriples:
|
||||
|
||||
ITER_URI = "urn:trustgraph:agent:test-session/i1"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i0"
|
||||
|
||||
def test_iteration_types(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
assert has_type(triples, self.ITER_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ITER_URI, TG_ANALYSIS)
|
||||
|
||||
def test_first_iteration_generated_by_question(self):
|
||||
"""First iteration uses wasGeneratedBy to link to question activity."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
# Should NOT have wasDerivedFrom
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_subsequent_iteration_derived_from_previous(self):
|
||||
"""Subsequent iterations use wasDerivedFrom to link to previous iteration."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, previous_uri=self.PREV_URI,
|
||||
action="search",
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ITER_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
# Should NOT have wasGeneratedBy
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ITER_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_iteration_label_includes_action(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.ITER_URI)
|
||||
assert label is not None
|
||||
assert "graph-rag-query" in label.o.value
|
||||
|
||||
def test_iteration_thought_sub_entity(self):
|
||||
"""Thought is a sub-entity with Reflection and Thought types."""
|
||||
thought_uri = "urn:trustgraph:agent:test-session/i1/thought"
|
||||
thought_doc = "urn:doc:thought-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
thought_uri=thought_uri,
|
||||
thought_document_id=thought_doc,
|
||||
)
|
||||
# Iteration links to thought sub-entity
|
||||
thought_link = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
assert thought_link is not None
|
||||
assert thought_link.o.iri == thought_uri
|
||||
# Thought has correct types
|
||||
assert has_type(triples, thought_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, thought_uri, TG_THOUGHT_TYPE)
|
||||
# Thought was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, thought_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Thought has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, thought_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == thought_doc
|
||||
|
||||
def test_iteration_observation_sub_entity(self):
|
||||
"""Observation is a sub-entity with Reflection and Observation types."""
|
||||
obs_uri = "urn:trustgraph:agent:test-session/i1/observation"
|
||||
obs_doc = "urn:doc:obs-1"
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
observation_uri=obs_uri,
|
||||
observation_document_id=obs_doc,
|
||||
)
|
||||
# Iteration links to observation sub-entity
|
||||
obs_link = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert obs_link is not None
|
||||
assert obs_link.o.iri == obs_uri
|
||||
# Observation has correct types
|
||||
assert has_type(triples, obs_uri, TG_REFLECTION_TYPE)
|
||||
assert has_type(triples, obs_uri, TG_OBSERVATION_TYPE)
|
||||
# Observation was generated by iteration
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, obs_uri)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.ITER_URI
|
||||
# Observation has document reference
|
||||
doc = find_triple(triples, TG_DOCUMENT, obs_uri)
|
||||
assert doc is not None
|
||||
assert doc.o.iri == obs_doc
|
||||
|
||||
def test_iteration_action_recorded(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="graph-rag-query",
|
||||
)
|
||||
action = find_triple(triples, TG_ACTION, self.ITER_URI)
|
||||
assert action is not None
|
||||
assert action.o.value == "graph-rag-query"
|
||||
|
||||
def test_iteration_arguments_json_encoded(self):
|
||||
args = {"query": "test query", "limit": 10}
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
arguments=args,
|
||||
)
|
||||
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
|
||||
assert arguments is not None
|
||||
parsed = json.loads(arguments.o.value)
|
||||
assert parsed == args
|
||||
|
||||
def test_iteration_default_arguments_empty_dict(self):
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="search",
|
||||
)
|
||||
arguments = find_triple(triples, TG_ARGUMENTS, self.ITER_URI)
|
||||
assert arguments is not None
|
||||
parsed = json.loads(arguments.o.value)
|
||||
assert parsed == {}
|
||||
|
||||
def test_iteration_no_thought_or_observation(self):
|
||||
"""Minimal iteration with just action — no thought or observation triples."""
|
||||
triples = agent_iteration_triples(
|
||||
self.ITER_URI, question_uri=self.SESSION_URI,
|
||||
action="noop",
|
||||
)
|
||||
thought = find_triple(triples, TG_THOUGHT, self.ITER_URI)
|
||||
obs = find_triple(triples, TG_OBSERVATION, self.ITER_URI)
|
||||
assert thought is None
|
||||
assert obs is None
|
||||
|
||||
def test_iteration_chaining(self):
|
||||
"""First iteration uses wasGeneratedBy, second uses wasDerivedFrom."""
|
||||
iter1_uri = "urn:trustgraph:agent:sess/i1"
|
||||
iter2_uri = "urn:trustgraph:agent:sess/i2"
|
||||
|
||||
triples1 = agent_iteration_triples(
|
||||
iter1_uri, question_uri=self.SESSION_URI, action="step1",
|
||||
)
|
||||
triples2 = agent_iteration_triples(
|
||||
iter2_uri, previous_uri=iter1_uri, action="step2",
|
||||
)
|
||||
|
||||
gen1 = find_triple(triples1, PROV_WAS_GENERATED_BY, iter1_uri)
|
||||
assert gen1.o.iri == self.SESSION_URI
|
||||
|
||||
derived2 = find_triple(triples2, PROV_WAS_DERIVED_FROM, iter2_uri)
|
||||
assert derived2.o.iri == iter1_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# agent_final_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentFinalTriples:
|
||||
|
||||
FINAL_URI = "urn:trustgraph:agent:test-session/final"
|
||||
PREV_URI = "urn:trustgraph:agent:test-session/i3"
|
||||
SESSION_URI = "urn:trustgraph:agent:test-session"
|
||||
|
||||
def test_final_types(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
assert has_type(triples, self.FINAL_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.FINAL_URI, TG_CONCLUSION)
|
||||
assert has_type(triples, self.FINAL_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_final_derived_from_previous(self):
|
||||
"""Conclusion with iterations uses wasDerivedFrom."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PREV_URI
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is None
|
||||
|
||||
def test_final_generated_by_question_when_no_iterations(self):
|
||||
"""When agent answers immediately, final uses wasGeneratedBy."""
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, question_uri=self.SESSION_URI,
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.FINAL_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.SESSION_URI
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FINAL_URI)
|
||||
assert derived is None
|
||||
|
||||
def test_final_label(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.FINAL_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Conclusion"
|
||||
|
||||
def test_final_document_reference(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
document_id="urn:trustgraph:agent:sess/answer",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:agent:sess/answer"
|
||||
|
||||
def test_final_no_document(self):
|
||||
triples = agent_final_triples(
|
||||
self.FINAL_URI, previous_uri=self.PREV_URI,
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.FINAL_URI)
|
||||
assert doc is None
|
||||
543
tests/unit/test_provenance/test_explainability.py
Normal file
543
tests/unit/test_provenance/test_explainability.py
Normal file
|
|
@ -0,0 +1,543 @@
|
|||
"""
|
||||
Tests for the explainability API (entity parsing, wire format conversion,
|
||||
and ExplainabilityClient).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.api.explainability import (
|
||||
EdgeSelection,
|
||||
ExplainEntity,
|
||||
Question,
|
||||
Grounding,
|
||||
Exploration,
|
||||
Focus,
|
||||
Synthesis,
|
||||
Reflection,
|
||||
Analysis,
|
||||
Conclusion,
|
||||
parse_edge_selection_triples,
|
||||
extract_term_value,
|
||||
wire_triples_to_tuples,
|
||||
ExplainabilityClient,
|
||||
TG_QUERY, TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT, TG_CHUNK_COUNT, TG_CONCEPT, TG_ENTITY,
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, TG_AGENT_QUESTION,
|
||||
PROV_STARTED_AT_TIME, PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entity from_triples parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExplainEntityFromTriples:
|
||||
"""Test ExplainEntity.from_triples dispatches to correct subclass."""
|
||||
|
||||
def test_graphrag_question(self):
|
||||
triples = [
|
||||
("urn:q:1", RDF_TYPE, TG_QUESTION),
|
||||
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:q:1", TG_QUERY, "What is AI?"),
|
||||
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:q:1", triples)
|
||||
assert isinstance(entity, Question)
|
||||
assert entity.query == "What is AI?"
|
||||
assert entity.timestamp == "2024-01-01T00:00:00Z"
|
||||
assert entity.question_type == "graph-rag"
|
||||
|
||||
def test_docrag_question(self):
|
||||
triples = [
|
||||
("urn:q:2", RDF_TYPE, TG_QUESTION),
|
||||
("urn:q:2", RDF_TYPE, TG_DOC_RAG_QUESTION),
|
||||
("urn:q:2", TG_QUERY, "Find info"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:q:2", triples)
|
||||
assert isinstance(entity, Question)
|
||||
assert entity.question_type == "document-rag"
|
||||
|
||||
def test_agent_question(self):
|
||||
triples = [
|
||||
("urn:q:3", RDF_TYPE, TG_QUESTION),
|
||||
("urn:q:3", RDF_TYPE, TG_AGENT_QUESTION),
|
||||
("urn:q:3", TG_QUERY, "Agent query"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:q:3", triples)
|
||||
assert isinstance(entity, Question)
|
||||
assert entity.question_type == "agent"
|
||||
|
||||
def test_grounding(self):
|
||||
triples = [
|
||||
("urn:gnd:1", RDF_TYPE, TG_GROUNDING),
|
||||
("urn:gnd:1", TG_CONCEPT, "machine learning"),
|
||||
("urn:gnd:1", TG_CONCEPT, "neural networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:gnd:1", triples)
|
||||
assert isinstance(entity, Grounding)
|
||||
assert len(entity.concepts) == 2
|
||||
assert "machine learning" in entity.concepts
|
||||
assert "neural networks" in entity.concepts
|
||||
|
||||
def test_exploration(self):
|
||||
triples = [
|
||||
("urn:exp:1", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:1", TG_EDGE_COUNT, "15"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:1", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert entity.edge_count == 15
|
||||
|
||||
def test_exploration_with_chunk_count(self):
|
||||
triples = [
|
||||
("urn:exp:2", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:2", TG_CHUNK_COUNT, "5"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:2", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert entity.chunk_count == 5
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:3", TG_EDGE_COUNT, "10"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:machine-learning"),
|
||||
("urn:exp:3", TG_ENTITY, "urn:e:neural-networks"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:3", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert len(entity.entities) == 2
|
||||
|
||||
def test_exploration_invalid_count(self):
|
||||
triples = [
|
||||
("urn:exp:3", RDF_TYPE, TG_EXPLORATION),
|
||||
("urn:exp:3", TG_EDGE_COUNT, "not-a-number"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:exp:3", triples)
|
||||
assert isinstance(entity, Exploration)
|
||||
assert entity.edge_count == 0
|
||||
|
||||
def test_focus(self):
|
||||
triples = [
|
||||
("urn:foc:1", RDF_TYPE, TG_FOCUS),
|
||||
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:1"),
|
||||
("urn:foc:1", TG_SELECTED_EDGE, "urn:edge:2"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:foc:1", triples)
|
||||
assert isinstance(entity, Focus)
|
||||
assert len(entity.selected_edge_uris) == 2
|
||||
assert "urn:edge:1" in entity.selected_edge_uris
|
||||
assert "urn:edge:2" in entity.selected_edge_uris
|
||||
|
||||
def test_synthesis_with_document(self):
|
||||
triples = [
|
||||
("urn:syn:1", RDF_TYPE, TG_SYNTHESIS),
|
||||
("urn:syn:1", TG_DOCUMENT, "urn:doc:answer-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:1", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document == "urn:doc:answer-1"
|
||||
|
||||
def test_synthesis_no_document(self):
|
||||
triples = [
|
||||
("urn:syn:2", RDF_TYPE, TG_SYNTHESIS),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:syn:2", triples)
|
||||
assert isinstance(entity, Synthesis)
|
||||
assert entity.document == ""
|
||||
|
||||
def test_reflection_thought(self):
|
||||
triples = [
|
||||
("urn:ref:1", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:1", RDF_TYPE, TG_THOUGHT_TYPE),
|
||||
("urn:ref:1", TG_DOCUMENT, "urn:doc:thought-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:1", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "thought"
|
||||
assert entity.document == "urn:doc:thought-1"
|
||||
|
||||
def test_reflection_observation(self):
|
||||
triples = [
|
||||
("urn:ref:2", RDF_TYPE, TG_REFLECTION_TYPE),
|
||||
("urn:ref:2", RDF_TYPE, TG_OBSERVATION_TYPE),
|
||||
("urn:ref:2", TG_DOCUMENT, "urn:doc:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ref:2", triples)
|
||||
assert isinstance(entity, Reflection)
|
||||
assert entity.reflection_type == "observation"
|
||||
assert entity.document == "urn:doc:obs-1"
|
||||
|
||||
def test_analysis(self):
|
||||
triples = [
|
||||
("urn:ana:1", RDF_TYPE, TG_ANALYSIS),
|
||||
("urn:ana:1", TG_ACTION, "graph-rag-query"),
|
||||
("urn:ana:1", TG_ARGUMENTS, '{"query": "test"}'),
|
||||
("urn:ana:1", TG_THOUGHT, "urn:ref:thought-1"),
|
||||
("urn:ana:1", TG_OBSERVATION, "urn:ref:obs-1"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:ana:1", triples)
|
||||
assert isinstance(entity, Analysis)
|
||||
assert entity.action == "graph-rag-query"
|
||||
assert entity.arguments == '{"query": "test"}'
|
||||
assert entity.thought == "urn:ref:thought-1"
|
||||
assert entity.observation == "urn:ref:obs-1"
|
||||
|
||||
def test_conclusion_with_document(self):
|
||||
triples = [
|
||||
("urn:conc:1", RDF_TYPE, TG_CONCLUSION),
|
||||
("urn:conc:1", TG_DOCUMENT, "urn:doc:final"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:1", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document == "urn:doc:final"
|
||||
|
||||
def test_conclusion_no_document(self):
|
||||
triples = [
|
||||
("urn:conc:2", RDF_TYPE, TG_CONCLUSION),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:conc:2", triples)
|
||||
assert isinstance(entity, Conclusion)
|
||||
assert entity.document == ""
|
||||
|
||||
def test_unknown_type(self):
|
||||
triples = [
|
||||
("urn:x:1", RDF_TYPE, "http://example.com/UnknownType"),
|
||||
]
|
||||
entity = ExplainEntity.from_triples("urn:x:1", triples)
|
||||
assert isinstance(entity, ExplainEntity)
|
||||
assert entity.entity_type == "unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_edge_selection_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseEdgeSelectionTriples:
|
||||
|
||||
def test_with_edge_and_reasoning(self):
|
||||
triples = [
|
||||
("urn:edge:1", TG_EDGE, {"s": "Alice", "p": "knows", "o": "Bob"}),
|
||||
("urn:edge:1", TG_REASONING, "Alice and Bob are connected"),
|
||||
]
|
||||
result = parse_edge_selection_triples(triples)
|
||||
assert isinstance(result, EdgeSelection)
|
||||
assert result.uri == "urn:edge:1"
|
||||
assert result.edge == {"s": "Alice", "p": "knows", "o": "Bob"}
|
||||
assert result.reasoning == "Alice and Bob are connected"
|
||||
|
||||
def test_with_edge_only(self):
|
||||
triples = [
|
||||
("urn:edge:2", TG_EDGE, {"s": "A", "p": "r", "o": "B"}),
|
||||
]
|
||||
result = parse_edge_selection_triples(triples)
|
||||
assert result.edge is not None
|
||||
assert result.reasoning == ""
|
||||
|
||||
def test_with_reasoning_only(self):
|
||||
triples = [
|
||||
("urn:edge:3", TG_REASONING, "some reason"),
|
||||
]
|
||||
result = parse_edge_selection_triples(triples)
|
||||
assert result.edge is None
|
||||
assert result.reasoning == "some reason"
|
||||
|
||||
def test_empty_triples(self):
|
||||
result = parse_edge_selection_triples([])
|
||||
assert result.uri == ""
|
||||
assert result.edge is None
|
||||
assert result.reasoning == ""
|
||||
|
||||
def test_edge_must_be_dict(self):
|
||||
"""Non-dict values for TG_EDGE should not be treated as edges."""
|
||||
triples = [
|
||||
("urn:edge:4", TG_EDGE, "not-a-dict"),
|
||||
]
|
||||
result = parse_edge_selection_triples(triples)
|
||||
assert result.edge is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_term_value
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractTermValue:
|
||||
|
||||
def test_iri_short_format(self):
|
||||
assert extract_term_value({"t": "i", "i": "urn:test"}) == "urn:test"
|
||||
|
||||
def test_iri_long_format(self):
|
||||
assert extract_term_value({"type": "i", "iri": "urn:test"}) == "urn:test"
|
||||
|
||||
def test_literal_short_format(self):
|
||||
assert extract_term_value({"t": "l", "v": "hello"}) == "hello"
|
||||
|
||||
def test_literal_long_format(self):
|
||||
assert extract_term_value({"type": "l", "value": "hello"}) == "hello"
|
||||
|
||||
def test_quoted_triple(self):
|
||||
term = {
|
||||
"t": "t",
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": "urn:s"},
|
||||
"p": {"t": "i", "i": "urn:p"},
|
||||
"o": {"t": "i", "i": "urn:o"},
|
||||
}
|
||||
}
|
||||
result = extract_term_value(term)
|
||||
assert result == {"s": "urn:s", "p": "urn:p", "o": "urn:o"}
|
||||
|
||||
def test_quoted_triple_long_format(self):
|
||||
term = {
|
||||
"type": "t",
|
||||
"triple": {
|
||||
"s": {"type": "i", "iri": "urn:s"},
|
||||
"p": {"type": "i", "iri": "urn:p"},
|
||||
"o": {"type": "l", "value": "val"},
|
||||
}
|
||||
}
|
||||
result = extract_term_value(term)
|
||||
assert result == {"s": "urn:s", "p": "urn:p", "o": "val"}
|
||||
|
||||
def test_unknown_type_fallback(self):
|
||||
result = extract_term_value({"t": "x", "i": "urn:fallback"})
|
||||
assert result == "urn:fallback"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wire_triples_to_tuples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWireTriplesToTuples:
|
||||
|
||||
def test_basic_conversion(self):
|
||||
wire = [
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:s1"},
|
||||
"p": {"t": "i", "i": "urn:p1"},
|
||||
"o": {"t": "l", "v": "value1"},
|
||||
},
|
||||
]
|
||||
result = wire_triples_to_tuples(wire)
|
||||
assert len(result) == 1
|
||||
assert result[0] == ("urn:s1", "urn:p1", "value1")
|
||||
|
||||
def test_multiple_triples(self):
|
||||
wire = [
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:s1"},
|
||||
"p": {"t": "i", "i": "urn:p1"},
|
||||
"o": {"t": "l", "v": "v1"},
|
||||
},
|
||||
{
|
||||
"s": {"t": "i", "i": "urn:s2"},
|
||||
"p": {"t": "i", "i": "urn:p2"},
|
||||
"o": {"t": "i", "i": "urn:o2"},
|
||||
},
|
||||
]
|
||||
result = wire_triples_to_tuples(wire)
|
||||
assert len(result) == 2
|
||||
assert result[0] == ("urn:s1", "urn:p1", "v1")
|
||||
assert result[1] == ("urn:s2", "urn:p2", "urn:o2")
|
||||
|
||||
def test_empty_list(self):
|
||||
assert wire_triples_to_tuples([]) == []
|
||||
|
||||
def test_missing_fields(self):
|
||||
wire = [{"s": {}, "p": {}, "o": {}}]
|
||||
result = wire_triples_to_tuples(wire)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExplainabilityClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_wire_triples(tuples):
|
||||
"""Convert (s, p, o) tuples to wire format for mocking."""
|
||||
result = []
|
||||
for s, p, o in tuples:
|
||||
entry = {
|
||||
"s": {"t": "i", "i": s},
|
||||
"p": {"t": "i", "i": p},
|
||||
}
|
||||
if o.startswith("urn:") or o.startswith("http"):
|
||||
entry["o"] = {"t": "i", "i": o}
|
||||
else:
|
||||
entry["o"] = {"t": "l", "v": o}
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
|
||||
class TestExplainabilityClientFetchEntity:
|
||||
|
||||
def test_fetch_question_entity(self):
|
||||
wire = _make_wire_triples([
|
||||
("urn:q:1", RDF_TYPE, TG_QUESTION),
|
||||
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:q:1", TG_QUERY, "What is AI?"),
|
||||
("urn:q:1", PROV_STARTED_AT_TIME, "2024-01-01T00:00:00Z"),
|
||||
])
|
||||
|
||||
mock_flow = MagicMock()
|
||||
# Return same results twice for quiescence
|
||||
mock_flow.triples_query.side_effect = [wire, wire]
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
entity = client.fetch_entity("urn:q:1", graph="urn:graph:retrieval")
|
||||
|
||||
assert isinstance(entity, Question)
|
||||
assert entity.query == "What is AI?"
|
||||
assert entity.question_type == "graph-rag"
|
||||
|
||||
def test_fetch_returns_none_when_no_data(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.return_value = []
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0, max_retries=2)
|
||||
entity = client.fetch_entity("urn:nonexistent")
|
||||
|
||||
assert entity is None
|
||||
|
||||
def test_fetch_retries_on_empty_results(self):
|
||||
wire = _make_wire_triples([
|
||||
("urn:q:1", RDF_TYPE, TG_QUESTION),
|
||||
("urn:q:1", RDF_TYPE, TG_GRAPH_RAG_QUESTION),
|
||||
("urn:q:1", TG_QUERY, "Q"),
|
||||
])
|
||||
|
||||
mock_flow = MagicMock()
|
||||
# Empty, then data, then same data (stable)
|
||||
mock_flow.triples_query.side_effect = [[], wire, wire]
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
entity = client.fetch_entity("urn:q:1")
|
||||
|
||||
assert isinstance(entity, Question)
|
||||
assert mock_flow.triples_query.call_count == 3
|
||||
|
||||
|
||||
class TestExplainabilityClientResolveLabel:
|
||||
|
||||
def test_resolve_label_found(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.return_value = _make_wire_triples([
|
||||
("urn:entity:1", RDFS_LABEL, "Entity One"),
|
||||
])
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
label = client.resolve_label("urn:entity:1")
|
||||
assert label == "Entity One"
|
||||
|
||||
def test_resolve_label_not_found(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.return_value = []
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
label = client.resolve_label("urn:entity:1")
|
||||
assert label == "urn:entity:1"
|
||||
|
||||
def test_resolve_label_cached(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.triples_query.return_value = _make_wire_triples([
|
||||
("urn:entity:1", RDFS_LABEL, "Entity One"),
|
||||
])
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
client.resolve_label("urn:entity:1")
|
||||
client.resolve_label("urn:entity:1")
|
||||
|
||||
# Only one query should be made
|
||||
assert mock_flow.triples_query.call_count == 1
|
||||
|
||||
def test_resolve_label_non_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.resolve_label("plain text") == "plain text"
|
||||
assert client.resolve_label("") == ""
|
||||
mock_flow.triples_query.assert_not_called()
|
||||
|
||||
def test_resolve_edge_labels(self):
|
||||
mock_flow = MagicMock()
|
||||
|
||||
def mock_query(s=None, p=None, **kwargs):
|
||||
labels = {
|
||||
"urn:e:Alice": "Alice",
|
||||
"urn:r:knows": "knows",
|
||||
"urn:e:Bob": "Bob",
|
||||
}
|
||||
if s in labels:
|
||||
return _make_wire_triples([(s, RDFS_LABEL, labels[s])])
|
||||
return []
|
||||
|
||||
mock_flow.triples_query.side_effect = mock_query
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
s, p, o = client.resolve_edge_labels(
|
||||
{"s": "urn:e:Alice", "p": "urn:r:knows", "o": "urn:e:Bob"}
|
||||
)
|
||||
assert s == "Alice"
|
||||
assert p == "knows"
|
||||
assert o == "Bob"
|
||||
|
||||
|
||||
class TestExplainabilityClientContentFetching:
|
||||
|
||||
def test_fetch_document_content_from_librarian(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
mock_api.library.return_value = mock_library
|
||||
mock_library.get_document_content.return_value = b"librarian content"
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
result = client.fetch_document_content(
|
||||
"urn:document:abc123", api=mock_api
|
||||
)
|
||||
assert result == "librarian content"
|
||||
|
||||
def test_fetch_document_content_truncated(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_library = MagicMock()
|
||||
mock_api.library.return_value = mock_library
|
||||
mock_library.get_document_content.return_value = b"x" * 20000
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
result = client.fetch_document_content(
|
||||
"urn:doc:1", api=mock_api, max_content=100
|
||||
)
|
||||
assert len(result) < 20000
|
||||
assert result.endswith("... [truncated]")
|
||||
|
||||
def test_fetch_document_content_empty_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
result = client.fetch_document_content("", api=mock_api)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestExplainabilityClientDetectSessionType:
|
||||
|
||||
def test_detect_agent_from_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:agent:abc") == "agent"
|
||||
|
||||
def test_detect_graphrag_from_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:question:abc") == "graphrag"
|
||||
|
||||
def test_detect_docrag_from_uri(self):
|
||||
mock_flow = MagicMock()
|
||||
client = ExplainabilityClient(mock_flow, retry_delay=0.0)
|
||||
assert client.detect_session_type("urn:trustgraph:docrag:abc") == "docrag"
|
||||
812
tests/unit/test_provenance/test_triples.py
Normal file
812
tests/unit/test_provenance/test_triples.py
Normal file
|
|
@ -0,0 +1,812 @@
|
|||
"""
|
||||
Tests for provenance triple builder functions (extraction-time and query-time).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from trustgraph.schema import Triple, Term, IRI, LITERAL, TRIPLE
|
||||
|
||||
from trustgraph.provenance.triples import (
|
||||
set_graph,
|
||||
document_triples,
|
||||
derived_entity_triples,
|
||||
subgraph_provenance_triples,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_synthesis_triples,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
|
||||
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
|
||||
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
|
||||
TG_PAGE_COUNT, TG_MIME_TYPE, TG_PAGE_NUMBER,
|
||||
TG_CHUNK_INDEX, TG_CHAR_OFFSET, TG_CHAR_LENGTH,
|
||||
TG_CHUNK_SIZE, TG_CHUNK_OVERLAP, TG_COMPONENT_VERSION,
|
||||
TG_LLM_MODEL, TG_ONTOLOGY, TG_CONTAINS,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_DOCUMENT,
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANSWER_TYPE,
|
||||
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
GRAPH_SOURCE, GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
"""Find first triple matching predicate (and optionally subject)."""
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
"""Find all triples matching predicate (and optionally subject)."""
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
"""Check if subject has rdf:type rdf_type."""
|
||||
for t in triples:
|
||||
if (t.s.iri == subject and t.p.iri == RDF_TYPE
|
||||
and t.o.type == IRI and t.o.iri == rdf_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_graph
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSetGraph:
|
||||
|
||||
def test_sets_graph_on_all_triples(self):
|
||||
triples = [
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="urn:s1"),
|
||||
p=Term(type=IRI, iri="urn:p1"),
|
||||
o=Term(type=LITERAL, value="v1"),
|
||||
),
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="urn:s2"),
|
||||
p=Term(type=IRI, iri="urn:p2"),
|
||||
o=Term(type=LITERAL, value="v2"),
|
||||
),
|
||||
]
|
||||
result = set_graph(triples, GRAPH_RETRIEVAL)
|
||||
assert len(result) == 2
|
||||
for t in result:
|
||||
assert t.g == GRAPH_RETRIEVAL
|
||||
|
||||
def test_does_not_modify_originals(self):
|
||||
original = Triple(
|
||||
s=Term(type=IRI, iri="urn:s"),
|
||||
p=Term(type=IRI, iri="urn:p"),
|
||||
o=Term(type=LITERAL, value="v"),
|
||||
)
|
||||
result = set_graph([original], "urn:graph:test")
|
||||
assert original.g is None
|
||||
assert result[0].g == "urn:graph:test"
|
||||
|
||||
def test_empty_list(self):
|
||||
result = set_graph([], GRAPH_SOURCE)
|
||||
assert result == []
|
||||
|
||||
def test_preserves_spo(self):
|
||||
original = Triple(
|
||||
s=Term(type=IRI, iri="urn:s"),
|
||||
p=Term(type=IRI, iri="urn:p"),
|
||||
o=Term(type=LITERAL, value="hello"),
|
||||
)
|
||||
result = set_graph([original], "urn:g")[0]
|
||||
assert result.s.iri == "urn:s"
|
||||
assert result.p.iri == "urn:p"
|
||||
assert result.o.value == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# document_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentTriples:
|
||||
|
||||
DOC_URI = "https://example.com/doc/abc"
|
||||
|
||||
def test_minimal_document(self):
|
||||
triples = document_triples(self.DOC_URI)
|
||||
assert has_type(triples, self.DOC_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.DOC_URI, TG_DOCUMENT_TYPE)
|
||||
assert len(triples) == 2
|
||||
|
||||
def test_with_title(self):
|
||||
triples = document_triples(self.DOC_URI, title="My Doc")
|
||||
title_t = find_triple(triples, DC_TITLE)
|
||||
assert title_t is not None
|
||||
assert title_t.o.value == "My Doc"
|
||||
# Title also creates an rdfs:label
|
||||
label_t = find_triple(triples, RDFS_LABEL)
|
||||
assert label_t is not None
|
||||
assert label_t.o.value == "My Doc"
|
||||
|
||||
def test_with_source(self):
|
||||
triples = document_triples(self.DOC_URI, source="https://source.com/f.pdf")
|
||||
source_t = find_triple(triples, DC_SOURCE)
|
||||
assert source_t is not None
|
||||
assert source_t.o.type == IRI
|
||||
assert source_t.o.iri == "https://source.com/f.pdf"
|
||||
|
||||
def test_with_date(self):
|
||||
triples = document_triples(self.DOC_URI, date="2024-01-15")
|
||||
date_t = find_triple(triples, DC_DATE)
|
||||
assert date_t is not None
|
||||
assert date_t.o.value == "2024-01-15"
|
||||
|
||||
def test_with_creator(self):
|
||||
triples = document_triples(self.DOC_URI, creator="Alice")
|
||||
creator_t = find_triple(triples, DC_CREATOR)
|
||||
assert creator_t is not None
|
||||
assert creator_t.o.value == "Alice"
|
||||
|
||||
def test_with_page_count(self):
|
||||
triples = document_triples(self.DOC_URI, page_count=42)
|
||||
pc_t = find_triple(triples, TG_PAGE_COUNT)
|
||||
assert pc_t is not None
|
||||
assert pc_t.o.value == "42"
|
||||
|
||||
def test_with_page_count_zero(self):
|
||||
triples = document_triples(self.DOC_URI, page_count=0)
|
||||
pc_t = find_triple(triples, TG_PAGE_COUNT)
|
||||
assert pc_t is not None
|
||||
assert pc_t.o.value == "0"
|
||||
|
||||
def test_with_mime_type(self):
|
||||
triples = document_triples(self.DOC_URI, mime_type="application/pdf")
|
||||
mt_t = find_triple(triples, TG_MIME_TYPE)
|
||||
assert mt_t is not None
|
||||
assert mt_t.o.value == "application/pdf"
|
||||
|
||||
def test_all_metadata(self):
|
||||
triples = document_triples(
|
||||
self.DOC_URI,
|
||||
title="Test",
|
||||
source="https://s.com",
|
||||
date="2024-01-01",
|
||||
creator="Bob",
|
||||
page_count=10,
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
# 2 type triples + title + label + source + date + creator + page_count + mime_type
|
||||
assert len(triples) == 9
|
||||
|
||||
def test_subject_is_doc_uri(self):
|
||||
triples = document_triples(self.DOC_URI, title="T")
|
||||
for t in triples:
|
||||
assert t.s.iri == self.DOC_URI
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derived_entity_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDerivedEntityTriples:
|
||||
|
||||
ENTITY_URI = "https://example.com/doc/abc/p1"
|
||||
PARENT_URI = "https://example.com/doc/abc"
|
||||
|
||||
def test_page_entity_has_page_type(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
page_number=1,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
|
||||
|
||||
def test_chunk_entity_has_chunk_type(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"chunker", "1.0",
|
||||
chunk_index=0,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
assert has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
|
||||
|
||||
def test_no_specific_type_without_page_or_chunk(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"component", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
|
||||
assert not has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
|
||||
assert not has_type(triples, self.ENTITY_URI, TG_CHUNK_TYPE)
|
||||
|
||||
def test_was_derived_from_parent(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.ENTITY_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.PARENT_URI
|
||||
|
||||
def test_activity_created(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
# Entity was generated by an activity
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
|
||||
assert gen is not None
|
||||
act_uri = gen.o.iri
|
||||
|
||||
# Activity has correct type and metadata
|
||||
assert has_type(triples, act_uri, PROV_ACTIVITY)
|
||||
|
||||
# Activity used the parent
|
||||
used = find_triple(triples, PROV_USED, act_uri)
|
||||
assert used is not None
|
||||
assert used.o.iri == self.PARENT_URI
|
||||
|
||||
# Activity has component version
|
||||
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
|
||||
assert version is not None
|
||||
assert version.o.value == "1.0"
|
||||
|
||||
def test_agent_created(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
# Find the agent URI via wasAssociatedWith
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.ENTITY_URI)
|
||||
act_uri = gen.o.iri
|
||||
assoc = find_triple(triples, PROV_WAS_ASSOCIATED_WITH, act_uri)
|
||||
assert assoc is not None
|
||||
agt_uri = assoc.o.iri
|
||||
|
||||
assert has_type(triples, agt_uri, PROV_AGENT)
|
||||
label = find_triple(triples, RDFS_LABEL, agt_uri)
|
||||
assert label is not None
|
||||
assert label.o.value == "pdf-extractor"
|
||||
|
||||
def test_timestamp_recorded(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
timestamp="2024-06-15T12:30:00Z",
|
||||
)
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME)
|
||||
assert ts is not None
|
||||
assert ts.o.value == "2024-06-15T12:30:00Z"
|
||||
|
||||
def test_default_timestamp_generated(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
)
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME)
|
||||
assert ts is not None
|
||||
assert len(ts.o.value) > 0
|
||||
|
||||
def test_optional_label(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
label="Page 1",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
label = find_triple(triples, RDFS_LABEL, self.ENTITY_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Page 1"
|
||||
|
||||
def test_page_number_recorded(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"pdf-extractor", "1.0",
|
||||
page_number=3,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
pn = find_triple(triples, TG_PAGE_NUMBER, self.ENTITY_URI)
|
||||
assert pn is not None
|
||||
assert pn.o.value == "3"
|
||||
|
||||
def test_chunk_metadata_recorded(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"chunker", "2.0",
|
||||
chunk_index=5,
|
||||
char_offset=1000,
|
||||
char_length=500,
|
||||
chunk_size=512,
|
||||
chunk_overlap=64,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
ci = find_triple(triples, TG_CHUNK_INDEX, self.ENTITY_URI)
|
||||
assert ci is not None and ci.o.value == "5"
|
||||
|
||||
co = find_triple(triples, TG_CHAR_OFFSET, self.ENTITY_URI)
|
||||
assert co is not None and co.o.value == "1000"
|
||||
|
||||
cl = find_triple(triples, TG_CHAR_LENGTH, self.ENTITY_URI)
|
||||
assert cl is not None and cl.o.value == "500"
|
||||
|
||||
# chunk_size and chunk_overlap are on the activity, not the entity
|
||||
cs = find_triple(triples, TG_CHUNK_SIZE)
|
||||
assert cs is not None and cs.o.value == "512"
|
||||
|
||||
ov = find_triple(triples, TG_CHUNK_OVERLAP)
|
||||
assert ov is not None and ov.o.value == "64"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# subgraph_provenance_triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubgraphProvenanceTriples:
|
||||
|
||||
SG_URI = "https://trustgraph.ai/subgraph/test-sg"
|
||||
CHUNK_URI = "https://example.com/doc/abc/p1/c0"
|
||||
|
||||
def _make_extracted_triple(self, s="urn:e:Alice", p="urn:r:knows", o="urn:e:Bob"):
|
||||
return Triple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o),
|
||||
)
|
||||
|
||||
def test_contains_quoted_triples(self):
|
||||
extracted = [self._make_extracted_triple()]
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, extracted, self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
|
||||
assert len(contains) == 1
|
||||
assert contains[0].o.type == TRIPLE
|
||||
assert contains[0].o.triple.s.iri == "urn:e:Alice"
|
||||
assert contains[0].o.triple.p.iri == "urn:r:knows"
|
||||
assert contains[0].o.triple.o.iri == "urn:e:Bob"
|
||||
|
||||
def test_multiple_extracted_triples(self):
|
||||
extracted = [
|
||||
self._make_extracted_triple("urn:e:A", "urn:r:x", "urn:e:B"),
|
||||
self._make_extracted_triple("urn:e:C", "urn:r:y", "urn:e:D"),
|
||||
self._make_extracted_triple("urn:e:E", "urn:r:z", "urn:e:F"),
|
||||
]
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, extracted, self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
|
||||
assert len(contains) == 3
|
||||
|
||||
def test_empty_extracted_triples(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
contains = find_triples(triples, TG_CONTAINS, self.SG_URI)
|
||||
assert len(contains) == 0
|
||||
# Should still have subgraph provenance metadata
|
||||
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
|
||||
|
||||
def test_subgraph_has_correct_types(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
assert has_type(triples, self.SG_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SG_URI, TG_SUBGRAPH_TYPE)
|
||||
|
||||
def test_derived_from_chunk(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SG_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.CHUNK_URI
|
||||
|
||||
def test_activity_and_agent(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.SG_URI)
|
||||
assert gen is not None
|
||||
act_uri = gen.o.iri
|
||||
|
||||
assert has_type(triples, act_uri, PROV_ACTIVITY)
|
||||
|
||||
used = find_triple(triples, PROV_USED, act_uri)
|
||||
assert used is not None
|
||||
assert used.o.iri == self.CHUNK_URI
|
||||
|
||||
version = find_triple(triples, TG_COMPONENT_VERSION, act_uri)
|
||||
assert version is not None
|
||||
assert version.o.value == "1.0"
|
||||
|
||||
def test_optional_llm_model(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
llm_model="claude-3-opus",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
llm = find_triple(triples, TG_LLM_MODEL)
|
||||
assert llm is not None
|
||||
assert llm.o.value == "claude-3-opus"
|
||||
|
||||
def test_no_llm_model_when_omitted(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
llm = find_triple(triples, TG_LLM_MODEL)
|
||||
assert llm is None
|
||||
|
||||
def test_optional_ontology(self):
|
||||
triples = subgraph_provenance_triples(
|
||||
self.SG_URI, [], self.CHUNK_URI,
|
||||
"kg-extractor", "1.0",
|
||||
ontology_uri="https://example.com/ontology/v1",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
)
|
||||
ont = find_triple(triples, TG_ONTOLOGY)
|
||||
assert ont is not None
|
||||
assert ont.o.type == IRI
|
||||
assert ont.o.iri == "https://example.com/ontology/v1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAG query-time triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuestionTriples:
|
||||
|
||||
Q_URI = "urn:trustgraph:question:test-session"
|
||||
|
||||
def test_question_types(self):
|
||||
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_GRAPH_RAG_QUESTION)
|
||||
|
||||
def test_question_query_text(self):
|
||||
triples = question_triples(self.Q_URI, "What is AI?", "2024-01-01T00:00:00Z")
|
||||
query = find_triple(triples, TG_QUERY, self.Q_URI)
|
||||
assert query is not None
|
||||
assert query.o.value == "What is AI?"
|
||||
|
||||
def test_question_timestamp(self):
|
||||
triples = question_triples(self.Q_URI, "Q", "2024-06-15T10:00:00Z")
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
|
||||
assert ts is not None
|
||||
assert ts.o.value == "2024-06-15T10:00:00Z"
|
||||
|
||||
def test_question_default_timestamp(self):
|
||||
triples = question_triples(self.Q_URI, "Q")
|
||||
ts = find_triple(triples, PROV_STARTED_AT_TIME, self.Q_URI)
|
||||
assert ts is not None
|
||||
assert len(ts.o.value) > 0
|
||||
|
||||
def test_question_label(self):
|
||||
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
|
||||
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "GraphRAG Question"
|
||||
|
||||
def test_question_triple_count(self):
|
||||
triples = question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
|
||||
assert len(triples) == 6
|
||||
|
||||
|
||||
class TestGroundingTriples:
|
||||
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
Q_URI = "urn:trustgraph:question:test-session"
|
||||
|
||||
def test_grounding_types(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML"])
|
||||
assert has_type(triples, self.GND_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.GND_URI, TG_GROUNDING)
|
||||
|
||||
def test_grounding_generated_by_question(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI"])
|
||||
gen = find_triple(triples, PROV_WAS_GENERATED_BY, self.GND_URI)
|
||||
assert gen is not None
|
||||
assert gen.o.iri == self.Q_URI
|
||||
|
||||
def test_grounding_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, ["AI", "ML", "robots"])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 3
|
||||
values = {t.o.value for t in concepts}
|
||||
assert values == {"AI", "ML", "robots"}
|
||||
|
||||
def test_grounding_empty_concepts(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
concepts = find_triples(triples, TG_CONCEPT, self.GND_URI)
|
||||
assert len(concepts) == 0
|
||||
|
||||
def test_grounding_label(self):
|
||||
triples = grounding_triples(self.GND_URI, self.Q_URI, [])
|
||||
label = find_triple(triples, RDFS_LABEL, self.GND_URI)
|
||||
assert label is not None
|
||||
assert label.o.value == "Grounding"
|
||||
|
||||
|
||||
class TestExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
GND_URI = "urn:trustgraph:prov:grounding:test-session"
|
||||
|
||||
def test_exploration_types(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_exploration_derived_from_grounding(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_exploration_edge_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 15)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "15"
|
||||
|
||||
def test_exploration_zero_edges(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 0)
|
||||
ec = find_triple(triples, TG_EDGE_COUNT, self.EXP_URI)
|
||||
assert ec is not None
|
||||
assert ec.o.value == "0"
|
||||
|
||||
def test_exploration_with_entities(self):
|
||||
entities = ["urn:e:machine-learning", "urn:e:neural-networks"]
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10, entities=entities)
|
||||
ent_triples = find_triples(triples, TG_ENTITY, self.EXP_URI)
|
||||
assert len(ent_triples) == 2
|
||||
|
||||
def test_exploration_triple_count(self):
|
||||
triples = exploration_triples(self.EXP_URI, self.GND_URI, 10)
|
||||
assert len(triples) == 5
|
||||
|
||||
|
||||
class TestFocusTriples:
|
||||
|
||||
FOC_URI = "urn:trustgraph:prov:focus:test-session"
|
||||
EXP_URI = "urn:trustgraph:prov:exploration:test-session"
|
||||
SESSION_ID = "test-session"
|
||||
|
||||
def test_focus_types(self):
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
|
||||
assert has_type(triples, self.FOC_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.FOC_URI, TG_FOCUS)
|
||||
|
||||
def test_focus_derived_from_exploration(self):
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.FOC_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.EXP_URI
|
||||
|
||||
def test_focus_no_edges(self):
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, [], self.SESSION_ID)
|
||||
selected = find_triples(triples, TG_SELECTED_EDGE)
|
||||
assert len(selected) == 0
|
||||
|
||||
def test_focus_with_edges_and_reasoning(self):
|
||||
edges = [
|
||||
{
|
||||
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
|
||||
"reasoning": "Alice is connected to Bob",
|
||||
},
|
||||
{
|
||||
"edge": ("urn:e:Bob", "urn:r:worksAt", "urn:e:Acme"),
|
||||
"reasoning": "Bob works at Acme",
|
||||
},
|
||||
]
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
|
||||
|
||||
# Two selectedEdge links
|
||||
selected = find_triples(triples, TG_SELECTED_EDGE, self.FOC_URI)
|
||||
assert len(selected) == 2
|
||||
|
||||
# Each edge selection has a quoted triple
|
||||
edge_triples = find_triples(triples, TG_EDGE)
|
||||
assert len(edge_triples) == 2
|
||||
for et in edge_triples:
|
||||
assert et.o.type == TRIPLE
|
||||
|
||||
# Each edge selection has reasoning
|
||||
reasoning_triples = find_triples(triples, TG_REASONING)
|
||||
assert len(reasoning_triples) == 2
|
||||
|
||||
def test_focus_edge_without_reasoning(self):
|
||||
edges = [
|
||||
{"edge": ("urn:e:A", "urn:r:x", "urn:e:B"), "reasoning": ""},
|
||||
]
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
|
||||
reasoning = find_triples(triples, TG_REASONING)
|
||||
assert len(reasoning) == 0
|
||||
|
||||
def test_focus_edge_without_edge_data(self):
|
||||
edges = [
|
||||
{"edge": None, "reasoning": "some reasoning"},
|
||||
]
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
|
||||
selected = find_triples(triples, TG_SELECTED_EDGE)
|
||||
assert len(selected) == 0
|
||||
|
||||
def test_focus_quoted_triple_content(self):
|
||||
edges = [
|
||||
{
|
||||
"edge": ("urn:e:Alice", "urn:r:knows", "urn:e:Bob"),
|
||||
"reasoning": "test",
|
||||
},
|
||||
]
|
||||
triples = focus_triples(self.FOC_URI, self.EXP_URI, edges, self.SESSION_ID)
|
||||
edge_t = find_triple(triples, TG_EDGE)
|
||||
qt = edge_t.o.triple
|
||||
assert qt.s.iri == "urn:e:Alice"
|
||||
assert qt.p.iri == "urn:r:knows"
|
||||
assert qt.o.iri == "urn:e:Bob"
|
||||
|
||||
|
||||
class TestSynthesisTriples:
|
||||
|
||||
SYN_URI = "urn:trustgraph:prov:synthesis:test-session"
|
||||
FOC_URI = "urn:trustgraph:prov:focus:test-session"
|
||||
|
||||
def test_synthesis_types(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_synthesis_derived_from_focus(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
|
||||
assert derived is not None
|
||||
assert derived.o.iri == self.FOC_URI
|
||||
|
||||
def test_synthesis_with_document_reference(self):
|
||||
triples = synthesis_triples(
|
||||
self.SYN_URI, self.FOC_URI,
|
||||
document_id="urn:trustgraph:question:abc/answer",
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is not None
|
||||
assert doc.o.type == IRI
|
||||
assert doc.o.iri == "urn:trustgraph:question:abc/answer"
|
||||
|
||||
def test_synthesis_no_document(self):
|
||||
triples = synthesis_triples(self.SYN_URI, self.FOC_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentRAG query-time triples
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocRagQuestionTriples:
|
||||
|
||||
Q_URI = "urn:trustgraph:docrag:test-session"
|
||||
|
||||
def test_docrag_question_types(self):
|
||||
triples = docrag_question_triples(self.Q_URI, "Find info", "2024-01-01T00:00:00Z")
|
||||
assert has_type(triples, self.Q_URI, PROV_ACTIVITY)
|
||||
assert has_type(triples, self.Q_URI, TG_QUESTION)
|
||||
assert has_type(triples, self.Q_URI, TG_DOC_RAG_QUESTION)
|
||||
|
||||
def test_docrag_question_label(self):
|
||||
triples = docrag_question_triples(self.Q_URI, "Q", "2024-01-01T00:00:00Z")
|
||||
label = find_triple(triples, RDFS_LABEL, self.Q_URI)
|
||||
assert label.o.value == "DocumentRAG Question"
|
||||
|
||||
def test_docrag_question_query_text(self):
|
||||
triples = docrag_question_triples(self.Q_URI, "search query", "2024-01-01T00:00:00Z")
|
||||
query = find_triple(triples, TG_QUERY, self.Q_URI)
|
||||
assert query.o.value == "search query"
|
||||
|
||||
|
||||
class TestDocRagExplorationTriples:
|
||||
|
||||
EXP_URI = "urn:trustgraph:docrag:test/exploration"
|
||||
GND_URI = "urn:trustgraph:docrag:test/grounding"
|
||||
|
||||
def test_docrag_exploration_types(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
assert has_type(triples, self.EXP_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.EXP_URI, TG_EXPLORATION)
|
||||
|
||||
def test_docrag_exploration_derived_from_grounding(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 5)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.EXP_URI)
|
||||
assert derived.o.iri == self.GND_URI
|
||||
|
||||
def test_docrag_exploration_chunk_count(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 7)
|
||||
cc = find_triple(triples, TG_CHUNK_COUNT, self.EXP_URI)
|
||||
assert cc.o.value == "7"
|
||||
|
||||
def test_docrag_exploration_without_chunk_ids(self):
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK)
|
||||
assert len(chunks) == 0
|
||||
|
||||
def test_docrag_exploration_with_chunk_ids(self):
|
||||
chunk_ids = ["urn:chunk:1", "urn:chunk:2", "urn:chunk:3"]
|
||||
triples = docrag_exploration_triples(self.EXP_URI, self.GND_URI, 3, chunk_ids)
|
||||
chunks = find_triples(triples, TG_SELECTED_CHUNK, self.EXP_URI)
|
||||
assert len(chunks) == 3
|
||||
chunk_uris = {t.o.iri for t in chunks}
|
||||
assert chunk_uris == set(chunk_ids)
|
||||
|
||||
|
||||
class TestDocRagSynthesisTriples:
|
||||
|
||||
SYN_URI = "urn:trustgraph:docrag:test/synthesis"
|
||||
EXP_URI = "urn:trustgraph:docrag:test/exploration"
|
||||
|
||||
def test_docrag_synthesis_types(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
assert has_type(triples, self.SYN_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.SYN_URI, TG_SYNTHESIS)
|
||||
|
||||
def test_docrag_synthesis_derived_from_exploration(self):
|
||||
"""DocRAG skips the focus step — synthesis derives from exploration."""
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
derived = find_triple(triples, PROV_WAS_DERIVED_FROM, self.SYN_URI)
|
||||
assert derived.o.iri == self.EXP_URI
|
||||
|
||||
def test_docrag_synthesis_has_answer_type(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
assert has_type(triples, self.SYN_URI, TG_ANSWER_TYPE)
|
||||
|
||||
def test_docrag_synthesis_with_document(self):
|
||||
triples = docrag_synthesis_triples(
|
||||
self.SYN_URI, self.EXP_URI, document_id="urn:doc:ans"
|
||||
)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc.o.iri == "urn:doc:ans"
|
||||
|
||||
def test_docrag_synthesis_no_document(self):
|
||||
triples = docrag_synthesis_triples(self.SYN_URI, self.EXP_URI)
|
||||
doc = find_triple(triples, TG_DOCUMENT, self.SYN_URI)
|
||||
assert doc is None
|
||||
292
tests/unit/test_provenance/test_uris.py
Normal file
292
tests/unit/test_provenance/test_uris.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""
|
||||
Tests for provenance URI generation functions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from trustgraph.provenance.uris import (
|
||||
TRUSTGRAPH_BASE,
|
||||
_encode_id,
|
||||
document_uri,
|
||||
page_uri,
|
||||
chunk_uri_from_page,
|
||||
chunk_uri_from_doc,
|
||||
activity_uri,
|
||||
subgraph_uri,
|
||||
agent_uri,
|
||||
question_uri,
|
||||
exploration_uri,
|
||||
focus_uri,
|
||||
synthesis_uri,
|
||||
edge_selection_uri,
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_final_uri,
|
||||
docrag_question_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
||||
|
||||
class TestEncodeId:
|
||||
"""Tests for the _encode_id helper."""
|
||||
|
||||
def test_plain_string(self):
|
||||
assert _encode_id("abc123") == "abc123"
|
||||
|
||||
def test_string_with_spaces(self):
|
||||
assert _encode_id("hello world") == "hello%20world"
|
||||
|
||||
def test_string_with_slashes(self):
|
||||
assert _encode_id("a/b/c") == "a%2Fb%2Fc"
|
||||
|
||||
def test_integer_input(self):
|
||||
assert _encode_id(42) == "42"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _encode_id("") == ""
|
||||
|
||||
def test_special_characters(self):
|
||||
result = _encode_id("name@domain.com")
|
||||
assert "@" not in result or result == "name%40domain.com"
|
||||
|
||||
|
||||
class TestDocumentUris:
|
||||
"""Tests for document, page, and chunk URI generation."""
|
||||
|
||||
def test_document_uri_passthrough(self):
|
||||
iri = "https://example.com/doc/123"
|
||||
assert document_uri(iri) == iri
|
||||
|
||||
def test_page_uri_format(self):
|
||||
result = page_uri("https://example.com/doc/123", 5)
|
||||
assert result == "https://example.com/doc/123/p5"
|
||||
|
||||
def test_page_uri_page_zero(self):
|
||||
result = page_uri("https://example.com/doc/123", 0)
|
||||
assert result == "https://example.com/doc/123/p0"
|
||||
|
||||
def test_chunk_uri_from_page_format(self):
|
||||
result = chunk_uri_from_page("https://example.com/doc/123", 2, 3)
|
||||
assert result == "https://example.com/doc/123/p2/c3"
|
||||
|
||||
def test_chunk_uri_from_doc_format(self):
|
||||
result = chunk_uri_from_doc("https://example.com/doc/123", 7)
|
||||
assert result == "https://example.com/doc/123/c7"
|
||||
|
||||
def test_page_uri_preserves_doc_iri(self):
|
||||
doc = "urn:isbn:978-3-16-148410-0"
|
||||
result = page_uri(doc, 1)
|
||||
assert result.startswith(doc)
|
||||
|
||||
def test_chunk_from_page_hierarchy(self):
|
||||
"""Chunk URI should contain both page and chunk identifiers."""
|
||||
result = chunk_uri_from_page("https://example.com/doc", 3, 5)
|
||||
assert "/p3/" in result
|
||||
assert result.endswith("/c5")
|
||||
|
||||
|
||||
class TestActivityAndSubgraphUris:
|
||||
"""Tests for activity_uri, subgraph_uri, and agent_uri."""
|
||||
|
||||
def test_activity_uri_with_id(self):
|
||||
result = activity_uri("my-activity-id")
|
||||
assert result == f"{TRUSTGRAPH_BASE}/activity/my-activity-id"
|
||||
|
||||
def test_activity_uri_auto_generates_uuid(self):
|
||||
result = activity_uri()
|
||||
assert result.startswith(f"{TRUSTGRAPH_BASE}/activity/")
|
||||
# UUID part should be non-empty
|
||||
uuid_part = result.split("/activity/")[1]
|
||||
assert len(uuid_part) > 0
|
||||
|
||||
def test_activity_uri_unique_uuids(self):
|
||||
r1 = activity_uri()
|
||||
r2 = activity_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_activity_uri_encodes_special_chars(self):
|
||||
result = activity_uri("id with spaces")
|
||||
assert "id%20with%20spaces" in result
|
||||
|
||||
def test_subgraph_uri_with_id(self):
|
||||
result = subgraph_uri("sg-123")
|
||||
assert result == f"{TRUSTGRAPH_BASE}/subgraph/sg-123"
|
||||
|
||||
def test_subgraph_uri_auto_generates_uuid(self):
|
||||
result = subgraph_uri()
|
||||
assert result.startswith(f"{TRUSTGRAPH_BASE}/subgraph/")
|
||||
uuid_part = result.split("/subgraph/")[1]
|
||||
assert len(uuid_part) > 0
|
||||
|
||||
def test_subgraph_uri_unique_uuids(self):
|
||||
r1 = subgraph_uri()
|
||||
r2 = subgraph_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_agent_uri_format(self):
|
||||
result = agent_uri("pdf-extractor")
|
||||
assert result == f"{TRUSTGRAPH_BASE}/agent/pdf-extractor"
|
||||
|
||||
def test_agent_uri_encodes_special_chars(self):
|
||||
result = agent_uri("my component")
|
||||
assert "my%20component" in result
|
||||
|
||||
|
||||
class TestGraphRagQueryUris:
|
||||
"""Tests for GraphRAG query-time provenance URIs."""
|
||||
|
||||
FIXED_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_question_uri_with_session_id(self):
|
||||
result = question_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:question:{self.FIXED_UUID}"
|
||||
|
||||
def test_question_uri_auto_generates(self):
|
||||
result = question_uri()
|
||||
assert result.startswith("urn:trustgraph:question:")
|
||||
uuid_part = result.split("urn:trustgraph:question:")[1]
|
||||
assert len(uuid_part) > 0
|
||||
|
||||
def test_question_uri_unique(self):
|
||||
r1 = question_uri()
|
||||
r2 = question_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_exploration_uri_format(self):
|
||||
result = exploration_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:prov:exploration:{self.FIXED_UUID}"
|
||||
|
||||
def test_focus_uri_format(self):
|
||||
result = focus_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:prov:focus:{self.FIXED_UUID}"
|
||||
|
||||
def test_synthesis_uri_format(self):
|
||||
result = synthesis_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:prov:synthesis:{self.FIXED_UUID}"
|
||||
|
||||
def test_edge_selection_uri_format(self):
|
||||
result = edge_selection_uri(self.FIXED_UUID, 3)
|
||||
assert result == f"urn:trustgraph:prov:edge:{self.FIXED_UUID}:3"
|
||||
|
||||
def test_edge_selection_uri_zero_index(self):
|
||||
result = edge_selection_uri(self.FIXED_UUID, 0)
|
||||
assert result.endswith(":0")
|
||||
|
||||
def test_session_uris_share_session_id(self):
|
||||
"""All URIs for a session should contain the same session ID."""
|
||||
sid = self.FIXED_UUID
|
||||
q = question_uri(sid)
|
||||
e = exploration_uri(sid)
|
||||
f = focus_uri(sid)
|
||||
s = synthesis_uri(sid)
|
||||
for uri in [q, e, f, s]:
|
||||
assert sid in uri
|
||||
|
||||
|
||||
class TestAgentProvenanceUris:
|
||||
"""Tests for agent provenance URIs."""
|
||||
|
||||
FIXED_UUID = "661e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_agent_session_uri_with_id(self):
|
||||
result = agent_session_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}"
|
||||
|
||||
def test_agent_session_uri_auto_generates(self):
|
||||
result = agent_session_uri()
|
||||
assert result.startswith("urn:trustgraph:agent:")
|
||||
|
||||
def test_agent_session_uri_unique(self):
|
||||
r1 = agent_session_uri()
|
||||
r2 = agent_session_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_agent_iteration_uri_format(self):
|
||||
result = agent_iteration_uri(self.FIXED_UUID, 1)
|
||||
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/i1"
|
||||
|
||||
def test_agent_iteration_uri_numbering(self):
|
||||
r1 = agent_iteration_uri(self.FIXED_UUID, 1)
|
||||
r2 = agent_iteration_uri(self.FIXED_UUID, 2)
|
||||
assert r1 != r2
|
||||
assert r1.endswith("/i1")
|
||||
assert r2.endswith("/i2")
|
||||
|
||||
def test_agent_final_uri_format(self):
|
||||
result = agent_final_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:agent:{self.FIXED_UUID}/final"
|
||||
|
||||
def test_agent_uris_share_session_id(self):
|
||||
sid = self.FIXED_UUID
|
||||
session = agent_session_uri(sid)
|
||||
iteration = agent_iteration_uri(sid, 1)
|
||||
final = agent_final_uri(sid)
|
||||
for uri in [session, iteration, final]:
|
||||
assert sid in uri
|
||||
|
||||
|
||||
class TestDocRagProvenanceUris:
|
||||
"""Tests for Document RAG provenance URIs."""
|
||||
|
||||
FIXED_UUID = "772e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_docrag_question_uri_with_id(self):
|
||||
result = docrag_question_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}"
|
||||
|
||||
def test_docrag_question_uri_auto_generates(self):
|
||||
result = docrag_question_uri()
|
||||
assert result.startswith("urn:trustgraph:docrag:")
|
||||
|
||||
def test_docrag_question_uri_unique(self):
|
||||
r1 = docrag_question_uri()
|
||||
r2 = docrag_question_uri()
|
||||
assert r1 != r2
|
||||
|
||||
def test_docrag_exploration_uri_format(self):
|
||||
result = docrag_exploration_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/exploration"
|
||||
|
||||
def test_docrag_synthesis_uri_format(self):
|
||||
result = docrag_synthesis_uri(self.FIXED_UUID)
|
||||
assert result == f"urn:trustgraph:docrag:{self.FIXED_UUID}/synthesis"
|
||||
|
||||
def test_docrag_uris_share_session_id(self):
|
||||
sid = self.FIXED_UUID
|
||||
q = docrag_question_uri(sid)
|
||||
e = docrag_exploration_uri(sid)
|
||||
s = docrag_synthesis_uri(sid)
|
||||
for uri in [q, e, s]:
|
||||
assert sid in uri
|
||||
|
||||
|
||||
class TestUriNamespaceIsolation:
|
||||
"""Verify that different provenance types use distinct URI namespaces."""
|
||||
|
||||
FIXED_UUID = "883e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
def test_graphrag_vs_agent_namespace(self):
|
||||
graphrag = question_uri(self.FIXED_UUID)
|
||||
agent = agent_session_uri(self.FIXED_UUID)
|
||||
assert graphrag != agent
|
||||
assert "question" in graphrag
|
||||
assert "agent" in agent
|
||||
|
||||
def test_graphrag_vs_docrag_namespace(self):
|
||||
graphrag = question_uri(self.FIXED_UUID)
|
||||
docrag = docrag_question_uri(self.FIXED_UUID)
|
||||
assert graphrag != docrag
|
||||
|
||||
def test_agent_vs_docrag_namespace(self):
|
||||
agent = agent_session_uri(self.FIXED_UUID)
|
||||
docrag = docrag_question_uri(self.FIXED_UUID)
|
||||
assert agent != docrag
|
||||
|
||||
def test_extraction_vs_query_namespace(self):
|
||||
"""Extraction URIs use https://, query URIs use urn:."""
|
||||
ext = activity_uri(self.FIXED_UUID)
|
||||
query = question_uri(self.FIXED_UUID)
|
||||
assert ext.startswith("https://")
|
||||
assert query.startswith("urn:")
|
||||
124
tests/unit/test_provenance/test_vocabulary.py
Normal file
124
tests/unit/test_provenance/test_vocabulary.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
Tests for provenance vocabulary bootstrap.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import Triple, Term, IRI, LITERAL
|
||||
|
||||
from trustgraph.provenance.vocabulary import (
|
||||
get_vocabulary_triples,
|
||||
PROV_CLASS_LABELS,
|
||||
PROV_PREDICATE_LABELS,
|
||||
DC_PREDICATE_LABELS,
|
||||
SCHEMA_LABELS,
|
||||
SKOS_LABELS,
|
||||
TG_CLASS_LABELS,
|
||||
TG_PREDICATE_LABELS,
|
||||
)
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDFS_LABEL,
|
||||
PROV_ENTITY, PROV_ACTIVITY, PROV_AGENT,
|
||||
PROV_WAS_DERIVED_FROM, PROV_WAS_GENERATED_BY,
|
||||
PROV_USED, PROV_WAS_ASSOCIATED_WITH, PROV_STARTED_AT_TIME,
|
||||
DC_TITLE, DC_SOURCE, DC_DATE, DC_CREATOR,
|
||||
TG_DOCUMENT_TYPE, TG_PAGE_TYPE, TG_CHUNK_TYPE, TG_SUBGRAPH_TYPE,
|
||||
)
|
||||
|
||||
|
||||
class TestVocabularyTriples:
|
||||
"""Tests for the vocabulary bootstrap function."""
|
||||
|
||||
def test_returns_list_of_triples(self):
|
||||
result = get_vocabulary_triples()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
for t in result:
|
||||
assert isinstance(t, Triple)
|
||||
|
||||
def test_all_triples_are_label_triples(self):
|
||||
"""Every vocabulary triple should use rdfs:label as predicate."""
|
||||
for t in get_vocabulary_triples():
|
||||
assert t.p.type == IRI
|
||||
assert t.p.iri == RDFS_LABEL
|
||||
|
||||
def test_all_subjects_are_iris(self):
|
||||
for t in get_vocabulary_triples():
|
||||
assert t.s.type == IRI
|
||||
assert len(t.s.iri) > 0
|
||||
|
||||
def test_all_objects_are_literals(self):
|
||||
for t in get_vocabulary_triples():
|
||||
assert t.o.type == LITERAL
|
||||
assert len(t.o.value) > 0
|
||||
|
||||
def test_no_duplicate_subjects(self):
|
||||
subjects = [t.s.iri for t in get_vocabulary_triples()]
|
||||
assert len(subjects) == len(set(subjects))
|
||||
|
||||
def test_includes_prov_classes(self):
|
||||
subjects = {t.s.iri for t in get_vocabulary_triples()}
|
||||
assert PROV_ENTITY in subjects
|
||||
assert PROV_ACTIVITY in subjects
|
||||
assert PROV_AGENT in subjects
|
||||
|
||||
def test_includes_prov_predicates(self):
|
||||
subjects = {t.s.iri for t in get_vocabulary_triples()}
|
||||
assert PROV_WAS_DERIVED_FROM in subjects
|
||||
assert PROV_WAS_GENERATED_BY in subjects
|
||||
assert PROV_USED in subjects
|
||||
assert PROV_WAS_ASSOCIATED_WITH in subjects
|
||||
assert PROV_STARTED_AT_TIME in subjects
|
||||
|
||||
def test_includes_dc_predicates(self):
|
||||
subjects = {t.s.iri for t in get_vocabulary_triples()}
|
||||
assert DC_TITLE in subjects
|
||||
assert DC_SOURCE in subjects
|
||||
assert DC_DATE in subjects
|
||||
assert DC_CREATOR in subjects
|
||||
|
||||
def test_includes_tg_classes(self):
|
||||
subjects = {t.s.iri for t in get_vocabulary_triples()}
|
||||
assert TG_DOCUMENT_TYPE in subjects
|
||||
assert TG_PAGE_TYPE in subjects
|
||||
assert TG_CHUNK_TYPE in subjects
|
||||
assert TG_SUBGRAPH_TYPE in subjects
|
||||
|
||||
def test_component_lists_sum_to_total(self):
|
||||
total = get_vocabulary_triples()
|
||||
components = (
|
||||
PROV_CLASS_LABELS +
|
||||
PROV_PREDICATE_LABELS +
|
||||
DC_PREDICATE_LABELS +
|
||||
SCHEMA_LABELS +
|
||||
SKOS_LABELS +
|
||||
TG_CLASS_LABELS +
|
||||
TG_PREDICATE_LABELS
|
||||
)
|
||||
assert len(total) == len(components)
|
||||
|
||||
def test_idempotent(self):
|
||||
"""Calling twice should return equivalent triples."""
|
||||
r1 = get_vocabulary_triples()
|
||||
r2 = get_vocabulary_triples()
|
||||
assert len(r1) == len(r2)
|
||||
for t1, t2 in zip(r1, r2):
|
||||
assert t1.s.iri == t2.s.iri
|
||||
assert t1.o.value == t2.o.value
|
||||
|
||||
|
||||
class TestNamespaceConstants:
|
||||
"""Verify namespace constants are well-formed IRIs."""
|
||||
|
||||
def test_prov_namespace_prefix(self):
|
||||
assert PROV_ENTITY.startswith("http://www.w3.org/ns/prov#")
|
||||
|
||||
def test_dc_namespace_prefix(self):
|
||||
assert DC_TITLE.startswith("http://purl.org/dc/elements/1.1/")
|
||||
|
||||
def test_tg_namespace_prefix(self):
|
||||
assert TG_DOCUMENT_TYPE.startswith("https://trustgraph.ai/ns/")
|
||||
|
||||
def test_rdfs_label_iri(self):
|
||||
assert RDFS_LABEL == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
|
@ -37,7 +37,7 @@ def mock_qdrant_client():
|
|||
def mock_graph_embeddings_request():
|
||||
"""Mock graph embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -46,9 +46,9 @@ def mock_graph_embeddings_request():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_vectors():
|
||||
"""Mock graph embeddings request with multiple vectors"""
|
||||
"""Mock graph embeddings request with multiple vectors (legacy name, now single vector)"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3, 0.4]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
|
@ -82,7 +82,7 @@ def mock_graph_embeddings_uri_response():
|
|||
def mock_document_embeddings_request():
|
||||
"""Mock document embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -91,9 +91,9 @@ def mock_document_embeddings_request():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings request with multiple vectors"""
|
||||
"""Mock document embeddings request with multiple vectors (legacy name, now single vector)"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3, 0.4]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
|
@ -139,9 +139,9 @@ def mock_large_query_response():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_mixed_dimension_vectors():
|
||||
"""Mock request with vectors of different dimensions"""
|
||||
"""Mock request with vector (legacy name suggested mixed dimensions, now single vector)"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.doc_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest
|
||||
from trustgraph.schema import DocumentEmbeddingsRequest, ChunkMatch
|
||||
|
||||
|
||||
class TestMilvusDocEmbeddingsQueryProcessor:
|
||||
|
|
@ -33,7 +33,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
|
@ -71,15 +71,15 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"doc": "First document chunk"}},
|
||||
{"entity": {"doc": "Second document chunk"}},
|
||||
{"entity": {"doc": "Third document chunk"}},
|
||||
{"entity": {"chunk_id": "First document chunk"}},
|
||||
{"entity": {"chunk_id": "Second document chunk"}},
|
||||
{"entity": {"chunk_id": "Third document chunk"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
|
|
@ -90,50 +90,44 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5
|
||||
)
|
||||
|
||||
# Verify results are document chunks
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert result[0] == "First document chunk"
|
||||
assert result[1] == "Second document chunk"
|
||||
assert result[2] == "Third document chunk"
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert result[0].chunk_id == "First document chunk"
|
||||
assert result[1].chunk_id == "Second document chunk"
|
||||
assert result[2].chunk_id == "Third document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
async def test_query_document_embeddings_longer_vector(self, processor):
|
||||
"""Test querying document embeddings with a longer vector"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=3
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document from first vector"}},
|
||||
{"entity": {"doc": "Another doc from first vector"}},
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "First document"}},
|
||||
{"entity": {"chunk_id": "Second document"}},
|
||||
{"entity": {"chunk_id": "Third document"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document from second vector"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters including user/collection
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}),
|
||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results from all vectors are combined
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=3
|
||||
)
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document from first vector" in result
|
||||
assert "Another doc from first vector" in result
|
||||
assert "Document from second vector" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "First document" in chunk_ids
|
||||
assert "Second document" in chunk_ids
|
||||
assert "Third document" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_with_limit(self, processor):
|
||||
|
|
@ -141,16 +135,16 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document 1"}},
|
||||
{"entity": {"doc": "Document 2"}},
|
||||
{"entity": {"doc": "Document 3"}},
|
||||
{"entity": {"doc": "Document 4"}},
|
||||
{"entity": {"chunk_id": "Document 1"}},
|
||||
{"entity": {"chunk_id": "Document 2"}},
|
||||
{"entity": {"chunk_id": "Document 3"}},
|
||||
{"entity": {"chunk_id": "Document 4"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
|
|
@ -170,7 +164,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -188,7 +182,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -211,25 +205,26 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with Unicode content
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with Unicode: éñ中文🚀"}},
|
||||
{"entity": {"doc": "Regular ASCII document"}},
|
||||
{"entity": {"doc": "Document with émojis: 😀🎉"}},
|
||||
{"entity": {"chunk_id": "Document with Unicode: éñ中文🚀"}},
|
||||
{"entity": {"chunk_id": "Regular ASCII document"}},
|
||||
{"entity": {"chunk_id": "Document with émojis: 😀🎉"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify Unicode content is preserved
|
||||
# Verify Unicode content is preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document with Unicode: éñ中文🚀" in result
|
||||
assert "Regular ASCII document" in result
|
||||
assert "Document with émojis: 😀🎉" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document with Unicode: éñ中文🚀" in chunk_ids
|
||||
assert "Regular ASCII document" in chunk_ids
|
||||
assert "Document with émojis: 😀🎉" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_large_documents(self, processor):
|
||||
|
|
@ -237,24 +232,25 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with large content
|
||||
large_doc = "A" * 10000 # 10KB of content
|
||||
mock_results = [
|
||||
{"entity": {"doc": large_doc}},
|
||||
{"entity": {"doc": "Small document"}},
|
||||
{"entity": {"chunk_id": large_doc}},
|
||||
{"entity": {"chunk_id": "Small document"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify large content is preserved
|
||||
# Verify large content is preserved in ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert large_doc in result
|
||||
assert "Small document" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert large_doc in chunk_ids
|
||||
assert "Small document" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_special_characters(self, processor):
|
||||
|
|
@ -262,25 +258,26 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with special characters
|
||||
mock_results = [
|
||||
{"entity": {"doc": "Document with \"quotes\" and 'apostrophes'"}},
|
||||
{"entity": {"doc": "Document with\nnewlines\tand\ttabs"}},
|
||||
{"entity": {"doc": "Document with special chars: @#$%^&*()"}},
|
||||
{"entity": {"chunk_id": "Document with \"quotes\" and 'apostrophes'"}},
|
||||
{"entity": {"chunk_id": "Document with\nnewlines\tand\ttabs"}},
|
||||
{"entity": {"chunk_id": "Document with special chars: @#$%^&*()"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify special characters are preserved
|
||||
# Verify special characters are preserved in ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in result
|
||||
assert "Document with\nnewlines\tand\ttabs" in result
|
||||
assert "Document with special chars: @#$%^&*()" in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document with \"quotes\" and 'apostrophes'" in chunk_ids
|
||||
assert "Document with\nnewlines\tand\ttabs" in chunk_ids
|
||||
assert "Document with special chars: @#$%^&*()" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
|
|
@ -288,7 +285,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
|
|
@ -306,7 +303,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=-1
|
||||
)
|
||||
|
||||
|
|
@ -324,7 +321,7 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -341,60 +338,54 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5], # 5D vector
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"doc": "Document from 2D vector"}}]
|
||||
mock_results_2 = [{"entity": {"doc": "Document from 4D vector"}}]
|
||||
mock_results_3 = [{"entity": {"doc": "Document from 3D vector"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "Document 1"}},
|
||||
{"entity": {"chunk_id": "Document 2"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
assert "Document from 2D vector" in result
|
||||
assert "Document from 4D vector" in result
|
||||
assert "Document from 3D vector" in result
|
||||
|
||||
# Verify search was called with the vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document 1" in chunk_ids
|
||||
assert "Document 2" in chunk_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_duplicate_documents(self, processor):
|
||||
"""Test querying document embeddings with duplicate documents in results"""
|
||||
async def test_query_document_embeddings_multiple_results(self, processor):
|
||||
"""Test querying document embeddings with multiple results"""
|
||||
query = DocumentEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates across vectors
|
||||
mock_results_1 = [
|
||||
{"entity": {"doc": "Document A"}},
|
||||
{"entity": {"doc": "Document B"}},
|
||||
|
||||
# Mock search results with multiple documents
|
||||
mock_results = [
|
||||
{"entity": {"chunk_id": "Document A"}},
|
||||
{"entity": {"chunk_id": "Document B"}},
|
||||
{"entity": {"chunk_id": "Document C"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"doc": "Document B"}}, # Duplicate
|
||||
{"entity": {"doc": "Document C"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_document_embeddings(query)
|
||||
|
||||
# Note: Unlike graph embeddings, doc embeddings don't deduplicate
|
||||
# This preserves ranking and allows multiple occurrences
|
||||
assert len(result) == 4
|
||||
assert result.count("Document B") == 2 # Should appear twice
|
||||
assert "Document A" in result
|
||||
assert "Document C" in result
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert "Document A" in chunk_ids
|
||||
assert "Document B" in chunk_ids
|
||||
assert "Document C" in chunk_ids
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
@ -458,5 +449,5 @@ class TestMilvusDocEmbeddingsQueryProcessor:
|
|||
|
||||
mock_launch.assert_called_once_with(
|
||||
default_ident,
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunks\n"
|
||||
"\nDocument embeddings query service. Input is vector, output is an array\nof chunk_ids\n"
|
||||
)
|
||||
|
|
@ -18,10 +18,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -103,7 +100,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_single_vector(self, processor):
|
||||
"""Test querying document embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -179,7 +176,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -208,7 +205,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -226,7 +223,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -242,12 +239,9 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions using same index"""
|
||||
"""Test querying with single vector (legacy test name, schema now uses single vector)"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -285,7 +279,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.vector = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -304,7 +298,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -325,7 +319,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_unicode_content(self, processor):
|
||||
"""Test querying document embeddings with Unicode content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -351,7 +345,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_large_content(self, processor):
|
||||
"""Test querying document embeddings with large content results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -377,7 +371,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_mixed_content_types(self, processor):
|
||||
"""Test querying document embeddings with mixed content types"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -409,7 +403,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -425,7 +419,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
async def test_query_document_embeddings_index_access_failure(self, processor):
|
||||
"""Test handling of index access failure"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -437,13 +431,9 @@ class TestPineconeDocEmbeddingsQueryProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_document_embeddings_vector_accumulation(self, processor):
|
||||
"""Test that results from multiple vectors are properly accumulated"""
|
||||
"""Test that results from single vector query are returned (legacy multi-vector test)"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -77,9 +78,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point1.payload = {'chunk_id': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
mock_point2.payload = {'chunk_id': 'second document chunk'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
|
|
@ -94,7 +95,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -112,72 +113,69 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
# Verify result contains expected ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Results should be ChunkMatch objects
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert isinstance(result[1], ChunkMatch)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
assert result[0].chunk_id == 'first document chunk'
|
||||
assert result[1].chunk_id == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
async def test_query_document_embeddings_multiple_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings returns multiple results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
|
||||
# Mock query response with multiple results
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
||||
mock_point1.payload = {'chunk_id': 'document chunk 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
||||
mock_point2.payload = {'chunk_id': 'document chunk 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
mock_point3.payload = {'chunk_id': 'document chunk 3'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
# Verify collection was queried correctly
|
||||
expected_collection = 'd_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document chunk 1' in chunk_ids
|
||||
assert 'document chunk 2' in chunk_ids
|
||||
assert 'document chunk 3' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -192,7 +190,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_point.payload = {'chunk_id': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
|
|
@ -208,7 +206,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
|
@ -248,7 +246,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
|
@ -262,58 +260,53 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
"""Test querying document embeddings with a higher dimension vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
||||
mock_point1.payload = {'chunk_id': 'document from 5D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
mock_point2.payload = {'chunk_id': 'another 5D document'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
|
||||
# Create mock message with 5D vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2, 0.3, 0.4, 0.5] # 5D vector
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once with correct collection
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
# Call should use 5D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_5' # 5 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
# Verify results are ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document from 5D vector' in chunk_ids
|
||||
assert 'another 5D document' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -326,9 +319,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock query response with UTF-8 content
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point1.payload = {'chunk_id': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
mock_point2.payload = {'chunk_id': 'Chinese text: 你好世界'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
|
|
@ -343,7 +336,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
|
@ -353,10 +346,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
# Verify UTF-8 content works correctly in ChunkMatch objects
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in chunk_ids
|
||||
assert 'Chinese text: 你好世界' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -379,7 +373,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
|
@ -399,7 +393,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock query response
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': 'document chunk'}
|
||||
mock_point.payload = {'chunk_id': 'document chunk'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
|
@ -413,7 +407,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
|
@ -426,10 +420,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
|
||||
# Result should contain all returned documents as ChunkMatch objects
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
assert isinstance(result[0], ChunkMatch)
|
||||
assert result[0].chunk_id == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -442,9 +437,9 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Mock query response with fewer results than limit
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document 1'}
|
||||
mock_point1.payload = {'chunk_id': 'document 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document 2'}
|
||||
mock_point2.payload = {'chunk_id': 'document 2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
|
|
@ -459,7 +454,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
|
@ -472,11 +467,12 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
|
||||
# Result should contain all available documents as ChunkMatch objects
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
chunk_ids = [r.chunk_id for r in result]
|
||||
assert 'document 1' in chunk_ids
|
||||
assert 'document 2' in chunk_ids
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
|
|
@ -487,11 +483,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with missing 'doc' key
|
||||
# Mock query response with missing 'chunk_id' key
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'valid document'}
|
||||
mock_point1.payload = {'chunk_id': 'valid document'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {} # Missing 'doc' key
|
||||
mock_point2.payload = {} # Missing 'chunk_id' key
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||
|
||||
|
|
@ -508,13 +504,13 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
||||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['doc']
|
||||
# This should raise a KeyError when trying to access payload['chunk_id']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -33,7 +33,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=10
|
||||
)
|
||||
return query
|
||||
|
|
@ -119,7 +119,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -138,55 +138,46 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are converted to Term objects
|
||||
# Verify results are converted to EntityMatch objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Term)
|
||||
assert result[0].iri == "http://example.com/entity1"
|
||||
assert result[0].type == IRI
|
||||
assert isinstance(result[1], Term)
|
||||
assert result[1].iri == "http://example.com/entity2"
|
||||
assert result[1].type == IRI
|
||||
assert isinstance(result[2], Term)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].type == LITERAL
|
||||
assert isinstance(result[0], EntityMatch)
|
||||
assert result[0].entity.iri == "http://example.com/entity1"
|
||||
assert result[0].entity.type == IRI
|
||||
assert isinstance(result[1], EntityMatch)
|
||||
assert result[1].entity.iri == "http://example.com/entity2"
|
||||
assert result[1].entity.type == IRI
|
||||
assert isinstance(result[2], EntityMatch)
|
||||
assert result[2].entity.value == "literal entity"
|
||||
assert result[2].entity.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
async def test_query_graph_embeddings_multiple_results(self, processor):
|
||||
"""Test querying graph embeddings returns multiple results"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=3
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results - different results for each vector
|
||||
mock_results_1 = [
|
||||
|
||||
# Mock search results with multiple entities
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called twice with correct parameters including user/collection
|
||||
expected_calls = [
|
||||
(([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}),
|
||||
(([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}),
|
||||
]
|
||||
assert processor.vecstore.search.call_count == 2
|
||||
for i, (expected_args, expected_kwargs) in enumerate(expected_calls):
|
||||
actual_call = processor.vecstore.search.call_args_list[i]
|
||||
assert actual_call[0] == expected_args
|
||||
assert actual_call[1] == expected_kwargs
|
||||
|
||||
# Verify results are deduplicated and limited
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are EntityMatch objects
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
|
@ -197,7 +188,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=2
|
||||
)
|
||||
|
||||
|
|
@ -221,63 +212,57 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication(self, processor):
|
||||
"""Test that duplicate entities are properly deduplicated"""
|
||||
async def test_query_graph_embeddings_preserves_order(self, processor):
|
||||
"""Test that query results preserve order from the vector store"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results with duplicates
|
||||
mock_results_1 = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
mock_results_2 = [
|
||||
{"entity": {"entity": "http://example.com/entity2"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity1"}}, # Duplicate
|
||||
{"entity": {"entity": "http://example.com/entity3"}}, # New
|
||||
]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2]
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Mock search results - first vector returns enough results
|
||||
mock_results_1 = [
|
||||
# Mock search results in specific order
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results_1
|
||||
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify only first vector was searched (limit reached)
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4
|
||||
|
||||
# Verify results are in the same order as returned by the store
|
||||
assert len(result) == 3
|
||||
assert result[0].entity.iri == "http://example.com/entity1"
|
||||
assert result[1].entity.iri == "http://example.com/entity2"
|
||||
assert result[2].entity.iri == "http://example.com/entity3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_results_limited(self, processor):
|
||||
"""Test that results are properly limited when store returns more than requested"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
limit=2
|
||||
)
|
||||
|
||||
# Verify results are limited
|
||||
|
||||
# Mock search results - returns more results than limit
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
{"entity": {"entity": "http://example.com/entity3"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify search was called with the full vector
|
||||
processor.vecstore.search.assert_called_once_with(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], 'test_user', 'test_collection', limit=4
|
||||
)
|
||||
|
||||
# Verify results are limited to requested amount
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -286,7 +271,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[],
|
||||
vector=[],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -304,7 +289,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -327,7 +312,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -344,18 +329,18 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify all results are properly typed
|
||||
assert len(result) == 4
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.type == IRI]
|
||||
uri_results = [r for r in result if r.entity.type == IRI]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.iri for r in uri_results]
|
||||
uri_values = [r.entity.iri for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.type == IRI]
|
||||
literal_results = [r for r in result if not r.entity.type == IRI]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
literal_values = [r.entity.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
assert "another literal" in literal_values
|
||||
|
||||
|
|
@ -365,7 +350,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=5
|
||||
)
|
||||
|
||||
|
|
@ -447,7 +432,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
limit=0
|
||||
)
|
||||
|
||||
|
|
@ -460,33 +445,29 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
async def test_query_graph_embeddings_longer_vector(self, processor):
|
||||
"""Test querying graph embeddings with a longer vector"""
|
||||
query = GraphEmbeddingsRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
[0.7, 0.8, 0.9] # 3D vector
|
||||
],
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Mock search results for each vector
|
||||
mock_results_1 = [{"entity": {"entity": "entity_2d"}}]
|
||||
mock_results_2 = [{"entity": {"entity": "entity_4d"}}]
|
||||
mock_results_3 = [{"entity": {"entity": "entity_3d"}}]
|
||||
processor.vecstore.search.side_effect = [mock_results_1, mock_results_2, mock_results_3]
|
||||
|
||||
|
||||
# Mock search results
|
||||
mock_results = [
|
||||
{"entity": {"entity": "http://example.com/entity1"}},
|
||||
{"entity": {"entity": "http://example.com/entity2"}},
|
||||
]
|
||||
processor.vecstore.search.return_value = mock_results
|
||||
|
||||
result = await processor.query_graph_embeddings(query)
|
||||
|
||||
# Verify all vectors were searched
|
||||
assert processor.vecstore.search.call_count == 3
|
||||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
|
||||
# Verify search was called once with the full vector
|
||||
processor.vecstore.search.assert_called_once()
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
entity_values = [r.entity.iri if r.entity.type == IRI else r.entity.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.schema import Term, IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -19,10 +19,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
def mock_query_message(self):
|
||||
"""Create a mock query message for testing"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -131,7 +128,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
"""Test querying graph embeddings with a single vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -162,45 +159,39 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
include_metadata=True
|
||||
)
|
||||
|
||||
# Verify results
|
||||
# Verify results use EntityMatch structure
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].type == IRI
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].type == LITERAL
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].type == IRI
|
||||
assert entities[0].entity.iri == 'http://example.org/entity1'
|
||||
assert entities[0].entity.type == IRI
|
||||
assert entities[1].entity.value == 'entity2'
|
||||
assert entities[1].entity.type == LITERAL
|
||||
assert entities[2].entity.iri == 'http://example.org/entity3'
|
||||
assert entities[2].entity.type == IRI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
async def test_query_graph_embeddings_basic(self, processor, mock_query_message):
|
||||
"""Test basic graph embeddings query"""
|
||||
# Mock index and query results
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query results with distinct entities
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'})
|
||||
]
|
||||
|
||||
# Second query results
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(mock_query_message)
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
|
||||
# Verify deduplication occurred
|
||||
entity_values = [e.value for e in entities]
|
||||
|
||||
# Verify query was made once
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(entity_values) == 3
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
|
@ -210,7 +201,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_limit_handling(self, processor):
|
||||
"""Test that query respects the limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -234,7 +225,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_zero_limit(self, processor):
|
||||
"""Test querying with zero limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 0
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -252,7 +243,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_negative_limit(self, processor):
|
||||
"""Test querying with negative limit returns empty results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = -1
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -267,52 +258,41 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
|
||||
"""Test querying with vectors of different dimensions using same index"""
|
||||
async def test_query_graph_embeddings_2d_vector(self, processor):
|
||||
"""Test querying with a 2D vector"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6] # 4D vector
|
||||
]
|
||||
message.vector = [0.1, 0.2] # 2D vector
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
# Mock single index that handles all dimensions
|
||||
# Mock index
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock results for different vector queries
|
||||
mock_results_2d = MagicMock()
|
||||
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
# Mock results for 2D vector query
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
|
||||
|
||||
mock_results_4d = MagicMock()
|
||||
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
|
||||
|
||||
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Verify different indexes used for different dimensions
|
||||
assert processor.pinecone.Index.call_count == 2
|
||||
index_calls = processor.pinecone.Index.call_args_list
|
||||
index_names = [call[0][0] for call in index_calls]
|
||||
assert "t-test_user-test_collection-2" in index_names # 2D vector
|
||||
assert "t-test_user-test_collection-4" in index_names # 4D vector
|
||||
# Verify correct index used for 2D vector
|
||||
processor.pinecone.Index.assert_called_with("t-test_user-test_collection-2")
|
||||
|
||||
# Verify both queries were made
|
||||
assert mock_index.query.call_count == 2
|
||||
# Verify query was made
|
||||
assert mock_index.query.call_count == 1
|
||||
|
||||
# Verify results from both dimensions
|
||||
entity_values = [e.value for e in entities]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert 'entity_2d' in entity_values
|
||||
assert 'entity_4d' in entity_values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_empty_vectors_list(self, processor):
|
||||
"""Test querying with empty vectors list"""
|
||||
message = MagicMock()
|
||||
message.vectors = []
|
||||
message.vector = []
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -331,7 +311,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_no_results(self, processor):
|
||||
"""Test querying when index returns no results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
@ -349,73 +329,60 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
assert entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_deduplication_across_vectors(self, processor):
|
||||
"""Test that deduplication works correctly across multiple vector queries"""
|
||||
async def test_query_graph_embeddings_deduplication_in_results(self, processor):
|
||||
"""Test that deduplication works correctly within query results"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 3
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Both queries return overlapping results
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query returns results with some duplicates
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity1'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}),
|
||||
MagicMock(metadata={'entity': 'entity4'})
|
||||
]
|
||||
|
||||
mock_results2 = MagicMock()
|
||||
mock_results2.matches = [
|
||||
MagicMock(metadata={'entity': 'entity2'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity3'}), # Duplicate
|
||||
MagicMock(metadata={'entity': 'entity5'})
|
||||
]
|
||||
|
||||
mock_index.query.side_effect = [mock_results1, mock_results2]
|
||||
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
|
||||
# Should get exactly 3 unique entities (respecting limit)
|
||||
assert len(entities) == 3
|
||||
entity_values = [e.value for e in entities]
|
||||
entity_values = [e.entity.value for e in entities]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_early_termination_on_limit(self, processor):
|
||||
"""Test that querying stops early when limit is reached"""
|
||||
async def test_query_graph_embeddings_respects_limit(self, processor):
|
||||
"""Test that query respects limit parameter"""
|
||||
message = MagicMock()
|
||||
message.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 2
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
||||
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# First query returns enough results to meet limit
|
||||
mock_results1 = MagicMock()
|
||||
mock_results1.matches = [
|
||||
|
||||
# Query returns more results than limit
|
||||
mock_results = MagicMock()
|
||||
mock_results.matches = [
|
||||
MagicMock(metadata={'entity': 'entity1'}),
|
||||
MagicMock(metadata={'entity': 'entity2'}),
|
||||
MagicMock(metadata={'entity': 'entity3'})
|
||||
]
|
||||
mock_index.query.return_value = mock_results1
|
||||
|
||||
mock_index.query.return_value = mock_results
|
||||
|
||||
entities = await processor.query_graph_embeddings(message)
|
||||
|
||||
# Should only make one query since limit was reached
|
||||
|
||||
# Should only return 2 entities (respecting limit)
|
||||
mock_index.query.assert_called_once()
|
||||
assert len(entities) == 2
|
||||
|
||||
|
|
@ -423,7 +390,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
async def test_query_graph_embeddings_exception_handling(self, processor):
|
||||
"""Test that exceptions are properly raised"""
|
||||
message = MagicMock()
|
||||
message.vectors = [[0.1, 0.2, 0.3]]
|
||||
message.vector = [0.1, 0.2, 0.3]
|
||||
message.limit = 5
|
||||
message.user = 'test_user'
|
||||
message.collection = 'test_collection'
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
from trustgraph.schema import IRI, LITERAL, EntityMatch
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -167,7 +167,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
|
@ -185,10 +185,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
# Verify result contains expected EntityMatch objects
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert all(isinstance(entity, EntityMatch) for entity in result)
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
|
|
@ -221,35 +221,32 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
|
||||
# Verify both collections were queried (both 2-dimensional vectors)
|
||||
# Verify collection was queried
|
||||
expected_collection = 't_multi_user_multi_collection_2' # 2 dimensions
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -280,7 +277,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.vector = [0.1, 0.2, 0.3]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
|
@ -320,7 +317,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
|
@ -358,34 +355,29 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
|
||||
# Create mock message with single vector
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.vector = [0.1, 0.2] # 2D vector
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
# Verify query was called once
|
||||
assert mock_qdrant_instance.query_points.call_count == 1
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
# Call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' # 2 dimensions
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' # 3 dimensions
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
# Verify results with EntityMatch structure
|
||||
entity_values = [entity.entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -417,7 +409,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
|
@ -427,18 +419,18 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
||||
uri_entities = [entity for entity in result if entity.entity.type == IRI]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.iri for entity in uri_entities]
|
||||
uri_values = [entity.entity.iri for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
||||
regular_entities = [entity for entity in result if entity.entity.type == LITERAL]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
assert regular_entities[0].entity.value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -461,7 +453,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
|
@ -495,7 +487,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.vector = [0.1, 0.2]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
|
@ -512,7 +504,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
assert result[0].entity.value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
|
|||
|
|
@ -118,8 +118,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
|
|
@ -182,8 +182,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -219,8 +219,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -256,8 +256,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -293,8 +293,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -331,8 +331,8 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
|
|
@ -637,8 +637,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -678,8 +678,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -802,7 +802,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
# Verify all results were returned
|
||||
assert len(result) == 5
|
||||
for i, triple in enumerate(result):
|
||||
assert triple.s.value == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.s.iri == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
|
||||
|
|
|
|||
1
tests/unit/test_rdf/__init__.py
Normal file
1
tests/unit/test_rdf/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
309
tests/unit/test_rdf/test_rdf_primitives.py
Normal file
309
tests/unit/test_rdf/test_rdf_primitives.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
Tests for RDF 1.2 type system primitives: Term dataclass (IRI, blank node,
|
||||
typed literal, language-tagged literal, quoted triple), Triple/Quad dataclass
|
||||
with named graph support, and the knowledge/defs helper types.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import Term, Triple, IRI, BLANK, LITERAL, TRIPLE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTypeConstants:
|
||||
|
||||
def test_iri_constant(self):
|
||||
assert IRI == "i"
|
||||
|
||||
def test_blank_constant(self):
|
||||
assert BLANK == "b"
|
||||
|
||||
def test_literal_constant(self):
|
||||
assert LITERAL == "l"
|
||||
|
||||
def test_triple_constant(self):
|
||||
assert TRIPLE == "t"
|
||||
|
||||
def test_constants_are_distinct(self):
|
||||
vals = {IRI, BLANK, LITERAL, TRIPLE}
|
||||
assert len(vals) == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IRI terms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIriTerm:
|
||||
|
||||
def test_create_iri(self):
|
||||
t = Term(type=IRI, iri="http://example.org/Alice")
|
||||
assert t.type == IRI
|
||||
assert t.iri == "http://example.org/Alice"
|
||||
|
||||
def test_iri_defaults_empty(self):
|
||||
t = Term(type=IRI)
|
||||
assert t.iri == ""
|
||||
|
||||
def test_iri_with_fragment(self):
|
||||
t = Term(type=IRI, iri="http://example.org/ontology#Person")
|
||||
assert "#Person" in t.iri
|
||||
|
||||
def test_iri_with_unicode(self):
|
||||
t = Term(type=IRI, iri="http://example.org/概念")
|
||||
assert "概念" in t.iri
|
||||
|
||||
def test_iri_other_fields_default(self):
|
||||
t = Term(type=IRI, iri="http://example.org/x")
|
||||
assert t.id == ""
|
||||
assert t.value == ""
|
||||
assert t.datatype == ""
|
||||
assert t.language == ""
|
||||
assert t.triple is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blank node terms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBlankNodeTerm:
|
||||
|
||||
def test_create_blank_node(self):
|
||||
t = Term(type=BLANK, id="_:b0")
|
||||
assert t.type == BLANK
|
||||
assert t.id == "_:b0"
|
||||
|
||||
def test_blank_node_defaults_empty(self):
|
||||
t = Term(type=BLANK)
|
||||
assert t.id == ""
|
||||
|
||||
def test_blank_node_arbitrary_id(self):
|
||||
t = Term(type=BLANK, id="node-abc-123")
|
||||
assert t.id == "node-abc-123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed literals (XSD datatypes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTypedLiteral:
|
||||
|
||||
def test_plain_literal(self):
|
||||
t = Term(type=LITERAL, value="hello")
|
||||
assert t.type == LITERAL
|
||||
assert t.value == "hello"
|
||||
assert t.datatype == ""
|
||||
assert t.language == ""
|
||||
|
||||
def test_xsd_integer(self):
|
||||
t = Term(
|
||||
type=LITERAL, value="42",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#integer",
|
||||
)
|
||||
assert t.value == "42"
|
||||
assert "integer" in t.datatype
|
||||
|
||||
def test_xsd_boolean(self):
|
||||
t = Term(
|
||||
type=LITERAL, value="true",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#boolean",
|
||||
)
|
||||
assert t.datatype.endswith("#boolean")
|
||||
|
||||
def test_xsd_date(self):
|
||||
t = Term(
|
||||
type=LITERAL, value="2026-03-13",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#date",
|
||||
)
|
||||
assert t.value == "2026-03-13"
|
||||
assert t.datatype.endswith("#date")
|
||||
|
||||
def test_xsd_double(self):
|
||||
t = Term(
|
||||
type=LITERAL, value="3.14",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#double",
|
||||
)
|
||||
assert t.datatype.endswith("#double")
|
||||
|
||||
def test_empty_value_literal(self):
|
||||
t = Term(type=LITERAL, value="")
|
||||
assert t.value == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Language-tagged literals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLanguageTaggedLiteral:
|
||||
|
||||
def test_english_tag(self):
|
||||
t = Term(type=LITERAL, value="hello", language="en")
|
||||
assert t.language == "en"
|
||||
assert t.datatype == ""
|
||||
|
||||
def test_french_tag(self):
|
||||
t = Term(type=LITERAL, value="bonjour", language="fr")
|
||||
assert t.language == "fr"
|
||||
|
||||
def test_bcp47_subtag(self):
|
||||
t = Term(type=LITERAL, value="colour", language="en-GB")
|
||||
assert t.language == "en-GB"
|
||||
|
||||
def test_language_and_datatype_mutually_exclusive(self):
|
||||
"""Both can be set on the dataclass, but semantically only one should be used."""
|
||||
t = Term(type=LITERAL, value="x", language="en",
|
||||
datatype="http://www.w3.org/2001/XMLSchema#string")
|
||||
# Dataclass allows both — translators should respect mutual exclusivity
|
||||
assert t.language == "en"
|
||||
assert t.datatype != ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quoted triples (RDF-star)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuotedTriple:
|
||||
|
||||
def test_term_with_nested_triple(self):
|
||||
inner = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/Alice"),
|
||||
p=Term(type=IRI, iri="http://xmlns.com/foaf/0.1/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/Bob"),
|
||||
)
|
||||
qt = Term(type=TRIPLE, triple=inner)
|
||||
assert qt.type == TRIPLE
|
||||
assert qt.triple is inner
|
||||
assert qt.triple.s.iri == "http://example.org/Alice"
|
||||
|
||||
def test_quoted_triple_as_object(self):
|
||||
"""A triple whose object is a quoted triple (RDF-star)."""
|
||||
inner = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/Hope"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="A feeling of expectation"),
|
||||
)
|
||||
outer = Triple(
|
||||
s=Term(type=IRI, iri="urn:subgraph:123"),
|
||||
p=Term(type=IRI, iri="http://trustgraph.ai/tg/contains"),
|
||||
o=Term(type=TRIPLE, triple=inner),
|
||||
)
|
||||
assert outer.o.type == TRIPLE
|
||||
assert outer.o.triple.o.value == "A feeling of expectation"
|
||||
|
||||
def test_quoted_triple_none(self):
|
||||
t = Term(type=TRIPLE, triple=None)
|
||||
assert t.triple is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Triple / Quad (named graph)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTripleQuad:
|
||||
|
||||
def test_default_graph_is_none(self):
|
||||
t = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="val"),
|
||||
)
|
||||
assert t.g is None
|
||||
|
||||
def test_named_graph(self):
|
||||
t = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/s"),
|
||||
p=Term(type=IRI, iri="http://example.org/p"),
|
||||
o=Term(type=LITERAL, value="val"),
|
||||
g="urn:graph:source",
|
||||
)
|
||||
assert t.g == "urn:graph:source"
|
||||
|
||||
def test_empty_string_graph(self):
|
||||
t = Triple(g="")
|
||||
assert t.g == ""
|
||||
|
||||
def test_triple_with_all_none_terms(self):
|
||||
t = Triple()
|
||||
assert t.s is None
|
||||
assert t.p is None
|
||||
assert t.o is None
|
||||
assert t.g is None
|
||||
|
||||
def test_triple_equality(self):
|
||||
"""Dataclass equality based on field values."""
|
||||
t1 = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/A"),
|
||||
p=Term(type=IRI, iri="http://example.org/B"),
|
||||
o=Term(type=LITERAL, value="C"),
|
||||
)
|
||||
t2 = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/A"),
|
||||
p=Term(type=IRI, iri="http://example.org/B"),
|
||||
o=Term(type=LITERAL, value="C"),
|
||||
)
|
||||
assert t1 == t2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# knowledge/defs helper types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKnowledgeDefs:
|
||||
|
||||
def test_uri_type(self):
|
||||
from trustgraph.knowledge.defs import Uri
|
||||
u = Uri("http://example.org/x")
|
||||
assert u.is_uri() is True
|
||||
assert u.is_literal() is False
|
||||
assert u.is_triple() is False
|
||||
assert str(u) == "http://example.org/x"
|
||||
|
||||
def test_literal_type(self):
|
||||
from trustgraph.knowledge.defs import Literal
|
||||
l = Literal("hello world")
|
||||
assert l.is_uri() is False
|
||||
assert l.is_literal() is True
|
||||
assert l.is_triple() is False
|
||||
assert str(l) == "hello world"
|
||||
|
||||
def test_quoted_triple_type(self):
|
||||
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
|
||||
qt = QuotedTriple(
|
||||
s=Uri("http://example.org/s"),
|
||||
p=Uri("http://example.org/p"),
|
||||
o=Literal("val"),
|
||||
)
|
||||
assert qt.is_uri() is False
|
||||
assert qt.is_literal() is False
|
||||
assert qt.is_triple() is True
|
||||
assert qt.s == "http://example.org/s"
|
||||
assert qt.o == "val"
|
||||
|
||||
def test_quoted_triple_repr(self):
|
||||
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
|
||||
qt = QuotedTriple(
|
||||
s=Uri("http://example.org/A"),
|
||||
p=Uri("http://example.org/B"),
|
||||
o=Literal("C"),
|
||||
)
|
||||
r = repr(qt)
|
||||
assert "<<" in r
|
||||
assert ">>" in r
|
||||
assert "http://example.org/A" in r
|
||||
|
||||
def test_quoted_triple_nested(self):
|
||||
"""QuotedTriple can contain another QuotedTriple as object."""
|
||||
from trustgraph.knowledge.defs import QuotedTriple, Uri, Literal
|
||||
inner = QuotedTriple(
|
||||
s=Uri("http://example.org/s"),
|
||||
p=Uri("http://example.org/p"),
|
||||
o=Literal("v"),
|
||||
)
|
||||
outer = QuotedTriple(
|
||||
s=Uri("http://example.org/s2"),
|
||||
p=Uri("http://example.org/p2"),
|
||||
o=inner,
|
||||
)
|
||||
assert outer.o.is_triple() is True
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue