Merge 2.0 to master (#651)

This commit is contained in:
cybermaggedon 2026-02-28 11:03:14 +00:00 committed by GitHub
parent 3666ece2c5
commit b9d7bf9a8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
212 changed files with 13940 additions and 6180 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=1.8.999
run: make update-package-versions VERSION=2.0.999
- name: Setup environment
run: python3 -m venv env

View file

@ -5,7 +5,7 @@ VERSION=0.0.0
DOCKER=podman
all: container
all: containers
# Not used
wheels:
@ -49,7 +49,9 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
container: update-package-versions
FORCE:
containers: FORCE
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
@ -70,8 +72,8 @@ some-containers:
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.vertexai \
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \

View file

@ -16,7 +16,7 @@ RUN dnf install -y python3.13 && \
dnf clean all
RUN pip3 install --no-cache-dir \
anthropic cohere mistralai openai google-generativeai \
anthropic cohere mistralai openai \
ollama \
langchain==0.3.25 langchain-core==0.3.60 \
langchain-text-splitters==0.3.8 \

View file

@ -0,0 +1,260 @@
# Entity-Centric Knowledge Graph Storage on Cassandra
## Overview
This document describes a storage model for RDF-style knowledge graphs on Apache Cassandra. The model uses an **entity-centric** approach where every entity knows every quad it participates in and the role it plays. This replaces a traditional multi-table SPO permutation approach with just two tables.
## Background and Motivation
### The Traditional Approach
A standard RDF quad store on Cassandra requires multiple denormalised tables to cover query patterns — typically 6 or more tables representing different permutations of Subject, Predicate, Object, and Dataset (SPOD). Each quad is written to every table, resulting in significant write amplification, operational overhead, and schema complexity.
Additionally, label resolution (fetching human-readable names for entities) requires separate round-trip queries, which is particularly costly in AI and GraphRAG use cases where labels are essential for LLM context.
### The Entity-Centric Insight
Every quad `(D, S, P, O)` involves up to 4 entities. By writing a row for each entity's participation in the quad, we guarantee that **any query with at least one known element will hit a partition key**. This covers all 16 query patterns with a single data table.
Key benefits:
- **2 tables** instead of 7+
- **4 writes per quad** instead of 6+
- **Label resolution for free** — an entity's labels are co-located with its relationships, naturally warming the application cache
- **All 16 query patterns** served by single-partition reads
- **Simpler operations** — one data table to tune, compact, and repair
## Schema
### Table 1: quads_by_entity
The primary data table. Every entity has a partition containing all quads it participates in. Named to reflect the query pattern (lookup by entity).
```sql
CREATE TABLE quads_by_entity (
collection text, -- Collection/tenant scope (always specified)
entity text, -- The entity this row is about
role text, -- 'S', 'P', 'O', 'G' — how this entity participates
p text, -- Predicate of the quad
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
s text, -- Subject of the quad
o text, -- Object of the quad
d text, -- Dataset/graph of the quad
dtype text, -- XSD datatype (when otype = 'L'), e.g. 'xsd:string'
lang text, -- Language tag (when otype = 'L'), e.g. 'en', 'fr'
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
);
```
**Partition key**: `(collection, entity)` — scoped to collection, one partition per entity.
**Clustering column order rationale**:
1. **role** — most queries start with "where is this entity a subject/object"
2. **p** — next most common filter, "give me all `knows` relationships"
3. **otype** — enables filtering by URI-valued vs literal-valued relationships
4. **s, o, d** — remaining columns for uniqueness
### Table 2: quads_by_collection
Supports collection-level queries and deletion. Provides a manifest of all quads belonging to a collection. Named to reflect the query pattern (lookup by collection).
```sql
CREATE TABLE quads_by_collection (
collection text,
d text, -- Dataset/graph of the quad
s text, -- Subject of the quad
p text, -- Predicate of the quad
o text, -- Object of the quad
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
dtype text, -- XSD datatype (when otype = 'L')
lang text, -- Language tag (when otype = 'L')
PRIMARY KEY (collection, d, s, p, o)
);
```
Clustered by dataset first, enabling deletion at either collection or dataset granularity.
## Write Path
For each incoming quad `(D, S, P, O)` within a collection `C`, write **4 rows** to `quads_by_entity` and **1 row** to `quads_by_collection`.
### Example
Given the quad in collection `tenant1`:
```
Dataset: https://example.org/graph1
Subject: https://example.org/Alice
Predicate: https://example.org/knows
Object: https://example.org/Bob
```
Write 4 rows to `quads_by_entity`:
| collection | entity | role | p | otype | s | o | d |
|---|---|---|---|---|---|---|---|
| tenant1 | https://example.org/graph1 | G | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
| tenant1 | https://example.org/Alice | S | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
| tenant1 | https://example.org/knows | P | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
| tenant1 | https://example.org/Bob | O | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
Write 1 row to `quads_by_collection`:
| collection | d | s | p | o | otype | dtype | lang |
|---|---|---|---|---|---|---|---|
| tenant1 | https://example.org/graph1 | https://example.org/Alice | https://example.org/knows | https://example.org/Bob | U | | |
### Literal Example
For a label triple:
```
Dataset: https://example.org/graph1
Subject: https://example.org/Alice
Predicate: http://www.w3.org/2000/01/rdf-schema#label
Object: "Alice Smith" (lang: en)
```
The `otype` is `'L'`, `dtype` is `'xsd:string'`, and `lang` is `'en'`. The literal value `"Alice Smith"` is stored in `o`. Only 3 rows are needed in `quads_by_entity` — no row is written for the literal as entity, since literals are not independently queryable entities.
## Query Patterns
### All 16 DSPO Patterns
In the table below, "Perfect prefix" means the query uses a contiguous prefix of the clustering columns. "Partition scan + filter" means Cassandra reads a slice of one partition and filters in memory — still efficient, just not a pure prefix match.
| # | Known | Lookup entity | Clustering prefix | Efficiency |
|---|---|---|---|---|
| 1 | D,S,P,O | entity=S, role='S', p=P | Full match | Perfect prefix |
| 2 | D,S,P,? | entity=S, role='S', p=P | Filter on D | Partition scan + filter |
| 3 | D,S,?,O | entity=S, role='S' | Filter on D, O | Partition scan + filter |
| 4 | D,?,P,O | entity=O, role='O', p=P | Filter on D | Partition scan + filter |
| 5 | ?,S,P,O | entity=S, role='S', p=P | Filter on O | Partition scan + filter |
| 6 | D,S,?,? | entity=S, role='S' | Filter on D | Partition scan + filter |
| 7 | D,?,P,? | entity=P, role='P' | Filter on D | Partition scan + filter |
| 8 | D,?,?,O | entity=O, role='O' | Filter on D | Partition scan + filter |
| 9 | ?,S,P,? | entity=S, role='S', p=P | — | **Perfect prefix** |
| 10 | ?,S,?,O | entity=S, role='S' | Filter on O | Partition scan + filter |
| 11 | ?,?,P,O | entity=O, role='O', p=P | — | **Perfect prefix** |
| 12 | D,?,?,? | entity=D, role='G' | — | **Perfect prefix** |
| 13 | ?,S,?,? | entity=S, role='S' | — | **Perfect prefix** |
| 14 | ?,?,P,? | entity=P, role='P' | — | **Perfect prefix** |
| 15 | ?,?,?,O | entity=O, role='O' | — | **Perfect prefix** |
| 16 | ?,?,?,? | — | Full scan | Exploration only |
**Key result**: 7 of the 15 non-trivial patterns are perfect clustering prefix hits. The remaining 8 are single-partition reads with in-partition filtering. Every query with at least one known element hits a partition key.
Pattern 16 (?,?,?,?) does not occur in practice since collection is always specified, reducing it to pattern 12.
### Common Query Examples
**Everything about an entity:**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice';
```
**All outgoing relationships for an entity:**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
AND role = 'S';
```
**Specific predicate for an entity:**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
AND role = 'S' AND p = 'https://example.org/knows';
```
**Label for an entity (specific language):**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
AND role = 'S' AND p = 'http://www.w3.org/2000/01/rdf-schema#label'
AND otype = 'L';
```
Then filter by `lang = 'en'` application-side if needed.
**Only URI-valued relationships (entity-to-entity links):**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
AND role = 'S' AND p = 'https://example.org/knows' AND otype = 'U';
```
**Reverse lookup — what points to this entity:**
```sql
SELECT * FROM quads_by_entity
WHERE collection = 'tenant1' AND entity = 'https://example.org/Bob'
AND role = 'O';
```
## Label Resolution and Cache Warming
One of the most significant advantages of the entity-centric model is that **label resolution becomes a free side effect**.
In the traditional multi-table model, fetching labels requires separate round-trip queries: retrieve triples, identify entity URIs in the results, then fetch `rdfs:label` for each. This N+1 pattern is expensive.
In the entity-centric model, querying an entity returns **all** its quads — including its labels, types, and other properties. When the application caches query results, labels are pre-warmed before anything asks for them.
Two usage regimes confirm this works well in practice:
- **Human-facing queries**: naturally small result sets, labels essential. Entity reads pre-warm the cache.
- **AI/bulk queries**: large result sets with hard limits. Labels either unnecessary or needed only for a curated subset of entities already in cache.
The theoretical concern of resolving labels for huge result sets (e.g. 30,000 entities) is mitigated by the practical observation that no human or AI consumer usefully processes that many labels. Application-level query limits ensure cache pressure remains manageable.
## Wide Partitions and Reification
Reification (RDF-star style statements about statements) creates hub entities — e.g. a source document that supports thousands of extracted facts. This can produce wide partitions.
Mitigating factors:
- **Application-level query limits**: all GraphRAG and human-facing queries enforce hard limits, so wide partitions are never fully scanned on the hot read path
- **Cassandra handles partial reads efficiently**: a clustering column scan with an early stop is fast even on large partitions
- **Collection deletion** (the only operation that might traverse full partitions) is an accepted background process
## Collection Deletion
Triggered by API call, runs in the background (eventually consistent).
1. Read `quads_by_collection` for the target collection to get all quads
2. Extract unique entities from the quads (s, p, o, d values)
3. For each unique entity, delete the partition from `quads_by_entity`
4. Delete the rows from `quads_by_collection`
The `quads_by_collection` table provides the index needed to locate all entity partitions without a full table scan. Partition-level deletes are efficient since `(collection, entity)` is the partition key.
## Migration Path from Multi-Table Model
The entity-centric model can coexist with the existing multi-table model during migration:
1. Deploy `quads_by_entity` and `quads_by_collection` tables alongside existing tables
2. Dual-write new quads to both old and new tables
3. Backfill existing data into the new tables
4. Migrate read paths one query pattern at a time
5. Decommission old tables once all reads are migrated
## Summary
| Aspect | Traditional (6-table) | Entity-centric (2-table) |
|---|---|---|
| Tables | 7+ | 2 |
| Writes per quad | 6+ | 5 (4 data + 1 manifest) |
| Label resolution | Separate round trips | Free via cache warming |
| Query patterns | 16 across 6 tables | 16 on 1 table |
| Schema complexity | High | Low |
| Operational overhead | 6 tables to tune/repair | 1 data table |
| Reification support | Additional complexity | Natural fit |
| Object type filtering | Not available | Native (via otype clustering) |

View file

@ -0,0 +1,573 @@
# Graph Contexts Technical Specification
## Overview
This specification describes changes to TrustGraph's core graph primitives to
align with RDF 1.2 and support full RDF Dataset semantics. This is a breaking
change for the 2.x release series.
### Versioning
- **2.0**: Early adopter release. Core features available, may not be fully
production-ready.
- **2.1 / 2.2**: Production release. Stability and completeness validated.
Flexibility on maturity is intentional - early adopters can access new
capabilities before all features are production-hardened.
## Goals
The primary goals for this work are to enable metadata about facts/statements:
- **Temporal information**: Associate facts with time metadata
- When a fact was believed to be true
- When a fact became true
- When a fact was discovered to be false
- **Provenance/Sources**: Track which sources support a fact
- "This fact was supported by source X"
- Link facts back to their origin documents
- **Veracity/Trust**: Record assertions about truth
- "Person P asserted this was true"
- "Person Q claims this is false"
- Enable trust scoring and conflict detection
**Hypothesis**: Reification (RDF-star / quoted triples) is the key mechanism
to achieve these outcomes, as all require making statements about statements.
## Background
To express "the fact (Alice knows Bob) was discovered on 2024-01-15" or
"source X supports the claim (Y causes Z)", you need to reference an edge
as a thing you can make statements about. Standard triples don't support this.
### Current Limitations
The current `Value` class in `trustgraph-base/trustgraph/schema/core/primitives.py`
can represent:
- URI nodes (`is_uri=True`)
- Literal values (`is_uri=False`)
The `type` field exists but is not used to represent XSD datatypes.
## Technical Design
### RDF Features to Support
#### Core Features (Related to Reification Goals)
These features are directly related to the temporal, provenance, and veracity
goals:
1. **RDF 1.2 Quoted Triples (RDF-star)**
- Edges that point at other edges
- A Triple can appear as the subject or object of another Triple
- Enables statements about statements (reification)
- Core mechanism for annotating individual facts
2. **RDF Dataset / Named Graphs**
- Support for multiple named graphs within a dataset
- Each graph identified by an IRI
- Moves from triples (s, p, o) to quads (s, p, o, g)
- Includes a default graph plus zero or more named graphs
- The graph IRI can be a subject in statements, e.g.:
```
<graph-source-A> <discoveredOn> "2024-01-15"
<graph-source-A> <hasVeracity> "high"
```
- Note: Named graphs are a separate feature from reification. They have
uses beyond statement annotation (partitioning, access control, dataset
organization) and should be treated as a distinct capability.
3. **Blank Nodes** (Limited Support)
- Anonymous nodes without a global URI
- Supported for compatibility when loading external RDF data
- **Limited status**: No guarantees about stable identity after loading
- Find them via wildcard queries (match by connections, not by ID)
- Not a first-class feature - don't rely on precise blank node handling
#### Opportunistic Fixes (2.0 Breaking Change)
These features are not directly related to the reification goals but are
valuable improvements to include while making breaking changes:
4. **Literal Datatypes**
- Properly use the `type` field for XSD datatypes
- Examples: xsd:string, xsd:integer, xsd:dateTime, etc.
- Fixes current limitation: cannot represent dates or integers properly
5. **Language Tags**
- Support for language attributes on string literals (@en, @fr, etc.)
- Note: A literal has either a language tag OR a datatype, not both
(except for rdf:langString)
- Important for AI/multilingual use cases
### Data Models
#### Term (rename from Value)
The `Value` class will be renamed to `Term` to better reflect RDF terminology.
This rename serves two purposes:
1. Aligns naming with RDF concepts (a "Term" can be an IRI, literal, blank
node, or quoted triple - not just a "value")
2. Forces code review at the breaking change interface - any code still
referencing `Value` is visibly broken and needs updating
A Term can represent:
- **IRI/URI** - A named node/resource
- **Blank Node** - An anonymous node with local scope
- **Literal** - A data value with either:
- A datatype (XSD type), OR
- A language tag
- **Quoted Triple** - A triple used as a term (RDF 1.2)
##### Chosen Approach: Single Class with Type Discriminator
Serialization requirements drive the structure - a type discriminator is needed
in the wire format regardless of the Python representation. A single class with
a type field is the natural fit and aligns with the current `Value` pattern.
Single-character type codes provide compact serialization:
```python
from dataclasses import dataclass
# Term type constants
IRI = "i" # IRI/URI node
BLANK = "b" # Blank node
LITERAL = "l" # Literal value
TRIPLE = "t" # Quoted triple (RDF-star)
@dataclass
class Term:
type: str = "" # One of: IRI, BLANK, LITERAL, TRIPLE
# For IRI terms (type == IRI)
iri: str = ""
# For blank nodes (type == BLANK)
id: str = ""
# For literals (type == LITERAL)
value: str = ""
datatype: str = "" # XSD datatype URI (mutually exclusive with language)
language: str = "" # Language tag (mutually exclusive with datatype)
# For quoted triples (type == TRIPLE)
triple: "Triple | None" = None
```
Usage examples:
```python
# IRI term
node = Term(type=IRI, iri="http://example.org/Alice")
# Literal with datatype
age = Term(type=LITERAL, value="42", datatype="xsd:integer")
# Literal with language tag
label = Term(type=LITERAL, value="Hello", language="en")
# Blank node
anon = Term(type=BLANK, id="_:b1")
# Quoted triple (statement about a statement)
inner = Triple(
s=Term(type=IRI, iri="http://example.org/Alice"),
p=Term(type=IRI, iri="http://example.org/knows"),
o=Term(type=IRI, iri="http://example.org/Bob"),
)
reified = Term(type=TRIPLE, triple=inner)
```
##### Alternatives Considered
**Option B: Union of specialized classes** (`Term = IRI | BlankNode | Literal | QuotedTriple`)
- Rejected: Serialization would still need a type discriminator, adding complexity
**Option C: Base class with subclasses**
- Rejected: Same serialization issue, plus dataclass inheritance quirks
#### Triple / Quad
The `Triple` class gains an optional graph field to become a quad:
```python
@dataclass
class Triple:
s: Term | None = None # Subject
p: Term | None = None # Predicate
o: Term | None = None # Object
g: str | None = None # Graph name (IRI), None = default graph
```
Design decisions:
- **Field name**: `g` for consistency with `s`, `p`, `o`
- **Optional**: `None` means the default graph (unnamed)
- **Type**: Plain string (IRI) rather than Term
- Graph names are always IRIs
- Blank nodes as graph names ruled out (too confusing)
- No need for the full Term machinery
Note: The class name stays `Triple` even though it's technically a quad now.
This avoids churn and "triple" is still the common terminology for the s/p/o
portion. The graph context is metadata about where the triple lives.
### Candidate Query Patterns
The current query engine accepts combinations of S, P, O terms. With quoted
triples, a triple itself becomes a valid term in those positions. Below are
candidate query patterns that support the original goals.
#### Graph Parameter Semantics
Following SPARQL conventions for backward compatibility:
- **`g` omitted / None**: Query the default graph only
- **`g` = specific IRI**: Query that named graph only
- **`g` = wildcard / `*`**: Query across all graphs (equivalent to SPARQL
`GRAPH ?g { ... }`)
This keeps simple queries simple and makes named graph queries opt-in.
Cross-graph queries (g=wildcard) are fully supported. The Cassandra schema
includes dedicated tables (SPOG, POSG, OSPG) where g is a clustering column
rather than a partition key, enabling efficient queries across all graphs.
#### Temporal Queries
**Find all facts discovered after a given date:**
```
S: ? # any quoted triple
P: <discoveredOn>
O: > "2024-01-15"^^xsd:date # date comparison
```
**Find when a specific fact was believed true:**
```
S: << <Alice> <knows> <Bob> >> # quoted triple as subject
P: <believedTrueFrom>
O: ? # returns the date
```
**Find facts that became false:**
```
S: ? # any quoted triple
P: <discoveredFalseOn>
O: ? # has any value (exists)
```
#### Provenance Queries
**Find all facts supported by a specific source:**
```
S: ? # any quoted triple
P: <supportedBy>
O: <source:document-123>
```
**Find which sources support a specific fact:**
```
S: << <DrugA> <treats> <DiseaseB> >> # quoted triple as subject
P: <supportedBy>
O: ? # returns source IRIs
```
#### Veracity Queries
**Find assertions a person marked as true:**
```
S: ? # any quoted triple
P: <assertedTrueBy>
O: <person:Alice>
```
**Find conflicting assertions (same fact, different veracity):**
```
# First query: facts asserted true
S: ?
P: <assertedTrueBy>
O: ?
# Second query: facts asserted false
S: ?
P: <assertedFalseBy>
O: ?
# Application logic: find intersection of subjects
```
**Find facts with trust score below threshold:**
```
S: ? # any quoted triple
P: <trustScore>
O: < 0.5 # numeric comparison
```
### Architecture
Significant changes required across multiple components:
#### This Repository (trustgraph)
- **Schema primitives** (`trustgraph-base/trustgraph/schema/core/primitives.py`)
- Value → Term rename
- New Term structure with type discriminator
- Triple gains `g` field for graph context
- **Message translators** (`trustgraph-base/trustgraph/messaging/translators/`)
- Update for new Term/Triple structures
- Serialization/deserialization for new fields
- **Gateway components**
- Handle new Term and quad structures
- **Knowledge cores**
- Core changes to support quads and reification
- **Knowledge manager**
- Schema changes propagate here
- **Storage layers**
- Cassandra: Schema redesign (see Implementation Details)
- Other backends: Deferred to later phases
- **Command-line utilities**
- Update for new data structures
- **REST API documentation**
- OpenAPI spec updates
#### External Repositories
- **Python API** (this repo)
- Client library updates for new structures
- **TypeScript APIs** (separate repo)
- Client library updates
- **Workbench** (separate repo)
- Significant state management changes
### APIs
#### REST API
- Documented in OpenAPI spec
- Will need updates for new Term/Triple structures
- New endpoints may be needed for graph context operations
#### Python API (this repo)
- Client library changes to match new primitives
- Breaking changes to Term (was Value) and Triple
#### TypeScript API (separate repo)
- Parallel changes to Python API
- Separate release coordination
#### Workbench (separate repo)
- Significant state management changes
- UI updates for graph context features
### Implementation Details
#### Phased Storage Implementation
Multiple graph store backends exist (Cassandra, Neo4j, etc.). Implementation
will proceed in phases:
1. **Phase 1: Cassandra**
- Start with the home-grown Cassandra store
- Full control over the storage layer enables rapid iteration
- Schema will be redesigned from scratch for quads + reification
- Validate the data model and query patterns against real use cases
#### Cassandra Schema Design
Cassandra requires multiple tables to support different query access patterns
(each table efficiently queries by its partition key + clustering columns).
##### Query Patterns
With quads (g, s, p, o), each position can be specified or wildcard, giving
16 possible query patterns:
| # | g | s | p | o | Description |
|---|---|---|---|---|-------------|
| 1 | ? | ? | ? | ? | All quads |
| 2 | ? | ? | ? | o | By object |
| 3 | ? | ? | p | ? | By predicate |
| 4 | ? | ? | p | o | By predicate + object |
| 5 | ? | s | ? | ? | By subject |
| 6 | ? | s | ? | o | By subject + object |
| 7 | ? | s | p | ? | By subject + predicate |
| 8 | ? | s | p | o | Full triple (which graphs?) |
| 9 | g | ? | ? | ? | By graph |
| 10 | g | ? | ? | o | By graph + object |
| 11 | g | ? | p | ? | By graph + predicate |
| 12 | g | ? | p | o | By graph + predicate + object |
| 13 | g | s | ? | ? | By graph + subject |
| 14 | g | s | ? | o | By graph + subject + object |
| 15 | g | s | p | ? | By graph + subject + predicate |
| 16 | g | s | p | o | Exact quad |
##### Table Design
Cassandra constraint: You can only efficiently query by partition key, then
filter on clustering columns left-to-right. For g-wildcard queries, g must be
a clustering column. For g-specified queries, g in the partition key is more
efficient.
**Two table families needed:**
**Family A: g-wildcard queries** (g in clustering columns)
| Table | Partition | Clustering | Supports patterns |
|-------|-----------|------------|-------------------|
| SPOG | (user, collection, s) | p, o, g | 5, 7, 8 |
| POSG | (user, collection, p) | o, s, g | 3, 4 |
| OSPG | (user, collection, o) | s, p, g | 2, 6 |
**Family B: g-specified queries** (g in partition key)
| Table | Partition | Clustering | Supports patterns |
|-------|-----------|------------|-------------------|
| GSPO | (user, collection, g, s) | p, o | 9, 13, 15, 16 |
| GPOS | (user, collection, g, p) | o, s | 11, 12 |
| GOSP | (user, collection, g, o) | s, p | 10, 14 |
**Collection table** (for iteration and bulk deletion)
| Table | Partition | Clustering | Purpose |
|-------|-----------|------------|---------|
| COLL | (user, collection) | g, s, p, o | Enumerate all quads in collection |
##### Write and Delete Paths
**Write path**: Insert into all 7 tables.
**Delete collection path**:
1. Iterate COLL table for `(user, collection)`
2. For each quad, delete from all 6 query tables
3. Delete from COLL table (or range delete)
**Delete single quad path**: Delete from all 7 tables directly.
##### Storage Cost
Each quad is stored 7 times. This is the cost of flexible querying combined
with efficient collection deletion.
##### Quoted Triples in Storage
Subject or object can be a triple itself. Options:
**Option A: Serialize quoted triples to canonical string**
```
S: "<<http://ex/Alice|http://ex/knows|http://ex/Bob>>"
P: http://ex/discoveredOn
O: "2024-01-15"
G: null
```
- Store quoted triple as serialized string in S or O columns
- Query by exact match on serialized form
- Pro: Simple, fits existing index patterns
- Con: Can't query "find triples where quoted subject's predicate is X"
**Option B: Triple IDs / Hashes**
```
Triple table:
id: hash(s,p,o,g)
s, p, o, g: ...
Metadata table:
subject_triple_id: <hash>
p: http://ex/discoveredOn
o: "2024-01-15"
```
- Assign each triple an ID (hash of components)
- Reification metadata references triples by ID
- Pro: Clean separation, can index triple IDs
- Con: Requires computing/managing triple identity, two-phase lookups
**Recommendation**: Start with Option A (serialized strings) for simplicity.
Option B may be needed if advanced query patterns over quoted triple
components are required.
2. **Phase 2+: Other Backends**
- Neo4j and other stores implemented in subsequent stages
- Lessons learned from Cassandra inform these implementations
This approach de-risks the design by validating on a fully-controlled backend
before committing to implementations across all stores.
#### Value → Term Rename
The `Value` class will be renamed to `Term`. This affects ~78 files across
the codebase. The rename acts as a forcing function: any code still using
`Value` is immediately identifiable as needing review/update for 2.0
compatibility.
## Security Considerations
Named graphs are not a security feature. Users and collections remain the
security boundaries. Named graphs are purely for data organization and
reification support.
## Performance Considerations
- Quoted triples add nesting depth - may impact query performance
- Named graph indexing strategies needed for efficient graph-scoped queries
- Cassandra schema design will need to accommodate quad storage efficiently
### Vector Store Boundary
Vector stores always reference IRIs only:
- Never edges (quoted triples)
- Never literal values
- Never blank nodes
This keeps the vector store simple - it handles semantic similarity of named
entities. The graph structure handles relationships, reification, and metadata.
Quoted triples and named graphs don't complicate vector operations.
## Testing Strategy
Use existing test strategy. As this is a breaking change, extensive focus on
the end-to-end test suite to validate the new structures work correctly across
all components.
## Migration Plan
- 2.0 is a breaking release; no backward compatibility required
- Existing data may need migration to new schema (TBD based on final design)
- Consider migration tooling for converting existing triples
## Open Questions
- **Blank nodes**: Limited support confirmed. May need to decide on
skolemization strategy (generate IRIs on load, or preserve blank node IDs).
- **Query syntax**: What is the concrete syntax for specifying quoted triples
in queries? Need to define the query API.
- ~~**Predicate vocabulary**~~: Resolved. Any valid RDF predicates permitted,
including custom user-defined. Minimal assumptions about RDF validity.
Very few locked-in values (e.g., `rdfs:label` used in some places).
Strategy: avoid locking anything in unless absolutely necessary.
- ~~**Vector store impact**~~: Resolved. Vector stores always point to IRIs
only - never edges, literals, or blank nodes. Quoted triples and
reification don't affect the vector store.
- ~~**Named graph semantics**~~: Resolved. Queries default to the default
graph (matches SPARQL behavior, backward compatible). Explicit graph
parameter required to query named graphs or all graphs.
## References
- [RDF 1.2 Concepts](https://www.w3.org/TR/rdf12-concepts/)
- [RDF-star and SPARQL-star](https://w3c.github.io/rdf-star/)
- [RDF Dataset](https://www.w3.org/TR/rdf11-concepts/#section-dataset)

View file

@ -0,0 +1,455 @@
# JSONL Prompt Output Technical Specification
## Overview
This specification describes the implementation of JSONL (JSON Lines) output
format for prompt responses in TrustGraph. JSONL enables truncation-resilient
extraction of structured data from LLM responses, addressing critical issues
with JSON array outputs being corrupted when LLM responses hit output token
limits.
This implementation supports the following use cases:
1. **Truncation-Resilient Extraction**: Extract valid partial results even when
LLM output is truncated mid-response
2. **Large-Scale Extraction**: Handle extraction of many items without risk of
complete failure due to token limits
3. **Mixed-Type Extraction**: Support extraction of multiple entity types
(definitions, relationships, entities, attributes) in a single prompt
4. **Streaming-Compatible Output**: Enable future streaming/incremental
processing of extraction results
## Goals
- **Backward Compatibility**: Existing prompts using `response-type: "text"` and
`response-type: "json"` continue to work without modification
- **Truncation Resilience**: Partial LLM outputs yield partial valid results
rather than complete failure
- **Schema Validation**: Support JSON Schema validation for individual objects
- **Discriminated Unions**: Support mixed-type outputs using a `type` field
discriminator
- **Minimal API Changes**: Extend existing prompt configuration with new
response type and schema key
## Background
### Current Architecture
The prompt service supports two response types:
1. `response-type: "text"` - Raw text response returned as-is
2. `response-type: "json"` - JSON parsed from response, validated against
optional `schema`
Current implementation in `trustgraph-flow/trustgraph/template/prompt_manager.py`:
```python
class Prompt:
def __init__(self, template, response_type = "text", terms=None, schema=None):
self.template = template
self.response_type = response_type
self.terms = terms
self.schema = schema
```
### Current Limitations
When extraction prompts request output as JSON arrays (`[{...}, {...}, ...]`):
- **Truncation corruption**: If the LLM hits output token limits mid-array, the
entire response becomes invalid JSON and cannot be parsed
- **All-or-nothing parsing**: Must receive complete output before parsing
- **No partial results**: A truncated response yields zero usable data
- **Unreliable for large extractions**: More extracted items = higher failure risk
This specification addresses these limitations by introducing JSONL format for
extraction prompts, where each extracted item is a complete JSON object on its
own line.
## Technical Design
### Response Type Extension
Add a new response type `"jsonl"` alongside existing `"text"` and `"json"` types.
#### Configuration Changes
**New response type value:**
```
"response-type": "jsonl"
```
**Schema interpretation:**
The existing `"schema"` key is used for both `"json"` and `"jsonl"` response
types. The interpretation depends on the response type:
- `"json"`: Schema describes the entire response (typically an array or object)
- `"jsonl"`: Schema describes each individual line/object
```json
{
"response-type": "jsonl",
"schema": {
"type": "object",
"properties": {
"entity": { "type": "string" },
"definition": { "type": "string" }
},
"required": ["entity", "definition"]
}
}
```
This avoids changes to prompt configuration tooling and editors.
### JSONL Format Specification
#### Simple Extraction
For prompts extracting a single type of object (definitions, relationships,
topics, rows), the output is one JSON object per line with no wrapper:
**Prompt output format:**
```
{"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"}
{"entity": "chlorophyll", "definition": "Green pigment in plants"}
{"entity": "mitochondria", "definition": "Powerhouse of the cell"}
```
**Contrast with previous JSON array format:**
```json
[
{"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"},
{"entity": "chlorophyll", "definition": "Green pigment in plants"},
{"entity": "mitochondria", "definition": "Powerhouse of the cell"}
]
```
If the LLM truncates after line 2, the JSON array format yields invalid JSON,
while JSONL yields two valid objects.
#### Mixed-Type Extraction (Discriminated Unions)
For prompts extracting multiple types of objects (e.g., both definitions and
relationships, or entities, relationships, and attributes), use a `"type"`
field as discriminator:
**Prompt output format:**
```
{"type": "definition", "entity": "DNA", "definition": "Molecule carrying genetic instructions"}
{"type": "relationship", "subject": "DNA", "predicate": "located_in", "object": "cell nucleus", "object-entity": true}
{"type": "definition", "entity": "RNA", "definition": "Molecule that carries genetic information"}
{"type": "relationship", "subject": "RNA", "predicate": "transcribed_from", "object": "DNA", "object-entity": true}
```
**Schema for discriminated unions uses `oneOf`:**
```json
{
"response-type": "jsonl",
"schema": {
"oneOf": [
{
"type": "object",
"properties": {
"type": { "const": "definition" },
"entity": { "type": "string" },
"definition": { "type": "string" }
},
"required": ["type", "entity", "definition"]
},
{
"type": "object",
"properties": {
"type": { "const": "relationship" },
"subject": { "type": "string" },
"predicate": { "type": "string" },
"object": { "type": "string" },
"object-entity": { "type": "boolean" }
},
"required": ["type", "subject", "predicate", "object", "object-entity"]
}
]
}
}
```
#### Ontology Extraction
For ontology-based extraction with entities, relationships, and attributes:
**Prompt output format:**
```
{"type": "entity", "entity": "Cornish pasty", "entity_type": "fo/Recipe"}
{"type": "entity", "entity": "beef", "entity_type": "fo/Food"}
{"type": "relationship", "subject": "Cornish pasty", "subject_type": "fo/Recipe", "relation": "fo/has_ingredient", "object": "beef", "object_type": "fo/Food"}
{"type": "attribute", "entity": "Cornish pasty", "entity_type": "fo/Recipe", "attribute": "fo/serves", "value": "4 people"}
```
### Implementation Details
#### Prompt Class
The existing `Prompt` class requires no changes. The `schema` field is reused
for JSONL, with its interpretation determined by `response_type`:
```python
class Prompt:
def __init__(self, template, response_type="text", terms=None, schema=None):
self.template = template
self.response_type = response_type
self.terms = terms
self.schema = schema # Interpretation depends on response_type
```
#### PromptManager.load_config
No changes required - existing configuration loading already handles the
`schema` key.
#### JSONL Parsing
Add a new parsing method for JSONL responses:
```python
def parse_jsonl(self, text):
"""
Parse JSONL response, returning list of valid objects.
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
This provides truncation resilience - partial output yields partial results.
"""
results = []
for line_num, line in enumerate(text.strip().split('\n'), 1):
line = line.strip()
# Skip empty lines
if not line:
continue
# Skip markdown code fence markers if present
if line.startswith('```'):
continue
try:
obj = json.loads(line)
results.append(obj)
except json.JSONDecodeError as e:
# Log warning but continue - this provides truncation resilience
logger.warning(f"JSONL parse error on line {line_num}: {e}")
return results
```
#### PromptManager.invoke Changes
Extend the invoke method to handle the new response type:
```python
async def invoke(self, id, input, llm):
logger.debug("Invoking prompt template...")
terms = self.terms | self.prompts[id].terms | input
resp_type = self.prompts[id].response_type
prompt = {
"system": self.system_template.render(terms),
"prompt": self.render(id, input)
}
resp = await llm(**prompt)
if resp_type == "text":
return resp
if resp_type == "json":
try:
obj = self.parse_json(resp)
except:
logger.error(f"JSON parse failed: {resp}")
raise RuntimeError("JSON parse fail")
if self.prompts[id].schema:
try:
validate(instance=obj, schema=self.prompts[id].schema)
logger.debug("Schema validation successful")
except Exception as e:
raise RuntimeError(f"Schema validation fail: {e}")
return obj
if resp_type == "jsonl":
objects = self.parse_jsonl(resp)
if not objects:
logger.warning("JSONL parse returned no valid objects")
return []
# Validate each object against schema if provided
if self.prompts[id].schema:
validated = []
for i, obj in enumerate(objects):
try:
validate(instance=obj, schema=self.prompts[id].schema)
validated.append(obj)
except Exception as e:
logger.warning(f"Object {i} failed schema validation: {e}")
return validated
return objects
raise RuntimeError(f"Response type {resp_type} not known")
```
### Affected Prompts
The following prompts should be migrated to JSONL format:
| Prompt ID | Description | Type Field |
|-----------|-------------|------------|
| `extract-definitions` | Entity/definition extraction | No (single type) |
| `extract-relationships` | Relationship extraction | No (single type) |
| `extract-topics` | Topic/definition extraction | No (single type) |
| `extract-rows` | Structured row extraction | No (single type) |
| `agent-kg-extract` | Combined definition + relationship extraction | Yes: `"definition"`, `"relationship"` |
| `extract-with-ontologies` / `ontology-extract` | Ontology-based extraction | Yes: `"entity"`, `"relationship"`, `"attribute"` |
### API Changes
#### Client Perspective
JSONL parsing is transparent to prompt service API callers. The parsing occurs
server-side in the prompt service, and the response is returned via the standard
`PromptResponse.object` field as a serialized JSON array.
When clients call the prompt service (via `PromptClient.prompt()` or similar):
- **`response-type: "json"`** with array schema → client receives Python `list`
- **`response-type: "jsonl"`** → client receives Python `list`
From the client's perspective, both return identical data structures. The
difference is entirely in how the LLM output is parsed server-side:
- JSON array format: Single `json.loads()` call; fails completely if truncated
- JSONL format: Line-by-line parsing; yields partial results if truncated
This means existing client code expecting a list from extraction prompts
requires no changes when migrating prompts from JSON to JSONL format.
#### Server Return Value
For `response-type: "jsonl"`, the `PromptManager.invoke()` method returns a
`list[dict]` containing all successfully parsed and validated objects. This
list is then serialized to JSON for the `PromptResponse.object` field.
#### Error Handling
- Empty results: Returns empty list `[]` with warning log
- Partial parse failure: Returns list of successfully parsed objects with
warning logs for failures
- Complete parse failure: Returns empty list `[]` with warning logs
This differs from `response-type: "json"` which raises `RuntimeError` on
parse failure. The lenient behavior for JSONL is intentional to provide
truncation resilience.
### Configuration Example
Complete prompt configuration example:
```json
{
"prompt": "Extract all entities and their definitions from the following text. Output one JSON object per line.\n\nText:\n{{text}}\n\nOutput format per line:\n{\"entity\": \"<name>\", \"definition\": \"<definition>\"}",
"response-type": "jsonl",
"schema": {
"type": "object",
"properties": {
"entity": {
"type": "string",
"description": "The entity name"
},
"definition": {
"type": "string",
"description": "A clear definition of the entity"
}
},
"required": ["entity", "definition"]
}
}
```
## Security Considerations
- **Input Validation**: JSON parsing uses standard `json.loads()` which is safe
against injection attacks
- **Schema Validation**: Uses `jsonschema.validate()` for schema enforcement
- **No New Attack Surface**: JSONL parsing is strictly safer than JSON array
parsing due to line-by-line processing
## Performance Considerations
- **Memory**: Line-by-line parsing uses less peak memory than loading full JSON
arrays
- **Latency**: Parsing performance is comparable to JSON array parsing
- **Validation**: Schema validation runs per-object, which adds overhead but
enables partial results on validation failure
## Testing Strategy
### Unit Tests
- JSONL parsing with valid input
- JSONL parsing with empty lines
- JSONL parsing with markdown code fences
- JSONL parsing with truncated final line
- JSONL parsing with invalid JSON lines interspersed
- Schema validation with `oneOf` discriminated unions
- Backward compatibility: existing `"text"` and `"json"` prompts unchanged
### Integration Tests
- End-to-end extraction with JSONL prompts
- Extraction with simulated truncation (artificially limited response)
- Mixed-type extraction with type discriminator
- Ontology extraction with all three types
### Extraction Quality Tests
- Compare extraction results: JSONL vs JSON array format
- Verify truncation resilience: JSONL yields partial results where JSON fails
## Migration Plan
### Phase 1: Implementation
1. Implement `parse_jsonl()` method in `PromptManager`
2. Extend `invoke()` to handle `response-type: "jsonl"`
3. Add unit tests
### Phase 2: Prompt Migration
1. Update `extract-definitions` prompt and configuration
2. Update `extract-relationships` prompt and configuration
3. Update `extract-topics` prompt and configuration
4. Update `extract-rows` prompt and configuration
5. Update `agent-kg-extract` prompt and configuration
6. Update `extract-with-ontologies` prompt and configuration
### Phase 3: Downstream Updates
1. Update any code consuming extraction results to handle list return type
2. Update code that categorizes mixed-type extractions by `type` field
3. Update tests that assert on extraction output format
## Open Questions
None at this time.
## References
- Current implementation: `trustgraph-flow/trustgraph/template/prompt_manager.py`
- JSON Lines specification: https://jsonlines.org/
- JSON Schema `oneOf`: https://json-schema.org/understanding-json-schema/reference/combining.html#oneof
- Related specification: Streaming LLM Responses (`docs/tech-specs/streaming-llm-responses.md`)

View file

@ -0,0 +1,613 @@
# Structured Data Technical Specification (Part 2)
## Overview
This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`.
## Problem Statements
### 1. Naming Inconsistency: "Object" vs "Row"
The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion:
- "Object" is an overloaded term in software (Python objects, JSON objects, etc.)
- The data being handled is fundamentally tabular - rows in tables with defined schemas
- "Row" more accurately describes the data model and aligns with database terminology
This inconsistency appears in module names, class names, message types, and documentation.
### 2. Row Store Query Limitations
The current row store implementation has significant query limitations:
**Natural Language Mismatch**: Queries struggle with real-world data variations. For example:
- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"`
- Abbreviations, case differences, and formatting variations break exact-match queries
- Users expect semantic understanding, but the store provides literal matching
**Schema Evolution Issues**: Changing schemas causes problems:
- Existing data may not conform to updated schemas
- Table structure changes can break queries and data integrity
- No clear migration path for schema updates
### 3. Row Embeddings Required
Related to problem 2, the system needs vector embeddings for row data to enable:
- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST")
- Similarity matching for fuzzy queries
- Hybrid search combining structured filters with semantic similarity
- Better natural language query support
The embedding service was specified but not implemented.
### 4. Row Data Ingestion Incomplete
The structured data ingestion pipeline is not fully operational:
- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.)
- The ingestion service that uses these prompts is not plumbed into the system
- No end-to-end path for loading pre-structured data into the row store
## Goals
- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations
- **Consistent Naming**: Standardize on "row" terminology throughout the codebase
- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings
- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data
## Technical Design
### Unified Row Storage Schema
The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations.
The new design uses a single unified table for all row data:
```sql
CREATE TABLE rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
```
#### Column Definitions
| Column | Type | Description |
|--------|------|-------------|
| `collection` | `text` | Data collection/import identifier (from metadata) |
| `schema_name` | `text` | Name of the schema this row conforms to |
| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites |
| `index_value` | `frozen<list<text>>` | Index value(s) as a list |
| `data` | `map<text, text>` | Row data as key-value pairs |
| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. |
#### Index Handling
Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility.
**Single-field index example:**
- Schema defines `email` as indexed
- `index_name = "email"`
- `index_value = ['foo@bar.com']`
**Composite index example:**
- Schema defines composite index on `region` and `status`
- `index_name = "region,status"` (field names sorted and comma-joined)
- `index_value = ['US', 'active']` (values in same order as field names)
**Primary key example:**
- Schema defines `customer_id` as primary key
- `index_name = "customer_id"`
- `index_value = ['CUST001']`
#### Query Patterns
All queries follow the same pattern regardless of which index is used:
```sql
SELECT * FROM rows
WHERE collection = 'import_2024'
AND schema_name = 'customers'
AND index_name = 'email'
AND index_value = ['foo@bar.com']
```
#### Design Trade-offs
**Advantages:**
- Schema changes don't require table structure changes
- Row data is opaque to Cassandra - field additions/removals are transparent
- Consistent query pattern for all access methods
- No Cassandra secondary indexes (which can be slow at scale)
- Native Cassandra types throughout (`map`, `frozen<list>`)
**Trade-offs:**
- Write amplification: each row insert = N inserts (one per indexed field)
- Storage overhead from duplicated row data
- Type information stored in schema config, conversion at application layer
#### Consistency Model
The design accepts certain simplifications:
1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row.
2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed.
### Partition Tracking and Deletion
#### The Problem
With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data.
#### Partition Tracking Table
A secondary lookup table tracks which partitions exist:
```sql
CREATE TABLE row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
```
This enables efficient discovery of partitions for deletion operations.
#### Row Writer Behavior
The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row:
1. Check if `(collection, schema_name)` is in the cache
2. If not cached (first row for this pair):
- Look up the schema config to get all index names
- Insert entries into `row_partitions` for each `(collection, schema_name, index_name)`
- Add the pair to the cache
3. Proceed with writing the row data
The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names.
This approach ensures:
- Lookup table writes happen once per `(collection, schema_name)` pair, not per row
- The lookup table reflects the indexes that were active when data was written
- Schema changes mid-import are picked up correctly
#### Deletion Operations
**Delete collection:**
```sql
-- 1. Discover all partitions
SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X';
-- 2. Delete each partition from rows table
DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...';
-- (repeat for each discovered partition)
-- 3. Clean up the lookup table
DELETE FROM row_partitions WHERE collection = 'X';
```
**Delete collection + schema:**
```sql
-- 1. Discover partitions for this schema
SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
-- 2. Delete each partition from rows table
DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...';
-- (repeat for each discovered partition)
-- 3. Clean up the lookup table entries
DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
```
### Row Embeddings
Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street").
#### Design Overview
Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra.
#### Qdrant Collection Structure
One Qdrant collection per `(user, collection, schema_name, dimension)` tuple:
- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}`
- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix)
- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist
#### What Gets Embedded
The text representation of index values:
| Index Type | Example `index_value` | Text to Embed |
|------------|----------------------|---------------|
| Single-field | `['foo@bar.com']` | `"foo@bar.com"` |
| Composite | `['US', 'active']` | `"US active"` (space-joined) |
#### Point Structure
Each Qdrant point contains:
```json
{
"id": "<uuid>",
"vector": [0.1, 0.2, ...],
"payload": {
"index_name": "street_name",
"index_value": ["CHESTNUT ST"],
"text": "CHESTNUT ST"
}
}
```
| Payload Field | Description |
|---------------|-------------|
| `index_name` | The indexed field(s) this embedding represents |
| `index_value` | The original list of values (for Cassandra lookup) |
| `text` | The text that was embedded (for debugging/display) |
Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name.
#### Query Flow
1. User queries for "Chestnut Street" within user U, collection X, schema Y
2. Embed the query text
3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_`
4. Search matching Qdrant collection(s) for nearest vectors
5. Get matching points with payloads containing `index_name` and `index_value`
6. Query Cassandra:
```sql
SELECT * FROM rows
WHERE collection = 'X'
AND schema_name = 'Y'
AND index_name = '<from payload>'
AND index_value = <from payload>
```
7. Return matched rows
#### Optional: Filtering by Index Name
Queries can optionally filter by `index_name` in Qdrant to search only specific fields:
- **"Find any field matching 'Chestnut'"** → search all vectors in the collection
- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'`
#### Architecture
Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings):
- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings`
- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant
The Cassandra row writer is a separate parallel consumer:
- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra
All three services consume from the same flow, keeping them decoupled. This allows:
- Independent scaling of Cassandra writes vs embedding generation vs vector storage
- Embedding services can be disabled if not needed
- Failures in one service don't affect the others
- Consistent architecture with GraphRAG pipelines
#### Write Path
**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`:
1. Look up the schema to find indexed fields
2. For each indexed field:
- Build the text representation of the index value
- Compute embedding via the embeddings service
3. Output a `RowEmbeddings` message containing all computed vectors
**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`:
1. For each embedding in the message:
- Determine Qdrant collection from `(user, collection, schema_name, dimension)`
- Create collection if needed (lazy creation on first write)
- Upsert point with vector and payload
#### Message Types
```python
@dataclass
class RowIndexEmbedding:
index_name: str # The indexed field name(s)
index_value: list[str] # The field value(s)
text: str # Text that was embedded
vectors: list[list[float]] # Computed embedding vectors
@dataclass
class RowEmbeddings:
metadata: Metadata
schema_name: str
embeddings: list[RowIndexEmbedding]
```
#### Deletion Integration
Qdrant collections are discovered by prefix matching on the collection name pattern:
**Delete `(user, collection)`:**
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_`
2. Delete each matching collection
3. Delete Cassandra rows partitions (as documented above)
4. Clean up `row_partitions` entries
**Delete `(user, collection, schema_name)`:**
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_`
2. Delete each matching collection (handles multiple dimensions)
3. Delete Cassandra rows partitions
4. Clean up `row_partitions`
#### Module Locations
| Stage | Module | Entry Point |
|-------|--------|-------------|
| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` |
| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` |
### Row Embeddings Query API
The row embeddings query is a **separate API** from the GraphQL row query service:
| API | Purpose | Backend |
|-----|---------|---------|
| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra |
| Row Embeddings Query | Fuzzy/semantic matching | Qdrant |
This separation keeps concerns clean:
- GraphQL service focuses on exact, structured queries
- Embeddings API handles semantic similarity
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
#### Request/Response Schema
```python
@dataclass
class RowEmbeddingsRequest:
vectors: list[list[float]] # Query vectors (pre-computed embeddings)
user: str = ""
collection: str = ""
schema_name: str = ""
index_name: str = "" # Optional: filter to specific index
limit: int = 10 # Max results per vector
@dataclass
class RowIndexMatch:
index_name: str = "" # The matched index field(s)
index_value: list[str] = [] # The matched value(s)
text: str = "" # Original text that was embedded
score: float = 0.0 # Similarity score
@dataclass
class RowEmbeddingsResponse:
error: Error | None = None
matches: list[RowIndexMatch] = []
```
#### Query Processor
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
Entry point: `row-embeddings-query-qdrant`
The processor:
1. Receives `RowEmbeddingsRequest` with query vectors
2. Finds the appropriate Qdrant collection by prefix matching
3. Searches for nearest vectors with optional `index_name` filter
4. Returns `RowEmbeddingsResponse` with matching index information
#### API Gateway Integration
The gateway exposes row embeddings queries via the standard request/response pattern:
| Component | Location |
|-----------|----------|
| Dispatcher | `trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py` |
| Registration | Add `"row-embeddings"` to `request_response_dispatchers` in `manager.py` |
Flow interface name: `row-embeddings`
Interface definition in flow blueprint:
```json
{
"interfaces": {
"row-embeddings": {
"request": "non-persistent://tg/request/row-embeddings:{id}",
"response": "non-persistent://tg/response/row-embeddings:{id}"
}
}
}
```
#### Python SDK Support
The SDK provides methods for row embeddings queries:
```python
# Flow-scoped query (preferred)
api = Api(url)
flow = api.flow().id("default")
# Query with text (SDK computes embeddings)
matches = flow.row_embeddings_query(
text="Chestnut Street",
collection="my_collection",
schema_name="addresses",
index_name="street_name", # Optional filter
limit=10
)
# Query with pre-computed vectors
matches = flow.row_embeddings_query(
vectors=[[0.1, 0.2, ...]],
collection="my_collection",
schema_name="addresses"
)
# Each match contains:
for match in matches:
print(match.index_name) # e.g., "street_name"
print(match.index_value) # e.g., ["CHESTNUT ST"]
print(match.text) # e.g., "CHESTNUT ST"
print(match.score) # e.g., 0.95
```
#### CLI Utility
Command: `tg-invoke-row-embeddings`
```bash
# Query by text (computes embedding automatically)
tg-invoke-row-embeddings \
--text "Chestnut Street" \
--collection my_collection \
--schema addresses \
--index street_name \
--limit 10
# Query by vector file
tg-invoke-row-embeddings \
--vectors vectors.json \
--collection my_collection \
--schema addresses
# Output formats
tg-invoke-row-embeddings --text "..." --format json
tg-invoke-row-embeddings --text "..." --format table
```
#### Typical Usage Pattern
The row embeddings query is typically used as part of a fuzzy-to-exact lookup flow:
```python
# Step 1: Fuzzy search via embeddings
matches = flow.row_embeddings_query(
text="chestnut street",
collection="geo",
schema_name="streets"
)
# Step 2: Exact lookup via GraphQL for full row data
for match in matches:
query = f'''
query {{
streets(where: {{ {match.index_name}: {{ eq: "{match.index_value[0]}" }} }}) {{
street_name
city
zip_code
}}
}}
'''
rows = flow.rows_query(query, collection="geo")
```
This two-step pattern enables:
- Finding "CHESTNUT ST" when user searches for "Chestnut Street"
- Retrieving complete row data with all fields
- Combining semantic similarity with structured data access
### Row Data Ingestion
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
## Implementation Impact
### Current State Analysis
The existing implementation has two main components:
| Component | Location | Lines | Description |
|-----------|----------|-------|-------------|
| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling |
| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete |
**Current Query Pattern:**
```sql
SELECT * FROM {keyspace}.o_{schema_name}
WHERE collection = 'X' AND email = 'foo@bar.com'
ALLOW FILTERING
```
**New Query Pattern:**
```sql
SELECT * FROM {keyspace}.rows
WHERE collection = 'X' AND schema_name = 'customers'
AND index_name = 'email' AND index_value = ['foo@bar.com']
```
### Key Changes
1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either:
- Become post-filtering on returned data (if still needed)
- Are removed in favor of using the embeddings API for fuzzy matching
2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code.
### Proposed Refactor
The refactor has two parts:
#### 1. Break Out GraphQL Code
Extract reusable GraphQL components into a shared module:
```
trustgraph-flow/trustgraph/query/graphql/
├── __init__.py
├── types.py # Filter types (IntFilter, StringFilter, FloatFilter)
├── schema.py # Dynamic schema generation from RowSchema
└── filters.py # Filter parsing utilities
```
This enables:
- Reuse across different row store backends
- Cleaner separation of concerns
- Easier testing of GraphQL logic independently
#### 2. Implement New Table Schema
Refactor the Cassandra-specific code to use the unified table:
**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`):
- Single `rows` table instead of per-schema tables
- Write N copies per row (one per index)
- Register to `row_partitions` table
- Simpler table creation (one-time setup)
**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`):
- Query the unified `rows` table
- Use extracted GraphQL module for schema generation
- Simplified filter handling (exact match only at DB level)
### Module Renames
As part of the "object" → "row" naming cleanup:
| Current | New |
|---------|-----|
| `storage/objects/cassandra/` | `storage/rows/cassandra/` |
| `query/objects/cassandra/` | `query/rows/cassandra/` |
| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` |
### New Modules
| Module | Purpose |
|--------|---------|
| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities |
| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API |
| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) |
| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) |
## References
- [Structured Data Technical Specification](structured-data.md)

View file

@ -0,0 +1,39 @@
type: object
description: |
Row embeddings query request - find similar rows by vector similarity on indexed fields.
Enables semantic/fuzzy matching on structured data.
required:
- vectors
- schema_name
properties:
vectors:
type: array
description: Query embedding vector
items:
type: number
example: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156]
schema_name:
type: string
description: Schema name to search within
example: customers
index_name:
type: string
description: Optional index name to filter search to specific index
example: full_name
limit:
type: integer
description: Maximum number of matches to return
default: 10
minimum: 1
maximum: 1000
example: 20
user:
type: string
description: User identifier
default: trustgraph
example: alice
collection:
type: string
description: Collection to search
default: default
example: sales

View file

@ -0,0 +1,53 @@
type: object
description: Row embeddings query response with matching row index information
properties:
error:
type: object
description: Error information if query failed
properties:
type:
type: string
description: Error type identifier
example: row-embeddings-query-error
message:
type: string
description: Human-readable error message
example: Schema not found
matches:
type: array
description: List of matching row index entries with similarity scores
items:
type: object
properties:
index_name:
type: string
description: Name of the indexed field(s)
example: full_name
index_value:
type: array
description: Values of the indexed fields for this row
items:
type: string
example: ["John", "Smith"]
text:
type: string
description: The text that was embedded for this index entry
example: "John Smith"
score:
type: number
description: Similarity score (higher is more similar)
example: 0.89
example:
matches:
- index_name: full_name
index_value: ["John", "Smith"]
text: "John Smith"
score: 0.95
- index_name: full_name
index_value: ["Jon", "Smythe"]
text: "Jon Smythe"
score: 0.82
- index_name: full_name
index_value: ["Jonathan", "Schmidt"]
text: "Jonathan Schmidt"
score: 0.76

View file

@ -1,6 +1,6 @@
type: object
description: |
Objects query request - GraphQL query over knowledge graph.
Rows query request - GraphQL query over structured data.
required:
- query
properties:

View file

@ -1,5 +1,5 @@
type: object
description: Objects query response (GraphQL format)
description: Rows query response (GraphQL format)
properties:
data:
description: GraphQL response data (JSON object or null)

View file

@ -121,8 +121,8 @@ paths:
$ref: './paths/flow/mcp-tool.yaml'
/api/v1/flow/{flow}/service/triples:
$ref: './paths/flow/triples.yaml'
/api/v1/flow/{flow}/service/objects:
$ref: './paths/flow/objects.yaml'
/api/v1/flow/{flow}/service/rows:
$ref: './paths/flow/rows.yaml'
/api/v1/flow/{flow}/service/nlp-query:
$ref: './paths/flow/nlp-query.yaml'
/api/v1/flow/{flow}/service/structured-query:
@ -133,6 +133,8 @@ paths:
$ref: './paths/flow/graph-embeddings.yaml'
/api/v1/flow/{flow}/service/document-embeddings:
$ref: './paths/flow/document-embeddings.yaml'
/api/v1/flow/{flow}/service/row-embeddings:
$ref: './paths/flow/row-embeddings.yaml'
/api/v1/flow/{flow}/service/text-load:
$ref: './paths/flow/text-load.yaml'
/api/v1/flow/{flow}/service/document-load:

View file

@ -34,7 +34,7 @@ post:
```
1. User asks: "Who does Alice know?"
2. NLP Query generates GraphQL
3. Execute via /api/v1/flow/{flow}/service/objects
3. Execute via /api/v1/flow/{flow}/service/rows
4. Return results to user
```

View file

@ -0,0 +1,101 @@
post:
tags:
- Flow Services
summary: Row Embeddings Query - semantic search on structured data
description: |
Query row embeddings to find similar rows by vector similarity on indexed fields.
Enables fuzzy/semantic matching on structured data.
## Row Embeddings Query Overview
Find rows whose indexed field values are semantically similar to a query:
- **Input**: Query embedding vector, schema name, optional index filter
- **Search**: Compare against stored row index embeddings
- **Output**: Matching rows with index values and similarity scores
Core component of semantic search on structured data.
## Use Cases
- **Fuzzy name matching**: Find customers by approximate name
- **Semantic field search**: Find products by description similarity
- **Data deduplication**: Identify potential duplicate records
- **Entity resolution**: Match records across datasets
## Process
1. Obtain query embedding (via embeddings service)
2. Query stored row index embeddings for the specified schema
3. Calculate cosine similarity
4. Return top N most similar index entries
5. Use index values to retrieve full rows via GraphQL
## Response Format
Each match includes:
- `index_name`: The indexed field(s) that matched
- `index_value`: The actual values for those fields
- `text`: The text that was embedded
- `score`: Similarity score (higher = more similar)
operationId: rowEmbeddingsQueryService
security:
- bearerAuth: []
parameters:
- name: flow
in: path
required: true
schema:
type: string
description: Flow instance ID
example: my-flow
requestBody:
required: true
content:
application/json:
schema:
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
examples:
basicQuery:
summary: Find similar customer names
value:
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
schema_name: customers
limit: 10
user: alice
collection: sales
filteredQuery:
summary: Search specific index
value:
vectors: [0.1, -0.2, 0.3, -0.4, 0.5]
schema_name: products
index_name: description
limit: 20
responses:
'200':
description: Successful response
content:
application/json:
schema:
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml'
examples:
similarRows:
summary: Similar rows found
value:
matches:
- index_name: full_name
index_value: ["John", "Smith"]
text: "John Smith"
score: 0.95
- index_name: full_name
index_value: ["Jon", "Smythe"]
text: "Jon Smythe"
score: 0.82
- index_name: full_name
index_value: ["Jonathan", "Schmidt"]
text: "Jonathan Schmidt"
score: 0.76
'401':
$ref: '../../components/responses/Unauthorized.yaml'
'500':
$ref: '../../components/responses/Error.yaml'

View file

@ -1,19 +1,19 @@
post:
tags:
- Flow Services
summary: Objects query - GraphQL over knowledge graph
summary: Rows query - GraphQL over structured data
description: |
Query knowledge graph using GraphQL for object-oriented data access.
Query structured data using GraphQL for row-oriented data access.
## Objects Query Overview
## Rows Query Overview
GraphQL interface to knowledge graph:
GraphQL interface to structured data:
- **Schema-driven**: Predefined types and relationships
- **Flexible queries**: Request exactly what you need
- **Nested data**: Traverse relationships in single query
- **Type-safe**: Strong typing with introspection
Abstracts RDF triples into familiar object model.
Abstracts structured rows into familiar object model.
## GraphQL Benefits
@ -61,7 +61,7 @@ post:
Schema defines available types via config service.
Use introspection query to discover schema.
operationId: objectsQueryService
operationId: rowsQueryService
security:
- bearerAuth: []
parameters:
@ -77,7 +77,7 @@ post:
content:
application/json:
schema:
$ref: '../../components/schemas/query/ObjectsQueryRequest.yaml'
$ref: '../../components/schemas/query/RowsQueryRequest.yaml'
examples:
simpleQuery:
summary: Simple query
@ -129,7 +129,7 @@ post:
content:
application/json:
schema:
$ref: '../../components/schemas/query/ObjectsQueryResponse.yaml'
$ref: '../../components/schemas/query/RowsQueryResponse.yaml'
examples:
successfulQuery:
summary: Successful query

View file

@ -9,7 +9,7 @@ post:
Combines two operations in one call:
1. **NLP Query**: Generate GraphQL from question
2. **Objects Query**: Execute generated query
2. **Rows Query**: Execute generated query
3. **Return Results**: Direct answer data
Simplest way to query knowledge graph with natural language.
@ -21,7 +21,7 @@ post:
- **Output**: Query results (data)
- **Use when**: Want simple, direct answers
### NLP Query + Objects Query (separate calls)
### NLP Query + Rows Query (separate calls)
- **Step 1**: Convert question → GraphQL
- **Step 2**: Execute GraphQL → results
- **Use when**: Need to inspect/modify query before execution

View file

@ -25,12 +25,13 @@ payload:
- $ref: './requests/EmbeddingsRequest.yaml'
- $ref: './requests/McpToolRequest.yaml'
- $ref: './requests/TriplesRequest.yaml'
- $ref: './requests/ObjectsRequest.yaml'
- $ref: './requests/RowsRequest.yaml'
- $ref: './requests/NlpQueryRequest.yaml'
- $ref: './requests/StructuredQueryRequest.yaml'
- $ref: './requests/StructuredDiagRequest.yaml'
- $ref: './requests/GraphEmbeddingsRequest.yaml'
- $ref: './requests/DocumentEmbeddingsRequest.yaml'
- $ref: './requests/RowEmbeddingsRequest.yaml'
- $ref: './requests/TextLoadRequest.yaml'
- $ref: './requests/DocumentLoadRequest.yaml'

View file

@ -0,0 +1,30 @@
type: object
description: WebSocket request for row-embeddings service (flow-hosted service)
required:
- id
- service
- flow
- request
properties:
id:
type: string
description: Unique request identifier
service:
type: string
const: row-embeddings
description: Service identifier for row-embeddings service
flow:
type: string
description: Flow ID
request:
$ref: '../../../../api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
examples:
- id: req-1
service: row-embeddings
flow: my-flow
request:
vectors: [0.023, -0.142, 0.089, 0.234]
schema_name: customers
limit: 10
user: trustgraph
collection: default

View file

@ -1,5 +1,5 @@
type: object
description: WebSocket request for objects service (flow-hosted service)
description: WebSocket request for rows service (flow-hosted service)
required:
- id
- service
@ -11,16 +11,16 @@ properties:
description: Unique request identifier
service:
type: string
const: objects
description: Service identifier for objects service
const: rows
description: Service identifier for rows service
flow:
type: string
description: Flow ID
request:
$ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml'
$ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml'
examples:
- id: req-1
service: objects
service: rows
flow: my-flow
request:
query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }"

View file

@ -15,10 +15,10 @@ from trustgraph.schema import (
TextCompletionRequest, TextCompletionResponse,
DocumentRagQuery, DocumentRagResponse,
AgentRequest, AgentResponse, AgentStep,
Chunk, Triple, Triples, Value, Error,
Chunk, Triple, Triples, Term, Error,
EntityContext, EntityContexts,
GraphEmbeddings, EntityEmbeddings,
Metadata
Metadata, IRI, LITERAL
)
@ -43,7 +43,7 @@ def schema_registry():
"Chunk": Chunk,
"Triple": Triple,
"Triples": Triples,
"Value": Value,
"Term": Term,
"Error": Error,
"EntityContext": EntityContext,
"EntityContexts": EntityContexts,
@ -98,26 +98,22 @@ def sample_message_data():
"collection": "test_collection",
"metadata": []
},
"Value": {
"value": "http://example.com/entity",
"is_uri": True,
"type": ""
"Term": {
"type": IRI,
"iri": "http://example.com/entity"
},
"Triple": {
"s": Value(
value="http://example.com/subject",
is_uri=True,
type=""
"s": Term(
type=IRI,
iri="http://example.com/subject"
),
"p": Value(
value="http://example.com/predicate",
is_uri=True,
type=""
"p": Term(
type=IRI,
iri="http://example.com/predicate"
),
"o": Value(
value="Object value",
is_uri=False,
type=""
"o": Term(
type=LITERAL,
value="Object value"
)
}
}
@ -139,10 +135,10 @@ def invalid_message_data():
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
{"query": "test"}, # Missing required fields
],
"Value": [
{"value": None, "is_uri": True, "type": ""}, # Invalid value (None)
{"value": "test", "is_uri": "not_boolean", "type": ""}, # Invalid is_uri
{"value": 123, "is_uri": True, "type": ""}, # Invalid value (not string)
"Term": [
{"type": IRI, "iri": None}, # Invalid iri (None)
{"type": "invalid_type", "value": "test"}, # Invalid type
{"type": LITERAL, "value": 123}, # Invalid value (not string)
]
}

View file

@ -15,14 +15,14 @@ from trustgraph.schema import (
TextCompletionRequest, TextCompletionResponse,
DocumentRagQuery, DocumentRagResponse,
AgentRequest, AgentResponse, AgentStep,
Chunk, Triple, Triples, Value, Error,
Chunk, Triple, Triples, Term, Error,
EntityContext, EntityContexts,
GraphEmbeddings, EntityEmbeddings,
Metadata, Field, RowSchema,
StructuredDataSubmission, ExtractedObject,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
StructuredQueryRequest, StructuredQueryResponse,
StructuredObjectEmbedding
StructuredObjectEmbedding, IRI, LITERAL
)
from .conftest import validate_schema_contract, serialize_deserialize_test
@ -271,52 +271,51 @@ class TestAgentMessageContracts:
class TestGraphMessageContracts:
"""Contract tests for Graph/Knowledge message schemas"""
def test_value_schema_contract(self, sample_message_data):
"""Test Value schema contract"""
def test_term_schema_contract(self, sample_message_data):
"""Test Term schema contract"""
# Arrange
value_data = sample_message_data["Value"]
term_data = sample_message_data["Term"]
# Act & Assert
assert validate_schema_contract(Value, value_data)
# Test URI value
uri_value = Value(**value_data)
assert uri_value.value == "http://example.com/entity"
assert uri_value.is_uri is True
assert validate_schema_contract(Term, term_data)
# Test literal value
literal_value = Value(
value="Literal text value",
is_uri=False,
type=""
# Test URI term
uri_term = Term(**term_data)
assert uri_term.iri == "http://example.com/entity"
assert uri_term.type == IRI
# Test literal term
literal_term = Term(
type=LITERAL,
value="Literal text value"
)
assert literal_value.value == "Literal text value"
assert literal_value.is_uri is False
assert literal_term.value == "Literal text value"
assert literal_term.type == LITERAL
def test_triple_schema_contract(self, sample_message_data):
"""Test Triple schema contract"""
# Arrange
triple_data = sample_message_data["Triple"]
# Act & Assert - Triple uses Value objects, not dict validation
# Act & Assert - Triple uses Term objects, not dict validation
triple = Triple(
s=triple_data["s"],
p=triple_data["p"],
p=triple_data["p"],
o=triple_data["o"]
)
assert triple.s.value == "http://example.com/subject"
assert triple.p.value == "http://example.com/predicate"
assert triple.s.iri == "http://example.com/subject"
assert triple.p.iri == "http://example.com/predicate"
assert triple.o.value == "Object value"
assert triple.s.is_uri is True
assert triple.p.is_uri is True
assert triple.o.is_uri is False
assert triple.s.type == IRI
assert triple.p.type == IRI
assert triple.o.type == LITERAL
def test_triples_schema_contract(self, sample_message_data):
"""Test Triples (batch) schema contract"""
# Arrange
metadata = Metadata(**sample_message_data["Metadata"])
triple = Triple(**sample_message_data["Triple"])
triples_data = {
"metadata": metadata,
"triples": [triple]
@ -324,11 +323,11 @@ class TestGraphMessageContracts:
# Act & Assert
assert validate_schema_contract(Triples, triples_data)
triples = Triples(**triples_data)
assert triples.metadata.id == "test-doc-123"
assert len(triples.triples) == 1
assert triples.triples[0].s.value == "http://example.com/subject"
assert triples.triples[0].s.iri == "http://example.com/subject"
def test_chunk_schema_contract(self, sample_message_data):
"""Test Chunk schema contract"""
@ -349,29 +348,29 @@ class TestGraphMessageContracts:
def test_entity_context_schema_contract(self):
"""Test EntityContext schema contract"""
# Arrange
entity_value = Value(value="http://example.com/entity", is_uri=True, type="")
entity_term = Term(type=IRI, iri="http://example.com/entity")
entity_context_data = {
"entity": entity_value,
"entity": entity_term,
"context": "Context information about the entity"
}
# Act & Assert
assert validate_schema_contract(EntityContext, entity_context_data)
entity_context = EntityContext(**entity_context_data)
assert entity_context.entity.value == "http://example.com/entity"
assert entity_context.entity.iri == "http://example.com/entity"
assert entity_context.context == "Context information about the entity"
def test_entity_contexts_batch_schema_contract(self, sample_message_data):
"""Test EntityContexts (batch) schema contract"""
# Arrange
metadata = Metadata(**sample_message_data["Metadata"])
entity_value = Value(value="http://example.com/entity", is_uri=True, type="")
entity_term = Term(type=IRI, iri="http://example.com/entity")
entity_context = EntityContext(
entity=entity_value,
entity=entity_term,
context="Entity context"
)
entity_contexts_data = {
"metadata": metadata,
"entities": [entity_context]
@ -379,7 +378,7 @@ class TestGraphMessageContracts:
# Act & Assert
assert validate_schema_contract(EntityContexts, entity_contexts_data)
entity_contexts = EntityContexts(**entity_contexts_data)
assert entity_contexts.metadata.id == "test-doc-123"
assert len(entity_contexts.entities) == 1
@ -417,10 +416,10 @@ class TestMetadataMessageContracts:
# Act & Assert
assert validate_schema_contract(Metadata, metadata_data)
metadata = Metadata(**metadata_data)
assert len(metadata.metadata) == 1
assert metadata.metadata[0].s.value == "http://example.com/subject"
assert metadata.metadata[0].s.iri == "http://example.com/subject"
def test_error_schema_contract(self):
"""Test Error schema contract"""
@ -532,7 +531,7 @@ class TestSerializationContracts:
# Test each schema in the registry
for schema_name, schema_class in schema_registry.items():
if schema_name in sample_message_data:
# Skip Triple schema as it requires special handling with Value objects
# Skip Triple schema as it requires special handling with Term objects
if schema_name == "Triple":
continue
@ -541,36 +540,36 @@ class TestSerializationContracts:
assert serialize_deserialize_test(schema_class, data), f"Serialization failed for {schema_name}"
def test_triple_serialization_contract(self, sample_message_data):
"""Test Triple schema serialization contract with Value objects"""
"""Test Triple schema serialization contract with Term objects"""
# Arrange
triple_data = sample_message_data["Triple"]
# Act
triple = Triple(
s=triple_data["s"],
p=triple_data["p"],
p=triple_data["p"],
o=triple_data["o"]
)
# Assert - Test that Value objects are properly constructed and accessible
assert triple.s.value == "http://example.com/subject"
assert triple.p.value == "http://example.com/predicate"
# Assert - Test that Term objects are properly constructed and accessible
assert triple.s.iri == "http://example.com/subject"
assert triple.p.iri == "http://example.com/predicate"
assert triple.o.value == "Object value"
assert isinstance(triple.s, Value)
assert isinstance(triple.p, Value)
assert isinstance(triple.o, Value)
assert isinstance(triple.s, Term)
assert isinstance(triple.p, Term)
assert isinstance(triple.o, Term)
def test_nested_schema_serialization_contract(self, sample_message_data):
"""Test serialization of nested schemas"""
# Test Triples (contains Metadata and Triple objects)
metadata = Metadata(**sample_message_data["Metadata"])
triple = Triple(**sample_message_data["Triple"])
triples = Triples(metadata=metadata, triples=[triple])
# Verify nested objects maintain their contracts
assert triples.metadata.id == "test-doc-123"
assert triples.triples[0].s.value == "http://example.com/subject"
assert triples.triples[0].s.iri == "http://example.com/subject"
def test_array_field_serialization_contract(self):
"""Test serialization of array fields"""

View file

@ -1,8 +1,8 @@
"""
Contract tests for Cassandra Object Storage
Contract tests for Cassandra Row Storage
These tests verify the message contracts and schema compatibility
for the objects storage processor.
for the rows storage processor.
"""
import pytest
@ -10,12 +10,12 @@ import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.storage.rows.cassandra.write import Processor
@pytest.mark.contract
class TestObjectsCassandraContracts:
"""Contract tests for Cassandra object storage messages"""
class TestRowsCassandraContracts:
"""Contract tests for Cassandra row storage messages"""
def test_extracted_object_input_contract(self):
"""Test that ExtractedObject schema matches expected input format"""
@ -145,50 +145,6 @@ class TestObjectsCassandraContracts:
assert required_field_keys.issubset(field.keys())
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
def test_cassandra_type_mapping_contract(self):
"""Test that all supported field types have Cassandra mappings"""
processor = Processor.__new__(Processor)
# All field types that should be supported
supported_types = [
("string", "text"),
("integer", "int"), # or bigint based on size
("float", "float"), # or double based on size
("boolean", "boolean"),
("timestamp", "timestamp"),
("date", "date"),
("time", "time"),
("uuid", "uuid")
]
for field_type, expected_cassandra_type in supported_types:
cassandra_type = processor.get_cassandra_type(field_type)
# For integer and float, the exact type depends on size
if field_type in ["integer", "float"]:
assert cassandra_type in ["int", "bigint", "float", "double"]
else:
assert cassandra_type == expected_cassandra_type
def test_value_conversion_contract(self):
"""Test value conversion for all supported types"""
processor = Processor.__new__(Processor)
# Test conversions maintain data integrity
test_cases = [
# (input_value, field_type, expected_output, expected_type)
("123", "integer", 123, int),
("123.45", "float", 123.45, float),
("true", "boolean", True, bool),
("false", "boolean", False, bool),
("test string", "string", "test string", str),
(None, "string", None, type(None)),
]
for input_val, field_type, expected_val, expected_type in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_val
assert isinstance(result, expected_type) or result is None
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
def test_extracted_object_serialization_contract(self):
"""Test that ExtractedObject can be serialized/deserialized correctly"""
@ -222,43 +178,31 @@ class TestObjectsCassandraContracts:
assert decoded.confidence == original.confidence
assert decoded.source_span == original.source_span
def test_cassandra_table_naming_contract(self):
def test_cassandra_name_sanitization_contract(self):
"""Test Cassandra naming conventions and constraints"""
processor = Processor.__new__(Processor)
# Test table naming (always gets o_ prefix)
table_test_names = [
("simple_name", "o_simple_name"),
("Name-With-Dashes", "o_name_with_dashes"),
("name.with.dots", "o_name_with_dots"),
("123_numbers", "o_123_numbers"),
("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "o_uppercase"),
("CamelCase", "o_camelcase"),
("", "o_"), # Edge case - empty string becomes o_
]
for input_name, expected_name in table_test_names:
result = processor.sanitize_table(input_name)
assert result == expected_name
# Verify result is valid Cassandra identifier (starts with letter)
assert result.startswith('o_')
assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_'
# Test regular name sanitization (only adds o_ prefix if starts with number)
# Test name sanitization for Cassandra identifiers
# - Non-alphanumeric chars (except underscore) become underscores
# - Names starting with non-letter get 'r_' prefix
# - All names converted to lowercase
name_test_cases = [
("simple_name", "simple_name"),
("Name-With-Dashes", "name_with_dashes"),
("name.with.dots", "name_with_dots"),
("123_numbers", "o_123_numbers"), # Only this gets o_ prefix
("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number)
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "uppercase"),
("CamelCase", "camelcase"),
("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore)
]
for input_name, expected_name in name_test_cases:
result = processor.sanitize_name(input_name)
assert result == expected_name
assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}"
# Verify result is valid Cassandra identifier (starts with letter)
if result: # Skip empty string case
assert result[0].isalpha(), f"Result {result} should start with a letter"
def test_primary_key_structure_contract(self):
"""Test that primary key structure follows Cassandra best practices"""
@ -308,8 +252,8 @@ class TestObjectsCassandraContracts:
@pytest.mark.contract
class TestObjectsCassandraContractsBatch:
"""Contract tests for Cassandra object storage batch processing"""
class TestRowsCassandraContractsBatch:
"""Contract tests for Cassandra row storage batch processing"""
def test_extracted_object_batch_input_contract(self):
"""Test that batched ExtractedObject schema matches expected input format"""

View file

@ -1,26 +1,26 @@
"""
Contract tests for Objects GraphQL Query Service
Contract tests for Rows GraphQL Query Service
These tests verify the message contracts and schema compatibility
for the objects GraphQL query processor.
for the rows GraphQL query processor.
"""
import pytest
import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
@pytest.mark.contract
class TestObjectsGraphQLQueryContracts:
class TestRowsGraphQLQueryContracts:
"""Contract tests for GraphQL query service messages"""
def test_objects_query_request_contract(self):
"""Test ObjectsQueryRequest schema structure and required fields"""
def test_rows_query_request_contract(self):
"""Test RowsQueryRequest schema structure and required fields"""
# Create test request with all required fields
test_request = ObjectsQueryRequest(
test_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name email } }',
@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_request.variables["status"] == "active"
assert test_request.operation_name == "GetCustomers"
def test_objects_query_request_minimal(self):
"""Test ObjectsQueryRequest with minimal required fields"""
def test_rows_query_request_minimal(self):
"""Test RowsQueryRequest with minimal required fields"""
# Create request with only essential fields
minimal_request = ObjectsQueryRequest(
minimal_request = RowsQueryRequest(
user="user",
collection="collection",
query='{ test }',
@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_error.path == ["customers", "0", "nonexistent"]
assert test_error.extensions["code"] == "FIELD_ERROR"
def test_objects_query_response_success_contract(self):
"""Test ObjectsQueryResponse schema for successful queries"""
def test_rows_query_response_success_contract(self):
"""Test RowsQueryResponse schema for successful queries"""
# Create successful response
success_response = ObjectsQueryResponse(
success_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=[],
@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts:
assert len(parsed_data["customers"]) == 1
assert parsed_data["customers"][0]["id"] == "1"
def test_objects_query_response_error_contract(self):
"""Test ObjectsQueryResponse schema for error cases"""
def test_rows_query_response_error_contract(self):
"""Test RowsQueryResponse schema for error cases"""
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
# by creating a response without the problematic errors array first
error_response = ObjectsQueryResponse(
error_response = RowsQueryResponse(
error=None, # System error is None - these are GraphQL errors
data=None, # No data due to errors
errors=[], # Empty errors array to avoid Pulsar bug
@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts:
assert validation_error.path == ["customers", "email"]
assert validation_error.extensions["details"] == "Invalid email format"
def test_objects_query_response_system_error_contract(self):
"""Test ObjectsQueryResponse schema for system errors"""
def test_rows_query_response_system_error_contract(self):
"""Test RowsQueryResponse schema for system errors"""
from trustgraph.schema import Error
# Create system error response
system_error_response = ObjectsQueryResponse(
system_error_response = RowsQueryResponse(
error=Error(
type="objects-query-error",
type="rows-query-error",
message="Failed to connect to Cassandra cluster"
),
data=None,
@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts:
# Verify system error structure
assert system_error_response.error is not None
assert system_error_response.error.type == "objects-query-error"
assert system_error_response.error.type == "rows-query-error"
assert "Cassandra" in system_error_response.error.message
assert system_error_response.data is None
assert len(system_error_response.errors) == 0
@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts:
def test_request_response_serialization_contract(self):
"""Test that request/response can be serialized/deserialized correctly"""
# Create original request
original_request = ObjectsQueryRequest(
original_request = RowsQueryRequest(
user="serialization_test",
collection="test_data",
query='{ orders(limit: 5) { id total customer { name } } }',
@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test request serialization using Pulsar schema
request_schema = AvroSchema(ObjectsQueryRequest)
request_schema = AvroSchema(RowsQueryRequest)
# Encode and decode request
encoded_request = request_schema.encode(original_request)
@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts:
assert decoded_request.operation_name == original_request.operation_name
# Create original response - work around Pulsar Array(Record) bug
original_response = ObjectsQueryResponse(
original_response = RowsQueryResponse(
error=None,
data='{"orders": []}',
errors=[], # Empty to avoid Pulsar validation bug
@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test response serialization
response_schema = AvroSchema(ObjectsQueryResponse)
response_schema = AvroSchema(RowsQueryResponse)
# Encode and decode response
encoded_response = response_schema.encode(original_response)
@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_query_format_contract(self):
"""Test supported GraphQL query formats"""
# Test basic query
basic_query = ObjectsQueryRequest(
basic_query = RowsQueryRequest(
user="test", collection="test", query='{ customers { id } }',
variables={}, operation_name=""
)
@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts:
assert basic_query.query.strip().endswith('}')
# Test query with variables
parameterized_query = ObjectsQueryRequest(
parameterized_query = RowsQueryRequest(
user="test", collection="test",
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
variables={"status": "active", "limit": "10"},
@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts:
assert parameterized_query.operation_name == "GetCustomers"
# Test complex nested query
nested_query = ObjectsQueryRequest(
nested_query = RowsQueryRequest(
user="test", collection="test",
query='''
{
@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts:
# Note: Current schema uses Map(String()) which only supports string values
# This test verifies the current contract, though ideally we'd support all JSON types
variables_test = ObjectsQueryRequest(
variables_test = RowsQueryRequest(
user="test", collection="test", query='{ test }',
variables={
"string_var": "test_value",
@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts:
def test_cassandra_context_fields_contract(self):
"""Test that request contains necessary fields for Cassandra operations"""
# Verify request has fields needed for Cassandra keyspace/table targeting
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="keyspace_name", # Maps to Cassandra keyspace
collection="partition_collection", # Used in partition key
query='{ objects { id } }',
@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_extensions_contract(self):
"""Test GraphQL extensions field format and usage"""
# Extensions should support query metadata
response_with_extensions = ObjectsQueryResponse(
response_with_extensions = RowsQueryResponse(
error=None,
data='{"test": "data"}',
errors=[],
@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts:
'''
# Request to execute specific operation
multi_op_request = ObjectsQueryRequest(
multi_op_request = RowsQueryRequest(
user="test", collection="test",
query=multi_op_query,
variables={},
@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts:
assert "GetOrders" in multi_op_request.query
# Test single operation (operation_name optional)
single_op_request = ObjectsQueryRequest(
single_op_request = RowsQueryRequest(
user="test", collection="test",
query='{ customers { id } }',
variables={}, operation_name=""

View file

@ -15,7 +15,7 @@ from trustgraph.schema import (
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
StructuredQueryRequest, StructuredQueryResponse,
StructuredObjectEmbedding, Field, RowSchema,
Metadata, Error, Value
Metadata, Error
)
from .conftest import serialize_deserialize_test

View file

@ -12,7 +12,7 @@ import json
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.template.prompt_manager import PromptManager
@ -30,38 +30,16 @@ class TestAgentKgExtractionIntegration:
# Mock agent client
agent_client = AsyncMock()
# Mock successful agent response
# Mock successful agent response in JSONL format
def mock_agent_response(recipient, question):
# Simulate agent processing and return structured response
# Simulate agent processing and return structured JSONL response
mock_response = MagicMock()
mock_response.error = None
mock_response.answer = '''```json
{
"definitions": [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
],
"relationships": [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": true
},
{
"subject": "Neural Networks",
"predicate": "used_in",
"object": "Machine Learning",
"object-entity": true
}
]
}
{"type": "definition", "entity": "Machine Learning", "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."}
{"type": "definition", "entity": "Neural Networks", "definition": "Computing systems inspired by biological neural networks that process information."}
{"type": "relationship", "subject": "Machine Learning", "predicate": "is_subset_of", "object": "Artificial Intelligence", "object-entity": true}
{"type": "relationship", "subject": "Neural Networks", "predicate": "used_in", "object": "Machine Learning", "object-entity": true}
```'''
return mock_response.answer
@ -100,9 +78,9 @@ class TestAgentKgExtractionIntegration:
id="doc123",
metadata=[
Triple(
s=Value(value="doc123", is_uri=True),
p=Value(value="http://example.org/type", is_uri=True),
o=Value(value="document", is_uri=False)
s=Term(type=IRI, iri="doc123"),
p=Term(type=IRI, iri="http://example.org/type"),
o=Term(type=LITERAL, value="document")
)
]
)
@ -120,7 +98,7 @@ class TestAgentKgExtractionIntegration:
# Copy the methods we want to test
extractor.to_uri = real_extractor.to_uri
extractor.parse_json = real_extractor.parse_json
extractor.parse_jsonl = real_extractor.parse_jsonl
extractor.process_extraction_data = real_extractor.process_extraction_data
extractor.emit_triples = real_extractor.emit_triples
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
@ -156,7 +134,7 @@ class TestAgentKgExtractionIntegration:
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
# Parse and process
extraction_data = extractor.parse_json(agent_response)
extraction_data = extractor.parse_jsonl(agent_response)
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
# Add metadata triples
@ -200,15 +178,15 @@ class TestAgentKgExtractionIntegration:
assert len(sent_triples.triples) > 0
# Check that we have definition triples
definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION]
definition_triples = [t for t in sent_triples.triples if t.p.iri == DEFINITION]
assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks
# Check that we have label triples
label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL]
label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL]
assert len(label_triples) >= 2 # Should have labels for entities
# Check subject-of relationships
subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF]
subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) >= 2 # Entities should be linked to document
# Verify entity contexts were emitted
@ -220,7 +198,7 @@ class TestAgentKgExtractionIntegration:
assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities
# Verify entity URIs are properly formed
entity_uris = [ec.entity.value for ec in sent_contexts.entities]
entity_uris = [ec.entity.iri for ec in sent_contexts.entities]
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
@ -248,22 +226,28 @@ class TestAgentKgExtractionIntegration:
@pytest.mark.asyncio
async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
"""Test handling of invalid JSON responses from agent"""
"""Test handling of invalid JSON responses from agent - JSONL is lenient and skips invalid lines"""
# Arrange - mock invalid JSON response
agent_client = mock_flow_context("agent-request")
def mock_invalid_json_response(recipient, question):
return "This is not valid JSON at all"
agent_client.invoke = mock_invalid_json_response
mock_message = MagicMock()
mock_message.value.return_value = sample_chunk
mock_consumer = MagicMock()
# Act & Assert
with pytest.raises((ValueError, json.JSONDecodeError)):
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
# Act - JSONL parsing is lenient, invalid lines are skipped
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
# Assert - should emit triples (with just metadata) but no entity contexts
triples_publisher = mock_flow_context("triples")
triples_publisher.send.assert_called_once()
entity_contexts_publisher = mock_flow_context("entity-contexts")
entity_contexts_publisher.send.assert_not_called()
@pytest.mark.asyncio
async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context):
@ -272,7 +256,8 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request")
def mock_empty_response(recipient, question):
return '{"definitions": [], "relationships": []}'
# Return empty JSONL (just empty/whitespace)
return ''
agent_client.invoke = mock_empty_response
@ -303,7 +288,8 @@ class TestAgentKgExtractionIntegration:
agent_client = mock_flow_context("agent-request")
def mock_malformed_response(recipient, question):
return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}'''
# JSONL with definition missing required field
return '{"type": "definition", "entity": "Missing Definition"}'
agent_client.invoke = mock_malformed_response
@ -330,7 +316,7 @@ class TestAgentKgExtractionIntegration:
def capture_prompt(recipient, question):
# Verify the prompt contains the test text
assert test_text in question
return '{"definitions": [], "relationships": []}'
return '' # Empty JSONL response
agent_client.invoke = capture_prompt
@ -361,7 +347,7 @@ class TestAgentKgExtractionIntegration:
responses = []
def mock_response(recipient, question):
response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}'
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
responses.append(response)
return response
@ -398,7 +384,7 @@ class TestAgentKgExtractionIntegration:
# Verify unicode text was properly decoded and included
assert "学习机器" in question
assert "人工知能" in question
return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}'''
return '{"type": "definition", "entity": "機械学習", "definition": "人工知能の一分野"}'
agent_client.invoke = mock_unicode_response
@ -415,7 +401,7 @@ class TestAgentKgExtractionIntegration:
sent_triples = triples_publisher.send.call_args[0][0]
# Check that unicode entity was properly processed
entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"]
entity_labels = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL and t.o.value == "機械学習"]
assert len(entity_labels) > 0
@pytest.mark.asyncio
@ -433,7 +419,7 @@ class TestAgentKgExtractionIntegration:
def mock_large_text_response(recipient, question):
# Verify large text was included
assert len(question) > 10000
return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}'''
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'
agent_client.invoke = mock_large_text_response

View file

@ -12,7 +12,7 @@ from argparse import ArgumentParser
# Import processors that use Cassandra configuration
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
"""Test complete flow from environment variables to Cassandra Cluster connection."""
env_vars = {
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
# Trigger Cassandra connection
processor.connect_cassandra()
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
class TestMultipleHostsHandling:
"""Test multiple Cassandra hosts handling end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
env_vars = {
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify all hosts were passed to Cluster
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
class TestAuthenticationFlow:
"""Test authentication configuration flow end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is enabled when both username and password are provided."""
env_vars = {
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should be created
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
assert 'auth_provider' in call_args.kwargs
assert call_args.kwargs['auth_provider'] == mock_auth_instance
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when credentials are missing."""
env_vars = {
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should not be created
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
call_args = mock_cluster.call_args
assert 'auth_provider' not in call_args.kwargs
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when only username is provided."""
processor = ObjectsWriter(
processor = RowsWriter(
taskgroup=MagicMock(),
cassandra_host='partial-auth-host',
cassandra_username='partial-user'

View file

@ -16,7 +16,7 @@ from .cassandra_test_helper import cassandra_container
from trustgraph.direct.cassandra_kg import KnowledgeGraph
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
from trustgraph.schema import Triple, Term, Metadata, Triples, TriplesQueryRequest, IRI, LITERAL
@pytest.mark.integration
@ -118,19 +118,19 @@ class TestCassandraIntegration:
metadata=Metadata(user="testuser", collection="testcol"),
triples=[
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value="Alice Smith", is_uri=False)
s=Term(type=IRI, iri="http://example.org/person1"),
p=Term(type=IRI, iri="http://example.org/name"),
o=Term(type=LITERAL, value="Alice Smith")
),
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value="25", is_uri=False)
s=Term(type=IRI, iri="http://example.org/person1"),
p=Term(type=IRI, iri="http://example.org/age"),
o=Term(type=LITERAL, value="25")
),
Triple(
s=Value(value="http://example.org/person1", is_uri=True),
p=Value(value="http://example.org/department", is_uri=True),
o=Value(value="Engineering", is_uri=False)
s=Term(type=IRI, iri="http://example.org/person1"),
p=Term(type=IRI, iri="http://example.org/department"),
o=Term(type=LITERAL, value="Engineering")
)
]
)
@ -181,19 +181,19 @@ class TestCassandraIntegration:
metadata=Metadata(user="testuser", collection="testcol"),
triples=[
Triple(
s=Value(value="http://example.org/alice", is_uri=True),
p=Value(value="http://example.org/knows", is_uri=True),
o=Value(value="http://example.org/bob", is_uri=True)
s=Term(type=IRI, iri="http://example.org/alice"),
p=Term(type=IRI, iri="http://example.org/knows"),
o=Term(type=IRI, iri="http://example.org/bob")
),
Triple(
s=Value(value="http://example.org/alice", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value="30", is_uri=False)
s=Term(type=IRI, iri="http://example.org/alice"),
p=Term(type=IRI, iri="http://example.org/age"),
o=Term(type=LITERAL, value="30")
),
Triple(
s=Value(value="http://example.org/bob", is_uri=True),
p=Value(value="http://example.org/knows", is_uri=True),
o=Value(value="http://example.org/charlie", is_uri=True)
s=Term(type=IRI, iri="http://example.org/bob"),
p=Term(type=IRI, iri="http://example.org/knows"),
o=Term(type=IRI, iri="http://example.org/charlie")
)
]
)
@ -208,7 +208,7 @@ class TestCassandraIntegration:
# Test S query (find all relationships for Alice)
s_query = TriplesQueryRequest(
s=Value(value="http://example.org/alice", is_uri=True),
s=Term(type=IRI, iri="http://example.org/alice"),
p=None, # None for wildcard
o=None, # None for wildcard
limit=10,
@ -218,18 +218,18 @@ class TestCassandraIntegration:
s_results = await query_processor.query_triples(s_query)
print(f"Query processor results: {len(s_results)}")
for result in s_results:
print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}")
print(f" S={result.s.iri}, P={result.p.iri}, O={result.o.iri if result.o.type == IRI else result.o.value}")
assert len(s_results) == 2
s_predicates = [t.p.value for t in s_results]
s_predicates = [t.p.iri for t in s_results]
assert "http://example.org/knows" in s_predicates
assert "http://example.org/age" in s_predicates
print("✓ Subject queries via processor working")
# Test P query (find all "knows" relationships)
p_query = TriplesQueryRequest(
s=None, # None for wildcard
p=Value(value="http://example.org/knows", is_uri=True),
p=Term(type=IRI, iri="http://example.org/knows"),
o=None, # None for wildcard
limit=10,
user="testuser",
@ -238,8 +238,8 @@ class TestCassandraIntegration:
p_results = await query_processor.query_triples(p_query)
print(p_results)
assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie
p_subjects = [t.s.value for t in p_results]
p_subjects = [t.s.iri for t in p_results]
assert "http://example.org/alice" in p_subjects
assert "http://example.org/bob" in p_subjects
print("✓ Predicate queries via processor working")
@ -262,19 +262,19 @@ class TestCassandraIntegration:
metadata=Metadata(user="concurrent_test", collection="people"),
triples=[
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value=name, is_uri=False)
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
p=Term(type=IRI, iri="http://example.org/name"),
o=Term(type=LITERAL, value=name)
),
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/age", is_uri=True),
o=Value(value=str(age), is_uri=False)
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
p=Term(type=IRI, iri="http://example.org/age"),
o=Term(type=LITERAL, value=str(age))
),
Triple(
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
p=Value(value="http://example.org/department", is_uri=True),
o=Value(value=department, is_uri=False)
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
p=Term(type=IRI, iri="http://example.org/department"),
o=Term(type=LITERAL, value=department)
)
]
)
@ -333,36 +333,36 @@ class TestCassandraIntegration:
triples=[
# People and their types
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://company.org/Employee", is_uri=True)
s=Term(type=IRI, iri="http://company.org/alice"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://company.org/Employee")
),
Triple(
s=Value(value="http://company.org/bob", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://company.org/Employee", is_uri=True)
s=Term(type=IRI, iri="http://company.org/bob"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://company.org/Employee")
),
# Relationships
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/reportsTo", is_uri=True),
o=Value(value="http://company.org/bob", is_uri=True)
s=Term(type=IRI, iri="http://company.org/alice"),
p=Term(type=IRI, iri="http://company.org/reportsTo"),
o=Term(type=IRI, iri="http://company.org/bob")
),
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/worksIn", is_uri=True),
o=Value(value="http://company.org/engineering", is_uri=True)
s=Term(type=IRI, iri="http://company.org/alice"),
p=Term(type=IRI, iri="http://company.org/worksIn"),
o=Term(type=IRI, iri="http://company.org/engineering")
),
# Personal info
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/fullName", is_uri=True),
o=Value(value="Alice Johnson", is_uri=False)
s=Term(type=IRI, iri="http://company.org/alice"),
p=Term(type=IRI, iri="http://company.org/fullName"),
o=Term(type=LITERAL, value="Alice Johnson")
),
Triple(
s=Value(value="http://company.org/alice", is_uri=True),
p=Value(value="http://company.org/email", is_uri=True),
o=Value(value="alice@company.org", is_uri=False)
s=Term(type=IRI, iri="http://company.org/alice"),
p=Term(type=IRI, iri="http://company.org/email"),
o=Term(type=LITERAL, value="alice@company.org")
),
]
)

View file

@ -51,10 +51,10 @@ class MockWebSocket:
"metadata": {
"id": "test-id",
"metadata": {},
"user": "test-user",
"user": "test-user",
"collection": "test-collection"
},
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
}
@ -118,7 +118,7 @@ async def test_import_graceful_shutdown_integration(mock_backend):
"user": "test-user",
"collection": "test-collection"
},
"triples": [{"s": {"v": f"subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"object-{i}", "e": False}}]
"triples": [{"s": {"t": "l", "v": f"subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"object-{i}"}}]
}
messages.append(msg_data)
@ -163,7 +163,7 @@ async def test_export_no_message_loss_integration(mock_backend):
"user": "test-user",
"collection": "test-collection"
},
"triples": [{"s": {"v": f"export-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"export-object-{i}", "e": False}}]
"triples": [{"s": {"t": "l", "v": f"export-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"export-object-{i}"}}]
}
# Create Triples object instead of raw dict
from trustgraph.schema import Triples, Metadata
@ -302,7 +302,7 @@ async def test_concurrent_import_export_shutdown():
"user": "test-user",
"collection": "test-collection"
},
"triples": [{"s": {"v": f"concurrent-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
"triples": [{"s": {"t": "l", "v": f"concurrent-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
}
await import_handler.receive(msg)
@ -359,7 +359,7 @@ async def test_websocket_close_during_message_processing():
"user": "test-user",
"collection": "test-collection"
},
"triples": [{"s": {"v": f"slow-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
"triples": [{"s": {"t": "l", "v": f"slow-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
}
task = asyncio.create_task(import_handler.receive(msg))
message_tasks.append(task)
@ -423,7 +423,7 @@ async def test_backpressure_during_shutdown():
# Simulate receiving and processing a message
msg_data = {
"metadata": {"id": f"msg-{i}"},
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
}
await ws.send_json(msg_data)
# Check if we should stop

View file

@ -15,8 +15,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
@ -147,6 +147,8 @@ class TestKnowledgeGraphPipelineIntegration:
processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor)
processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor)
processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor)
processor.triples_batch_size = 50
processor.entity_batch_size = 5
return processor
@pytest.fixture
@ -156,6 +158,7 @@ class TestKnowledgeGraphPipelineIntegration:
processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor)
processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor)
processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor)
processor.triples_batch_size = 50
return processor
@pytest.mark.asyncio
@ -253,24 +256,24 @@ class TestKnowledgeGraphPipelineIntegration:
if s and o:
s_uri = definitions_processor.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
s_term = Term(type=IRI, iri=str(s_uri))
o_term = Term(type=LITERAL, value=str(o))
# Generate triples as the processor would
triples.append(Triple(
s=s_value,
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=s, is_uri=False)
s=s_term,
p=Term(type=IRI, iri=RDF_LABEL),
o=Term(type=LITERAL, value=s)
))
triples.append(Triple(
s=s_value,
p=Value(value=DEFINITION, is_uri=True),
o=o_value
s=s_term,
p=Term(type=IRI, iri=DEFINITION),
o=o_term
))
entities.append(EntityContext(
entity=s_value,
entity=s_term,
context=defn["definition"]
))
@ -279,16 +282,16 @@ class TestKnowledgeGraphPipelineIntegration:
assert len(entities) == 3 # 1 entity context per entity
# Verify triple structure
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
definition_triples = [t for t in triples if t.p.value == DEFINITION]
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
assert len(label_triples) == 3
assert len(definition_triples) == 3
# Verify entity contexts
for entity in entities:
assert entity.entity.is_uri is True
assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES)
assert entity.entity.type == IRI
assert entity.entity.iri.startswith(TRUSTGRAPH_ENTITIES)
assert len(entity.context) > 0
@pytest.mark.asyncio
@ -309,52 +312,52 @@ class TestKnowledgeGraphPipelineIntegration:
s = rel["subject"]
p = rel["predicate"]
o = rel["object"]
if s and p and o:
s_uri = relationships_processor.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
s_term = Term(type=IRI, iri=str(s_uri))
p_uri = relationships_processor.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
p_term = Term(type=IRI, iri=str(p_uri))
if rel["object-entity"]:
o_uri = relationships_processor.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
o_term = Term(type=IRI, iri=str(o_uri))
else:
o_value = Value(value=str(o), is_uri=False)
o_term = Term(type=LITERAL, value=str(o))
# Main relationship triple
triples.append(Triple(s=s_value, p=p_value, o=o_value))
triples.append(Triple(s=s_term, p=p_term, o=o_term))
# Label triples
triples.append(Triple(
s=s_value,
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=str(s), is_uri=False)
s=s_term,
p=Term(type=IRI, iri=RDF_LABEL),
o=Term(type=LITERAL, value=str(s))
))
triples.append(Triple(
s=p_value,
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=str(p), is_uri=False)
s=p_term,
p=Term(type=IRI, iri=RDF_LABEL),
o=Term(type=LITERAL, value=str(p))
))
if rel["object-entity"]:
triples.append(Triple(
s=o_value,
p=Value(value=RDF_LABEL, is_uri=True),
o=Value(value=str(o), is_uri=False)
s=o_term,
p=Term(type=IRI, iri=RDF_LABEL),
o=Term(type=LITERAL, value=str(o))
))
# Assert
assert len(triples) > 0
# Verify relationship triples exist
relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")]
relationship_triples = [t for t in triples if t.p.iri.endswith("is_subset_of") or t.p.iri.endswith("is_used_in")]
assert len(relationship_triples) >= 2
# Verify label triples
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
assert len(label_triples) > 0
@pytest.mark.asyncio
@ -374,9 +377,9 @@ class TestKnowledgeGraphPipelineIntegration:
),
triples=[
Triple(
s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True),
p=Value(value=DEFINITION, is_uri=True),
o=Value(value="A subset of AI", is_uri=False)
s=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"),
p=Term(type=IRI, iri=DEFINITION),
o=Term(type=LITERAL, value="A subset of AI")
)
]
)
@ -405,9 +408,14 @@ class TestKnowledgeGraphPipelineIntegration:
collection="test_collection",
metadata=[]
),
entities=[]
entities=[
EntityEmbeddings(
entity=Term(type=IRI, iri="http://example.org/entity"),
vectors=[[0.1, 0.2, 0.3]]
)
]
)
mock_msg = MagicMock()
mock_msg.value.return_value = sample_embeddings
@ -496,12 +504,12 @@ class TestKnowledgeGraphPipelineIntegration:
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
# Assert
# Should still call producers but with empty results
# Should NOT call producers with empty results (avoids Cassandra NULL issues)
triples_producer = mock_flow_context("triples")
entity_contexts_producer = mock_flow_context("entity-contexts")
triples_producer.send.assert_called_once()
entity_contexts_producer.send.assert_called_once()
triples_producer.send.assert_not_called()
entity_contexts_producer.send.assert_not_called()
@pytest.mark.asyncio
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
@ -602,9 +610,9 @@ class TestKnowledgeGraphPipelineIntegration:
collection="test_collection",
metadata=[
Triple(
s=Value(value="doc:test", is_uri=True),
p=Value(value="dc:title", is_uri=True),
o=Value(value="Test Document", is_uri=False)
s=Term(type=IRI, iri="doc:test"),
p=Term(type=IRI, iri="dc:title"),
o=Term(type=LITERAL, value="Test Document")
)
]
)

View file

@ -11,7 +11,7 @@ import json
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Mock flow with failing prompt service
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration

View file

@ -1,608 +0,0 @@
"""
Integration tests for Cassandra Object Storage
These tests verify the end-to-end functionality of storing ExtractedObjects
in Cassandra, including table creation, data insertion, and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import uuid
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestObjectsCassandraIntegration:
"""Integration tests for Cassandra object storage"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
# Extract keyspace name from query
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
# Check if this keyspace was created
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.graph_host = "localhost"
processor.graph_username = None
processor.graph_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.cluster = None
processor.session = None
# Bind actual methods
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
# Mock Cluster creation
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 1.5: Create the collection first (simulate tg-set-collection)
await processor.create_collection("test_user", "import_2024", {})
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
assert "collection text" in str(table_calls[0])
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
# Verify index creation
index_calls = [call for call in mock_session.execute.call_args_list
if "CREATE INDEX" in str(call)]
assert len(index_calls) == 1
assert "email" in str(index_calls[0])
# Verify data insertion
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
insert_call = insert_calls[0]
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
# Check inserted values
values = insert_call[0][1]
assert "import_2024" in values # collection
assert "CUST001" in values # customer_id
assert "John Doe" in values # name
assert "john@example.com" in values # email
assert 30 in values # age (converted to int)
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas and objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Create collections first
await processor.create_collection("shop", "catalog", {})
await processor.create_collection("shop", "sales", {})
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify separate tables were created
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
@pytest.mark.asyncio
async def test_missing_required_fields(self, processor_with_mocks):
"""Test handling of objects with missing required fields"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema with required field
processor.schemas["test_schema"] = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True, required=True),
Field(name="required_field", type="string", size=100, required=True)
]
)
# Create collection first
await processor.create_collection("test", "test", {})
# Create object missing required field
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test_schema",
values=[{"id": "123"}], # missing required_field
confidence=0.8,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should still process (Cassandra doesn't enforce NOT NULL)
await processor.on_object(msg, None, None)
# Verify insert was attempted
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
@pytest.mark.asyncio
async def test_schema_without_primary_key(self, processor_with_mocks):
"""Test handling schemas without defined primary keys"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema without primary key
processor.schemas["events"] = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Create collection first
await processor.create_collection("logger", "app_events", {})
# Process object
test_obj = ExtractedObject(
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
schema_name="events",
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
confidence=1.0,
source_span="Event"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify synthetic_id was added
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "synthetic_id uuid" in str(table_calls[0])
# Verify insert includes UUID
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
values = insert_calls[0][0][1]
# Check that a UUID was generated (will be in values list)
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
assert uuid_found
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_error_handling_during_insert(self, processor_with_mocks):
"""Test error handling when insertion fails"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["test"] = RowSchema(
name="test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Make insert fail
mock_result = MagicMock()
mock_result.one.return_value = MagicMock() # Keyspace exists
mock_session.execute.side_effect = [
mock_result, # keyspace existence check succeeds
None, # table creation succeeds
Exception("Connection timeout") # insert fails
]
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test",
values=[{"id": "123"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should raise the exception
with pytest.raises(Exception, match="Connection timeout"):
await processor.on_object(msg, None, None)
@pytest.mark.asyncio
async def test_collection_partitioning(self, processor_with_mocks):
"""Test that objects are properly partitioned by collection"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["data"] = RowSchema(
name="data",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process objects from different collections
collections = ["import_jan", "import_feb", "import_mar"]
# Create all collections first
for coll in collections:
await processor.create_collection("analytics", coll, {})
for coll in collections:
obj = ExtractedObject(
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
schema_name="data",
values=[{"id": f"ID-{coll}"}],
confidence=0.9,
source_span="Data"
)
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify all inserts include collection in values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3
# Check each insert has the correct collection
for i, call in enumerate(insert_calls):
values = call[0][1]
assert collections[i] in values
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
# Create collection first
await processor.create_collection("test_user", "batch_import", {})
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_batch_customers" in str(table_calls[0])
# Verify multiple inserts for batch values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
# Should have 3 separate inserts for the 3 objects in the batch
assert len(insert_calls) == 3
# Check each insert has correct data
for i, call in enumerate(insert_calls):
values = call[0][1]
assert "batch_import" in values # collection
assert f"CUST00{i+1}" in values # customer_id
if i == 0:
assert "John Doe" in values
assert "john@example.com" in values
elif i == 1:
assert "Jane Smith" in values
assert "jane@example.com" in values
elif i == 2:
assert "Bob Johnson" in values
assert "bob@example.com" in values
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Create collection first
await processor.create_collection("test", "empty", {})
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should still create table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
# Should not create any insert statements for empty batch
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 0
@pytest.mark.asyncio
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
"""Test processing mix of single and batch objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["mixed_test"] = RowSchema(
name="mixed_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
# Create collection first
await processor.create_collection("test", "mixed", {})
# Single object (backward compatibility)
single_obj = ExtractedObject(
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[{"id": "single-1", "data": "single data"}], # Array with single item
confidence=0.9,
source_span="Single object"
)
# Batch object
batch_obj = ExtractedObject(
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[
{"id": "batch-1", "data": "batch data 1"},
{"id": "batch-2", "data": "batch data 2"}
],
confidence=0.85,
source_span="Batch objects"
)
# Process both
for obj in [single_obj, batch_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Should have 3 total inserts (1 + 2)
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3

View file

@ -0,0 +1,492 @@
"""
Integration tests for Cassandra Row Storage (Unified Table Implementation)
These tests verify the end-to-end functionality of storing ExtractedObjects
in the unified Cassandra rows table, including table creation, data insertion,
and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.cassandra_host = ["localhost"]
processor.cassandra_username = None
processor.cassandra_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.tables_initialized = set()
processor.registered_partitions = set()
processor.cluster = None
processor.session = None
# Bind actual methods from the new unified table implementation
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.collection_exists = MagicMock(return_value=True)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows table + row_partitions table
assert any("rows" in str(call) for call in table_calls)
assert any("row_partitions" in str(call) for call in table_calls)
# Verify the rows table has correct structure
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
assert "collection text" in str(rows_table_call)
assert "schema_name text" in str(rows_table_call)
assert "index_name text" in str(rows_table_call)
assert "data map<text, text>" in str(rows_table_call)
# Verify data insertion into unified table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
assert len(rows_insert_calls) == 2
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas stored in unified table"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
assert len(table_calls) == 2
# Verify data inserts go to unified rows table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) > 0
for call in rows_insert_calls:
assert ".rows" in str(call)
@pytest.mark.asyncio
async def test_multi_index_storage(self, processor_with_mocks):
"""Test that rows are stored with multiple indexes"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="indexed_data",
values=[{
"id": "123",
"category": "electronics",
"status": "active",
"description": "A product"
}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 3
# Verify different index names were used
index_names = set()
for call in rows_insert_calls:
values = call[0][1]
index_names.add(values[2]) # index_name is 3rd parameter
assert index_names == {"id", "category", "status"}
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows + row_partitions
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
# 3 rows * 2 indexes = 6 data inserts
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 6
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 0
@pytest.mark.asyncio
async def test_data_stored_as_map(self, processor_with_mocks):
"""Test that data is stored as map<text, text>"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) >= 1
# Check that data is passed as a dict (will be map in Cassandra)
insert_call = rows_insert_calls[0]
values = insert_call[0][1]
# Values are: (collection, schema_name, index_name, index_value, data, source)
# values[4] should be the data map
data_map = values[4]
assert isinstance(data_map, dict)
assert data_map["id"] == "123"
assert data_map["name"] == "Test Item"
assert data_map["count"] == "42"
@pytest.mark.asyncio
async def test_partition_registration(self, processor_with_mocks):
"""Test that partitions are registered for efficient querying"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
# Should register partitions for each index (id, category)
assert len(partition_inserts) == 2
# Verify cache was updated
assert ("my_collection", "partition_test") in processor.registered_partitions

View file

@ -1,5 +1,5 @@
"""
Integration tests for Objects GraphQL Query Service
Integration tests for Rows GraphQL Query Service
These tests verify end-to-end functionality including:
- Real Cassandra database operations
@ -24,8 +24,8 @@ except Exception:
DOCKER_AVAILABLE = False
CassandraContainer = None
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
processor.connect_cassandra()
# Create mock message
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="msg_test_user",
collection="msg_test_collection",
query='{ customer_objects { customer_id name } }',
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
# Verify response structure
sent_response = mock_response_producer.send.call_args[0][0]
assert isinstance(sent_response, ObjectsQueryResponse)
assert isinstance(sent_response, RowsQueryResponse)
# Should have no system error (even if no data)
assert sent_response.error is None

View file

@ -2,7 +2,7 @@
Integration tests for Structured Query Service
These tests verify the end-to-end functionality of the structured query service,
testing orchestration between nlp-query and objects-query services.
testing orchestration between nlp-query and rows-query services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects Query Service Response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
errors=None,
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
# Verify Objects service call
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert "customers" in objects_call_args.query
assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects service failure
objects_error_response = ObjectsQueryResponse(
objects_error_response = RowsQueryResponse(
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
data=None,
errors=None,
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
assert response.error is not None
assert response.error.type == "structured-query-error"
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "nonexistent_table" in response.error.message
@pytest.mark.asyncio
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # No data when validation fails
errors=validation_errors,
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
]
}
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=json.dumps(complex_data),
errors=None,
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock empty Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": []}', # Empty result set
errors=None,
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
errors=None,
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
if service_name == "nlp-query-request":
service_call_count += 1
return nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
service_call_count += 1
return objects_client
elif service_name == "response":
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
errors=None,
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -19,4 +19,5 @@ markers =
integration: marks tests as integration tests
unit: marks tests as unit tests
contract: marks tests as contract tests (service interface validation)
vertexai: marks tests as vertex ai specific tests
vertexai: marks tests as vertex ai specific tests
asyncio: marks tests that use asyncio

View file

@ -88,8 +88,13 @@ async def test_subscriber_deferred_acknowledgment_success():
@pytest.mark.asyncio
async def test_subscriber_deferred_acknowledgment_failure():
"""Verify Subscriber negative acks on delivery failure."""
async def test_subscriber_dropped_message_still_acks():
"""Verify Subscriber acks even when message is dropped (backpressure).
This prevents redelivery storms on shared topics - if we negative_ack
a dropped message, it gets redelivered to all subscribers, none of
whom can handle it either, causing a tight redelivery loop.
"""
mock_backend = MagicMock()
mock_consumer = MagicMock()
mock_backend.create_consumer.return_value = mock_consumer
@ -103,24 +108,66 @@ async def test_subscriber_deferred_acknowledgment_failure():
max_size=1, # Very small queue
backpressure_strategy="drop_new"
)
# Start subscriber to initialize consumer
await subscriber.start()
# Create queue and fill it
queue = await subscriber.subscribe("test-queue")
await queue.put({"existing": "data"})
# Create mock message - should be dropped
msg = create_mock_message("msg-1", {"data": "test"})
# Process message (should fail due to full queue + drop_new strategy)
# Create mock message - should be dropped due to full queue
msg = create_mock_message("test-queue", {"data": "test"})
# Process message (should be dropped due to full queue + drop_new strategy)
await subscriber._process_message(msg)
# Should negative acknowledge failed delivery
mock_consumer.negative_acknowledge.assert_called_once_with(msg)
mock_consumer.acknowledge.assert_not_called()
# Should acknowledge even though delivery failed - prevents redelivery storm
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Clean up
await subscriber.stop()
@pytest.mark.asyncio
async def test_subscriber_orphaned_message_acks():
"""Verify Subscriber acks orphaned messages (no matching waiter).
On shared response topics, if a message arrives for a waiter that
no longer exists (e.g., client disconnected, request timed out),
we must acknowledge it to prevent redelivery storms.
"""
mock_backend = MagicMock()
mock_consumer = MagicMock()
mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber(
backend=mock_backend,
topic="test-topic",
subscription="test-subscription",
consumer_name="test-consumer",
schema=dict,
max_size=10,
backpressure_strategy="block"
)
# Start subscriber to initialize consumer
await subscriber.start()
# Don't create any queues - message will be orphaned
# This simulates a response arriving after the waiter has unsubscribed
# Create mock message with an ID that has no matching waiter
msg = create_mock_message("non-existent-waiter-id", {"data": "orphaned"})
# Process message (should be orphaned - no matching waiter)
await subscriber._process_message(msg)
# Should acknowledge orphaned message - prevents redelivery storm
mock_consumer.acknowledge.assert_called_once_with(msg)
mock_consumer.negative_acknowledge.assert_not_called()
# Clean up
await subscriber.stop()

View file

@ -55,6 +55,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection="sales_data",
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -92,6 +95,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection=None, # No collection specified
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -132,6 +138,9 @@ class TestSetToolStructuredQuery:
mcp_tool=None,
collection='sales_data',
template=None,
schema_name=None,
index_name=None,
limit=None,
arguments=[],
group=None,
state=None,
@ -201,6 +210,144 @@ class TestSetToolStructuredQuery:
assert 'Exception:' in printed_output or 'invalid choice:' in printed_output.lower()
class TestSetToolRowEmbeddingsQuery:
"""Test the set_tool function with row-embeddings-query type."""
@patch('trustgraph.cli.set_tool.Api')
def test_set_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
"""Test setting a row-embeddings-query tool with all parameters."""
mock_api_class.return_value, mock_config = mock_api
mock_config.get.return_value = []
set_tool(
url="http://test.com",
id="customer_search",
name="find_customer",
description="Find customers by name using semantic search",
type="row-embeddings-query",
mcp_tool=None,
collection="sales",
template=None,
schema_name="customers",
index_name="full_name",
limit=20,
arguments=[],
group=None,
state=None,
applicable_states=None
)
captured = capsys.readouterr()
assert "Tool set." in captured.out
# Verify the tool was stored correctly
call_args = mock_config.put.call_args[0][0]
assert len(call_args) == 1
config_value = call_args[0]
assert config_value.type == "tool"
assert config_value.key == "customer_search"
stored_tool = json.loads(config_value.value)
assert stored_tool["name"] == "find_customer"
assert stored_tool["type"] == "row-embeddings-query"
assert stored_tool["collection"] == "sales"
assert stored_tool["schema-name"] == "customers"
assert stored_tool["index-name"] == "full_name"
assert stored_tool["limit"] == 20
@patch('trustgraph.cli.set_tool.Api')
def test_set_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
"""Test setting row-embeddings-query tool with minimal parameters."""
mock_api_class.return_value, mock_config = mock_api
mock_config.get.return_value = []
set_tool(
url="http://test.com",
id="product_search",
name="find_product",
description="Find products using semantic search",
type="row-embeddings-query",
mcp_tool=None,
collection=None,
template=None,
schema_name="products",
index_name=None, # No index filter
limit=None, # Use default
arguments=[],
group=None,
state=None,
applicable_states=None
)
captured = capsys.readouterr()
assert "Tool set." in captured.out
call_args = mock_config.put.call_args[0][0]
stored_tool = json.loads(call_args[0].value)
assert stored_tool["type"] == "row-embeddings-query"
assert stored_tool["schema-name"] == "products"
assert "index-name" not in stored_tool # Should not be included if None
assert "limit" not in stored_tool # Should not be included if None
assert "collection" not in stored_tool # Should not be included if None
def test_set_main_row_embeddings_query_with_all_options(self):
"""Test set main() with row-embeddings-query tool type and all options."""
test_args = [
'tg-set-tool',
'--id', 'customer_search',
'--name', 'find_customer',
'--type', 'row-embeddings-query',
'--description', 'Find customers by name',
'--schema-name', 'customers',
'--collection', 'sales',
'--index-name', 'full_name',
'--limit', '25',
'--api-url', 'http://custom.com'
]
with patch('sys.argv', test_args), \
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
set_main()
mock_set.assert_called_once_with(
url='http://custom.com',
id='customer_search',
name='find_customer',
description='Find customers by name',
type='row-embeddings-query',
mcp_tool=None,
collection='sales',
template=None,
schema_name='customers',
index_name='full_name',
limit=25,
arguments=[],
group=None,
state=None,
applicable_states=None,
token=None
)
def test_valid_types_includes_row_embeddings_query(self):
"""Test that 'row-embeddings-query' is included in valid tool types."""
test_args = [
'tg-set-tool',
'--id', 'test_tool',
'--name', 'test_tool',
'--type', 'row-embeddings-query',
'--description', 'Test tool',
'--schema-name', 'test_schema'
]
with patch('sys.argv', test_args), \
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
# Should not raise an exception about invalid type
set_main()
mock_set.assert_called_once()
class TestShowToolsStructuredQuery:
"""Test the show_tools function with structured-query tools."""
@ -259,9 +406,9 @@ class TestShowToolsStructuredQuery:
@patch('trustgraph.cli.show_tools.Api')
def test_show_mixed_tool_types(self, mock_api_class, mock_api, capsys):
"""Test displaying multiple tool types including structured-query."""
"""Test displaying multiple tool types including structured-query and row-embeddings-query."""
mock_api_class.return_value, mock_config = mock_api
tools = [
{
"name": "ask_knowledge",
@ -270,37 +417,47 @@ class TestShowToolsStructuredQuery:
"collection": "docs"
},
{
"name": "query_data",
"name": "query_data",
"description": "Query structured data",
"type": "structured-query",
"collection": "sales"
},
{
"name": "find_customer",
"description": "Find customers by semantic search",
"type": "row-embeddings-query",
"schema-name": "customers",
"collection": "crm"
},
{
"name": "complete_text",
"description": "Generate text",
"type": "text-completion"
}
]
config_values = [
ConfigValue(type="tool", key=f"tool_{i}", value=json.dumps(tool))
for i, tool in enumerate(tools)
]
mock_config.get_values.return_value = config_values
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# All tool types should be displayed
assert "knowledge-query" in output
assert "structured-query" in output
assert "structured-query" in output
assert "row-embeddings-query" in output
assert "text-completion" in output
# Collections should be shown for appropriate tools
assert "docs" in output # knowledge-query collection
assert "sales" in output # structured-query collection
assert "crm" in output # row-embeddings-query collection
assert "customers" in output # row-embeddings-query schema-name
def test_show_main_parses_args_correctly(self):
"""Test that show main() parses arguments correctly."""
@ -317,6 +474,76 @@ class TestShowToolsStructuredQuery:
mock_show.assert_called_once_with(url='http://custom.com', token=None)
class TestShowToolsRowEmbeddingsQuery:
"""Test the show_tools function with row-embeddings-query tools."""
@patch('trustgraph.cli.show_tools.Api')
def test_show_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
"""Test displaying a row-embeddings-query tool with all fields."""
mock_api_class.return_value, mock_config = mock_api
tool_config = {
"name": "find_customer",
"description": "Find customers by name using semantic search",
"type": "row-embeddings-query",
"collection": "sales",
"schema-name": "customers",
"index-name": "full_name",
"limit": 20
}
config_value = ConfigValue(
type="tool",
key="customer_search",
value=json.dumps(tool_config)
)
mock_config.get_values.return_value = [config_value]
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# Check that tool information is displayed
assert "customer_search" in output
assert "find_customer" in output
assert "row-embeddings-query" in output
assert "sales" in output # Collection
assert "customers" in output # Schema name
assert "full_name" in output # Index name
assert "20" in output # Limit
@patch('trustgraph.cli.show_tools.Api')
def test_show_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
"""Test displaying row-embeddings-query tool with minimal fields."""
mock_api_class.return_value, mock_config = mock_api
tool_config = {
"name": "find_product",
"description": "Find products using semantic search",
"type": "row-embeddings-query",
"schema-name": "products"
# No collection, index-name, or limit
}
config_value = ConfigValue(
type="tool",
key="product_search",
value=json.dumps(tool_config)
)
mock_config.get_values.return_value = [config_value]
show_config("http://test.com")
captured = capsys.readouterr()
output = captured.out
# Should display the tool with schema-name
assert "product_search" in output
assert "row-embeddings-query" in output
assert "products" in output # Schema name
class TestStructuredQueryToolValidation:
"""Test validation specific to structured-query tools."""

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import call
from trustgraph.cores.knowledge import KnowledgeManager
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Value, EntityEmbeddings
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
@pytest.fixture
@ -71,15 +71,15 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
user="test-user",
collection="default", # This should be overridden
metadata=[]
),
triples=[
Triple(
s=Value(value="http://example.org/john", is_uri=True),
p=Value(value="http://example.org/name", is_uri=True),
o=Value(value="John Smith", is_uri=False)
s=Term(type=IRI, iri="http://example.org/john"),
p=Term(type=IRI, iri="http://example.org/name"),
o=Term(type=LITERAL, value="John Smith")
)
]
)
@ -97,7 +97,7 @@ def sample_graph_embeddings():
),
entities=[
EntityEmbeddings(
entity=Value(value="http://example.org/john", is_uri=True),
entity=Term(type=IRI, iri="http://example.org/john"),
vectors=[[0.1, 0.2, 0.3]]
)
]

View file

@ -0,0 +1,599 @@
"""
Unit tests for EntityCentricKnowledgeGraph class
Tests the entity-centric knowledge graph implementation without requiring
an actual Cassandra connection. Uses mocking to verify correct behavior.
"""
import pytest
from unittest.mock import MagicMock, patch, call
import os
class TestEntityCentricKnowledgeGraph:
"""Test cases for EntityCentricKnowledgeGraph"""
@pytest.fixture
def mock_cluster(self):
"""Create a mock Cassandra cluster"""
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cluster_cls:
mock_cluster = MagicMock()
mock_session = MagicMock()
mock_cluster.connect.return_value = mock_session
mock_cluster_cls.return_value = mock_cluster
yield mock_cluster_cls, mock_cluster, mock_session
@pytest.fixture
def entity_kg(self, mock_cluster):
"""Create an EntityCentricKnowledgeGraph instance with mocked Cassandra"""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
# Create instance
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
return kg, mock_session
def test_init_creates_entity_centric_schema(self, mock_cluster):
"""Test that initialization creates the 2-table entity-centric schema"""
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
# Verify schema tables were created
execute_calls = mock_session.execute.call_args_list
executed_statements = [str(c) for c in execute_calls]
# Check for keyspace creation
keyspace_created = any('create keyspace' in str(c).lower() for c in execute_calls)
assert keyspace_created
# Check for quads_by_entity table
entity_table_created = any('quads_by_entity' in str(c) for c in execute_calls)
assert entity_table_created
# Check for quads_by_collection table
collection_table_created = any('quads_by_collection' in str(c) for c in execute_calls)
assert collection_table_created
# Check for collection_metadata table
metadata_table_created = any('collection_metadata' in str(c) for c in execute_calls)
assert metadata_table_created
def test_prepare_statements_initialized(self, entity_kg):
"""Test that prepared statements are initialized"""
kg, mock_session = entity_kg
# Verify prepare was called for various statements
assert mock_session.prepare.called
prepare_calls = mock_session.prepare.call_args_list
# Check that key prepared statements exist
prepared_queries = [str(c) for c in prepare_calls]
# Insert statements
insert_entity_stmt = any('INSERT INTO' in str(c) and 'quads_by_entity' in str(c)
for c in prepare_calls)
assert insert_entity_stmt
insert_collection_stmt = any('INSERT INTO' in str(c) and 'quads_by_collection' in str(c)
for c in prepare_calls)
assert insert_collection_stmt
def test_insert_uri_object_creates_4_entity_rows(self, entity_kg):
"""Test that inserting a quad with URI object creates 4 entity rows"""
kg, mock_session = entity_kg
# Reset mocks to track only insert-related calls
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g='http://example.org/graph1',
otype='u'
)
# Verify batch was executed
mock_session.execute.assert_called()
def test_insert_literal_object_creates_3_entity_rows(self, entity_kg):
"""Test that inserting a quad with literal object creates 3 entity rows"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://www.w3.org/2000/01/rdf-schema#label',
o='Alice Smith',
g=None,
otype='l',
dtype='xsd:string',
lang='en'
)
# Verify batch was executed
mock_session.execute.assert_called()
def test_insert_default_graph(self, entity_kg):
"""Test that None graph is stored as empty string"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g=None,
otype='u'
)
mock_session.execute.assert_called()
def test_insert_auto_detects_otype(self, entity_kg):
"""Test that otype is auto-detected when not provided"""
kg, mock_session = entity_kg
mock_session.reset_mock()
# URI should be auto-detected
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob'
)
mock_session.execute.assert_called()
mock_session.reset_mock()
# Literal should be auto-detected
kg.insert(
collection='test_collection',
s='http://example.org/Alice',
p='http://example.org/name',
o='Alice'
)
mock_session.execute.assert_called()
def test_get_s_returns_quads_for_subject(self, entity_kg):
"""Test get_s queries by subject"""
kg, mock_session = entity_kg
# Mock the query result
mock_result = [
MagicMock(p='http://example.org/knows', o='http://example.org/Bob',
d='', otype='u', dtype='', lang='', s='http://example.org/Alice')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice')
# Verify query was executed
mock_session.execute.assert_called()
# Results should be QuadResult objects
assert len(results) == 1
assert results[0].s == 'http://example.org/Alice'
assert results[0].p == 'http://example.org/knows'
assert results[0].o == 'http://example.org/Bob'
def test_get_p_returns_quads_for_predicate(self, entity_kg):
"""Test get_p queries by predicate"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', o='http://example.org/Bob',
d='', otype='u', dtype='', lang='', p='http://example.org/knows')
]
mock_session.execute.return_value = mock_result
results = kg.get_p('test_collection', 'http://example.org/knows')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_o_returns_quads_for_object(self, entity_kg):
"""Test get_o queries by object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
d='', otype='u', dtype='', lang='', o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_o('test_collection', 'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_sp_returns_quads_for_subject_predicate(self, entity_kg):
"""Test get_sp queries by subject and predicate"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(o='http://example.org/Bob', d='', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_sp('test_collection', 'http://example.org/Alice',
'http://example.org/knows')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_po_returns_quads_for_predicate_object(self, entity_kg):
"""Test get_po queries by predicate and object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', d='', otype='u', dtype='', lang='',
o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_po('test_collection', 'http://example.org/knows',
'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_os_returns_quads_for_object_subject(self, entity_kg):
"""Test get_os queries by object and subject"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='', otype='u', dtype='', lang='',
s='http://example.org/Alice', o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_os('test_collection', 'http://example.org/Bob',
'http://example.org/Alice')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_spo_returns_quads_for_subject_predicate_object(self, entity_kg):
"""Test get_spo queries by subject, predicate, and object"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(d='', otype='u', dtype='', lang='',
o='http://example.org/Bob')
]
mock_session.execute.return_value = mock_result
results = kg.get_spo('test_collection', 'http://example.org/Alice',
'http://example.org/knows', 'http://example.org/Bob')
mock_session.execute.assert_called()
assert len(results) == 1
def test_get_g_returns_quads_for_graph(self, entity_kg):
"""Test get_g queries by graph"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_g('test_collection', 'http://example.org/graph1')
mock_session.execute.assert_called()
def test_get_all_returns_all_quads_in_collection(self, entity_kg):
"""Test get_all returns all quads"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u', dtype='', lang='')
]
mock_session.execute.return_value = mock_result
results = kg.get_all('test_collection')
mock_session.execute.assert_called()
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
"""Test that g='*' returns quads from all graphs"""
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Bob'),
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Charlie')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
# Should return quads from both graphs
assert len(results) == 2
def test_specific_graph_filters_results(self, entity_kg):
"""Test that specifying a graph filters results"""
kg, mock_session = entity_kg
mock_result = [
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Bob'),
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
otype='u', dtype='', lang='', s='http://example.org/Alice',
o='http://example.org/Charlie')
]
mock_session.execute.return_value = mock_result
results = kg.get_s('test_collection', 'http://example.org/Alice',
g='http://example.org/graph1')
# Should only return quads from graph1
assert len(results) == 1
assert results[0].g == 'http://example.org/graph1'
def test_collection_exists_returns_true_when_exists(self, entity_kg):
"""Test collection_exists returns True for existing collection"""
kg, mock_session = entity_kg
mock_result = [MagicMock(collection='test_collection')]
mock_session.execute.return_value = mock_result
exists = kg.collection_exists('test_collection')
assert exists is True
def test_collection_exists_returns_false_when_not_exists(self, entity_kg):
"""Test collection_exists returns False for non-existing collection"""
kg, mock_session = entity_kg
mock_session.execute.return_value = []
exists = kg.collection_exists('nonexistent_collection')
assert exists is False
def test_create_collection_inserts_metadata(self, entity_kg):
"""Test create_collection inserts metadata row"""
kg, mock_session = entity_kg
mock_session.reset_mock()
kg.create_collection('test_collection')
# Verify INSERT was executed for collection_metadata
mock_session.execute.assert_called()
def test_delete_collection_removes_all_data(self, entity_kg):
"""Test delete_collection removes entity partitions and collection rows"""
kg, mock_session = entity_kg
# Mock reading quads from collection
mock_quads = [
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
o='http://example.org/Bob', otype='u')
]
mock_session.execute.return_value = mock_quads
mock_session.reset_mock()
kg.delete_collection('test_collection')
# Verify delete operations were executed
assert mock_session.execute.called
def test_close_shuts_down_connections(self, entity_kg):
"""Test close shuts down session and cluster"""
kg, mock_session = entity_kg
kg.close()
mock_session.shutdown.assert_called_once()
kg.cluster.shutdown.assert_called_once()
class TestQuadResult:
"""Test cases for QuadResult class"""
def test_quad_result_stores_all_fields(self):
"""Test QuadResult stores all quad fields"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/Alice',
p='http://example.org/knows',
o='http://example.org/Bob',
g='http://example.org/graph1',
otype='u',
dtype='',
lang=''
)
assert result.s == 'http://example.org/Alice'
assert result.p == 'http://example.org/knows'
assert result.o == 'http://example.org/Bob'
assert result.g == 'http://example.org/graph1'
assert result.otype == 'u'
assert result.dtype == ''
assert result.lang == ''
def test_quad_result_defaults(self):
"""Test QuadResult default values"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/s',
p='http://example.org/p',
o='literal value',
g=''
)
assert result.otype == 'u' # Default otype
assert result.dtype == ''
assert result.lang == ''
def test_quad_result_with_literal_metadata(self):
"""Test QuadResult with literal metadata"""
from trustgraph.direct.cassandra_kg import QuadResult
result = QuadResult(
s='http://example.org/Alice',
p='http://www.w3.org/2000/01/rdf-schema#label',
o='Alice Smith',
g='',
otype='l',
dtype='xsd:string',
lang='en'
)
assert result.otype == 'l'
assert result.dtype == 'xsd:string'
assert result.lang == 'en'
class TestWriteHelperFunctions:
"""Test cases for helper functions in write.py"""
def test_get_term_otype_for_iri(self):
"""Test get_term_otype returns 'u' for IRI terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_otype(term) == 'u'
def test_get_term_otype_for_literal(self):
"""Test get_term_otype returns 'l' for LITERAL terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='Alice Smith')
assert get_term_otype(term) == 'l'
def test_get_term_otype_for_blank(self):
"""Test get_term_otype returns 'u' for BLANK terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, BLANK
term = Term(type=BLANK, id='_:b1')
assert get_term_otype(term) == 'u'
def test_get_term_otype_for_triple(self):
"""Test get_term_otype returns 't' for TRIPLE terms"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
from trustgraph.schema import Term, TRIPLE
term = Term(type=TRIPLE)
assert get_term_otype(term) == 't'
def test_get_term_otype_for_none(self):
"""Test get_term_otype returns 'u' for None"""
from trustgraph.storage.triples.cassandra.write import get_term_otype
assert get_term_otype(None) == 'u'
def test_get_term_dtype_for_literal(self):
"""Test get_term_dtype extracts datatype from LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='42', datatype='xsd:integer')
assert get_term_dtype(term) == 'xsd:integer'
def test_get_term_dtype_for_non_literal(self):
"""Test get_term_dtype returns empty string for non-LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_dtype(term) == ''
def test_get_term_dtype_for_none(self):
"""Test get_term_dtype returns empty string for None"""
from trustgraph.storage.triples.cassandra.write import get_term_dtype
assert get_term_dtype(None) == ''
def test_get_term_lang_for_literal(self):
"""Test get_term_lang extracts language from LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_lang
from trustgraph.schema import Term, LITERAL
term = Term(type=LITERAL, value='Alice Smith', language='en')
assert get_term_lang(term) == 'en'
def test_get_term_lang_for_non_literal(self):
"""Test get_term_lang returns empty string for non-LITERAL"""
from trustgraph.storage.triples.cassandra.write import get_term_lang
from trustgraph.schema import Term, IRI
term = Term(type=IRI, iri='http://example.org/Alice')
assert get_term_lang(term) == ''
class TestServiceHelperFunctions:
"""Test cases for helper functions in service.py"""
def test_create_term_with_uri_otype(self):
"""Test create_term creates IRI Term for otype='u'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice', otype='u')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_with_literal_otype(self):
"""Test create_term creates LITERAL Term for otype='l'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
assert term.type == LITERAL
assert term.value == 'Alice Smith'
assert term.datatype == 'xsd:string'
assert term.language == 'en'
def test_create_term_with_triple_otype(self):
"""Test create_term creates IRI Term for otype='t'"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/statement1', otype='t')
assert term.type == IRI
assert term.iri == 'http://example.org/statement1'
def test_create_term_heuristic_fallback_uri(self):
"""Test create_term uses URL heuristic when otype not provided"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import IRI
term = create_term('http://example.org/Alice')
assert term.type == IRI
assert term.iri == 'http://example.org/Alice'
def test_create_term_heuristic_fallback_literal(self):
"""Test create_term uses literal heuristic when otype not provided"""
from trustgraph.query.triples.cassandra.service import create_term
from trustgraph.schema import LITERAL
term = create_term('Alice Smith')
assert term.type == LITERAL
assert term.value == 'Alice Smith'

View file

@ -0,0 +1,380 @@
"""
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
Tests the Stage 1 processor that computes embeddings for row index fields.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
"""Test row embeddings processor functionality"""
async def test_processor_initialization(self):
"""Test basic processor initialization"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings'
}
processor = Processor(**config)
assert hasattr(processor, 'schemas')
assert processor.schemas == {}
assert processor.batch_size == 10 # default
async def test_processor_initialization_with_custom_batch_size(self):
"""Test processor initialization with custom batch size"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings',
'batch_size': 25
}
processor = Processor(**config)
assert processor.batch_size == 25
async def test_get_index_names_single_index(self):
"""Test getting index names with single indexed field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
Field(name='email', type='text', indexed=False),
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed field
assert 'id' in index_names
assert 'name' in index_names
assert 'email' not in index_names
async def test_get_index_names_no_indexes(self):
"""Test getting index names when no fields are indexed"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='logs',
description='Log records',
fields=[
Field(name='timestamp', type='text'),
Field(name='message', type='text'),
]
)
index_names = processor.get_index_names(schema)
assert index_names == []
async def test_build_index_value_single_field(self):
"""Test building index value for single field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'id': 'CUST001',
'name': 'John Doe',
'email': 'john@example.com'
}
result = processor.build_index_value(value_map, 'name')
assert result == ['John Doe']
async def test_build_index_value_composite_index(self):
"""Test building index value for composite index"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'first_name': 'John',
'last_name': 'Doe',
'city': 'New York'
}
result = processor.build_index_value(value_map, 'first_name, last_name')
assert result == ['John', 'Doe']
async def test_build_index_value_missing_field(self):
"""Test building index value when field is missing"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'name': 'John Doe'
}
result = processor.build_index_value(value_map, 'missing_field')
assert result == ['']
async def test_build_text_for_embedding_single_value(self):
"""Test building text representation for single value"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John Doe'])
assert result == 'John Doe'
async def test_build_text_for_embedding_multiple_values(self):
"""Test building text representation for multiple values"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
assert result == 'John Doe NYC'
async def test_on_schema_config_loads_schemas(self):
"""Test that schema configuration is loaded correctly"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
schema_def = {
'name': 'customers',
'description': 'Customer records',
'fields': [
{'name': 'id', 'type': 'text', 'primary_key': True},
{'name': 'name', 'type': 'text', 'indexed': True},
{'name': 'email', 'type': 'text'}
]
}
config_data = {
'schema': {
'customers': json.dumps(schema_def)
}
}
await processor.on_schema_config(config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
config_data = {
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
assert processor.schemas == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_drops_unknown_schema(self):
"""Test that messages for unknown schemas are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='unknown_schema',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_processes_embeddings(self):
"""Test processing a message and computing embeddings"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject, RowSchema, Field
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[
{'id': 'CUST001', 'name': 'John Doe'},
{'id': 'CUST002', 'name': 'Jane Smith'}
]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
# Mock the flow
mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
mock_output = AsyncMock()
def flow_factory(name):
if name == 'embeddings-request':
return mock_embeddings_request
elif name == 'output':
return mock_output
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text
# 4 values: CUST001, John Doe, CUST002, Jane Smith
assert mock_embeddings_request.embed.call_count == 4
# Should have sent output
mock_output.send.assert_called()
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -7,7 +7,7 @@ collecting labels and definitions for entity embedding and retrieval.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
from trustgraph.schema.knowledge.graph import EntityContext
@ -25,9 +25,9 @@ class TestEntityContextBuilding:
"""Test that entity context is built from rdfs:label."""
triples = [
Triple(
s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/cornish-pasty"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Cornish Pasty")
)
]
@ -35,16 +35,16 @@ class TestEntityContextBuilding:
assert len(contexts) == 1, "Should create one entity context"
assert isinstance(contexts[0], EntityContext)
assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty"
assert contexts[0].entity.iri == "https://example.com/entity/cornish-pasty"
assert "Label: Cornish Pasty" in contexts[0].context
def test_builds_context_from_definition(self, processor):
"""Test that entity context includes definitions."""
triples = [
Triple(
s=Value(value="https://example.com/entity/pasty", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="A baked pastry filled with savory ingredients", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/pasty"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="A baked pastry filled with savory ingredients")
)
]
@ -57,14 +57,14 @@ class TestEntityContextBuilding:
"""Test that label and definition are combined in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Pasty Recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Pasty Recipe")
),
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional Cornish pastry recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Traditional Cornish pastry recipe")
)
]
@ -80,14 +80,14 @@ class TestEntityContextBuilding:
"""Test that only the first label is used in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="First Label", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="First Label")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Second Label", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Second Label")
)
]
@ -101,14 +101,14 @@ class TestEntityContextBuilding:
"""Test that all definitions are included in context."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="First definition", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="First definition")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Second definition", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Second definition")
)
]
@ -123,9 +123,9 @@ class TestEntityContextBuilding:
"""Test that schema.org description is treated as definition."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="https://schema.org/description", is_uri=True),
o=Value(value="A delicious food item", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="https://schema.org/description"),
o=Term(type=LITERAL, value="A delicious food item")
)
]
@ -138,26 +138,26 @@ class TestEntityContextBuilding:
"""Test that contexts are created for multiple entities."""
triples = [
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity One")
),
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Two", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity Two")
),
Triple(
s=Value(value="https://example.com/entity/entity3", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity Three", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity3"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity Three")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 3, "Should create context for each entity"
entity_uris = [ctx.entity.value for ctx in contexts]
entity_uris = [ctx.entity.iri for ctx in contexts]
assert "https://example.com/entity/entity1" in entity_uris
assert "https://example.com/entity/entity2" in entity_uris
assert "https://example.com/entity/entity3" in entity_uris
@ -166,9 +166,9 @@ class TestEntityContextBuilding:
"""Test that URI objects are ignored (only literal labels/definitions)."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=IRI, iri="https://example.com/some/uri") # URI, not literal
)
]
@ -181,14 +181,14 @@ class TestEntityContextBuilding:
"""Test that other predicates are ignored."""
triples = [
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Food")
),
Triple(
s=Value(value="https://example.com/entity/food1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/food2", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/food1"),
p=Term(type=IRI, iri="http://example.com/produces"),
o=Term(type=IRI, iri="https://example.com/entity/food2")
)
]
@ -205,29 +205,29 @@ class TestEntityContextBuilding:
assert len(contexts) == 0, "Empty triple list should return empty contexts"
def test_entity_context_has_value_object(self, processor):
"""Test that EntityContext.entity is a Value object."""
def test_entity_context_has_term_object(self, processor):
"""Test that EntityContext.entity is a Term object."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test Entity")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1
assert isinstance(contexts[0].entity, Value), "Entity should be Value object"
assert contexts[0].entity.is_uri, "Entity should be marked as URI"
assert isinstance(contexts[0].entity, Term), "Entity should be Term object"
assert contexts[0].entity.type == IRI, "Entity should be IRI type"
def test_entity_context_text_is_string(self, processor):
"""Test that EntityContext.context is a string."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test Entity", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test Entity")
)
]
@ -241,22 +241,22 @@ class TestEntityContextBuilding:
triples = [
# Entity with label - should create context
Triple(
s=Value(value="https://example.com/entity/entity1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Entity One", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Entity One")
),
# Entity with only rdf:type - should NOT create context
Triple(
s=Value(value="https://example.com/entity/entity2", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Food", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Food")
)
]
contexts = processor.build_entity_contexts(triples)
assert len(contexts) == 1, "Should only create context for entity with label/definition"
assert contexts[0].entity.value == "https://example.com/entity/entity1"
assert contexts[0].entity.iri == "https://example.com/entity/entity1"
class TestEntityContextEdgeCases:
@ -266,9 +266,9 @@ class TestEntityContextEdgeCases:
"""Test handling of unicode characters in labels."""
triples = [
Triple(
s=Value(value="https://example.com/entity/café", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Café Spécial", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/café"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Café Spécial")
)
]
@ -282,9 +282,9 @@ class TestEntityContextEdgeCases:
long_def = "This is a very long definition " * 50
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value=long_def, is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value=long_def)
)
]
@ -297,9 +297,9 @@ class TestEntityContextEdgeCases:
"""Test handling of special characters in context text."""
triples = [
Triple(
s=Value(value="https://example.com/entity/test", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Test & Entity <with> \"quotes\"", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/test"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Test & Entity <with> \"quotes\"")
)
]
@ -313,27 +313,27 @@ class TestEntityContextEdgeCases:
triples = [
# Label - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish Pasty Recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
o=Term(type=LITERAL, value="Cornish Pasty Recipe")
),
# Type - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://example.com/Recipe", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri="http://example.com/Recipe")
),
# Property - irrelevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://example.com/produces", is_uri=True),
o=Value(value="https://example.com/entity/pasty", is_uri=True)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://example.com/produces"),
o=Term(type=IRI, iri="https://example.com/entity/pasty")
),
# Definition - relevant
Triple(
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
o=Value(value="Traditional British pastry recipe", is_uri=False)
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
o=Term(type=LITERAL, value="Traditional British pastry recipe")
)
]

View file

@ -9,7 +9,7 @@ the knowledge graph.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
@pytest.fixture
@ -92,12 +92,12 @@ class TestOntologyTripleGeneration:
# Find type triples for Recipe class
recipe_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class"
assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \
assert recipe_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#Class", \
"Class type should be owl:Class"
def test_generates_class_labels(self, extractor, sample_ontology_subset):
@ -107,14 +107,14 @@ class TestOntologyTripleGeneration:
# Find label triples for Recipe class
recipe_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(recipe_label_triples) == 1, "Should generate label triple for class"
assert recipe_label_triples[0].o.value == "Recipe", \
"Label should match class label from ontology"
assert not recipe_label_triples[0].o.is_uri, \
assert recipe_label_triples[0].o.type == LITERAL, \
"Label should be a literal, not URI"
def test_generates_class_comments(self, extractor, sample_ontology_subset):
@ -124,8 +124,8 @@ class TestOntologyTripleGeneration:
# Find comment triples for Recipe class
recipe_comment_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/Recipe"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment"
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#comment"
]
assert len(recipe_comment_triples) == 1, "Should generate comment triple for class"
@ -139,13 +139,13 @@ class TestOntologyTripleGeneration:
# Find type triples for ingredients property
ingredients_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(ingredients_type_triples) == 1, \
"Should generate exactly one type triple per object property"
assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \
assert ingredients_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#ObjectProperty", \
"Object property type should be owl:ObjectProperty"
def test_generates_object_property_labels(self, extractor, sample_ontology_subset):
@ -155,8 +155,8 @@ class TestOntologyTripleGeneration:
# Find label triples for ingredients property
ingredients_label_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(ingredients_label_triples) == 1, \
@ -171,15 +171,15 @@ class TestOntologyTripleGeneration:
# Find domain triples for ingredients property
ingredients_domain_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/ingredients"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain"
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#domain"
]
assert len(ingredients_domain_triples) == 1, \
"Should generate domain triple for object property"
assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \
assert ingredients_domain_triples[0].o.iri == "http://purl.org/ontology/fo/Recipe", \
"Domain should be Recipe class URI"
assert ingredients_domain_triples[0].o.is_uri, \
assert ingredients_domain_triples[0].o.type == IRI, \
"Domain should be a URI reference"
def test_generates_object_property_range(self, extractor, sample_ontology_subset):
@ -189,13 +189,13 @@ class TestOntologyTripleGeneration:
# Find range triples for produces property
produces_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/produces"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
if t.s.iri == "http://purl.org/ontology/fo/produces"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(produces_range_triples) == 1, \
"Should generate range triple for object property"
assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \
assert produces_range_triples[0].o.iri == "http://purl.org/ontology/fo/Food", \
"Range should be Food class URI"
def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset):
@ -205,13 +205,13 @@ class TestOntologyTripleGeneration:
# Find type triples for serves property
serves_type_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
if t.s.iri == "http://purl.org/ontology/fo/serves"
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
]
assert len(serves_type_triples) == 1, \
"Should generate exactly one type triple per datatype property"
assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
assert serves_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
"Datatype property type should be owl:DatatypeProperty"
def test_generates_datatype_property_range(self, extractor, sample_ontology_subset):
@ -221,13 +221,13 @@ class TestOntologyTripleGeneration:
# Find range triples for serves property
serves_range_triples = [
t for t in triples
if t.s.value == "http://purl.org/ontology/fo/serves"
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
if t.s.iri == "http://purl.org/ontology/fo/serves"
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
]
assert len(serves_range_triples) == 1, \
"Should generate range triple for datatype property"
assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \
assert serves_range_triples[0].o.iri == "http://www.w3.org/2001/XMLSchema#string", \
"Range should be XSD type URI (xsd:string expanded)"
def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset):
@ -236,9 +236,9 @@ class TestOntologyTripleGeneration:
# Count unique class subjects
class_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and t.o.value == "http://www.w3.org/2002/07/owl#Class"
t.s.iri for t in triples
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and t.o.iri == "http://www.w3.org/2002/07/owl#Class"
)
assert len(class_subjects) == 3, \
@ -250,9 +250,9 @@ class TestOntologyTripleGeneration:
# Count unique property subjects (object + datatype properties)
property_subjects = set(
t.s.value for t in triples
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value)
t.s.iri for t in triples
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
and ("ObjectProperty" in t.o.iri or "DatatypeProperty" in t.o.iri)
)
assert len(property_subjects) == 3, \
@ -276,7 +276,7 @@ class TestOntologyTripleGeneration:
# Should still generate proper RDF triples despite dict field names
label_triples = [
t for t in triples
if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
if t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
]
assert len(label_triples) > 0, \
"Should generate rdfs:label triples from dict 'labels' field"

View file

@ -8,7 +8,7 @@ and extracts/validates triples from LLM responses.
import pytest
from trustgraph.extract.kg.ontology.extract import Processor
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
from trustgraph.schema.core.primitives import Triple, Value
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
@pytest.fixture
@ -248,9 +248,9 @@ class TestTripleParsing:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert len(validated) == 1, "Should parse one valid triple"
assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty"
assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
assert validated[0].s.iri == "https://trustgraph.ai/food/cornish-pasty"
assert validated[0].p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
def test_parse_multiple_triples(self, extractor, sample_ontology_subset):
"""Test parsing multiple triples."""
@ -307,11 +307,11 @@ class TestTripleParsing:
assert len(validated) == 1
# Subject should be expanded to entity URI
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
# Predicate should be expanded to ontology URI
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
# Object should be expanded to class URI
assert validated[0].o.value == "http://purl.org/ontology/fo/Food"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Food"
def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset):
"""Test that Triple objects are properly created."""
@ -324,12 +324,12 @@ class TestTripleParsing:
assert len(validated) == 1
triple = validated[0]
assert isinstance(triple, Triple), "Should create Triple objects"
assert isinstance(triple.s, Value), "Subject should be Value object"
assert isinstance(triple.p, Value), "Predicate should be Value object"
assert isinstance(triple.o, Value), "Object should be Value object"
assert triple.s.is_uri, "Subject should be marked as URI"
assert triple.p.is_uri, "Predicate should be marked as URI"
assert not triple.o.is_uri, "Object literal should not be marked as URI"
assert isinstance(triple.s, Term), "Subject should be Term object"
assert isinstance(triple.p, Term), "Predicate should be Term object"
assert isinstance(triple.o, Term), "Object should be Term object"
assert triple.s.type == IRI, "Subject should be IRI type"
assert triple.p.type == IRI, "Predicate should be IRI type"
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
class TestURIExpansionInExtraction:
@ -343,8 +343,8 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
assert validated[0].o.is_uri, "Class reference should be URI"
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
assert validated[0].o.type == IRI, "Class reference should be URI"
def test_expands_property_names(self, extractor, sample_ontology_subset):
"""Test that property names are expanded to full URIs."""
@ -354,7 +354,7 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
def test_expands_entity_instances(self, extractor, sample_ontology_subset):
"""Test that entity instances get constructed URIs."""
@ -364,8 +364,8 @@ class TestURIExpansionInExtraction:
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
assert "my-special-recipe" in validated[0].s.value
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
assert "my-special-recipe" in validated[0].s.iri
class TestEdgeCases:

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestDispatchSerialize:
@ -14,55 +14,55 @@ class TestDispatchSerialize:
def test_to_value_with_uri(self):
"""Test to_value function with URI"""
input_data = {"v": "http://example.com/resource", "e": True}
input_data = {"t": "i", "i": "http://example.com/resource"}
result = to_value(input_data)
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_to_value_with_literal(self):
"""Test to_value function with literal value"""
input_data = {"v": "literal string", "e": False}
input_data = {"t": "l", "v": "literal string"}
result = to_value(input_data)
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_to_subgraph_with_multiple_triples(self):
"""Test to_subgraph function with multiple triples"""
input_data = [
{
"s": {"v": "subject1", "e": True},
"p": {"v": "predicate1", "e": True},
"o": {"v": "object1", "e": False}
"s": {"t": "i", "i": "subject1"},
"p": {"t": "i", "i": "predicate1"},
"o": {"t": "l", "v": "object1"}
},
{
"s": {"v": "subject2", "e": False},
"p": {"v": "predicate2", "e": True},
"o": {"v": "object2", "e": True}
"s": {"t": "l", "v": "subject2"},
"p": {"t": "i", "i": "predicate2"},
"o": {"t": "i", "i": "object2"}
}
]
result = to_subgraph(input_data)
assert len(result) == 2
assert all(isinstance(triple, Triple) for triple in result)
# Check first triple
assert result[0].s.value == "subject1"
assert result[0].s.is_uri is True
assert result[0].p.value == "predicate1"
assert result[0].p.is_uri is True
assert result[0].s.iri == "subject1"
assert result[0].s.type == IRI
assert result[0].p.iri == "predicate1"
assert result[0].p.type == IRI
assert result[0].o.value == "object1"
assert result[0].o.is_uri is False
assert result[0].o.type == LITERAL
# Check second triple
assert result[1].s.value == "subject2"
assert result[1].s.is_uri is False
assert result[1].s.type == LITERAL
def test_to_subgraph_with_empty_list(self):
"""Test to_subgraph function with empty input"""
@ -74,16 +74,16 @@ class TestDispatchSerialize:
def test_serialize_value_with_uri(self):
"""Test serialize_value function with URI value"""
value = Value(value="http://example.com/test", is_uri=True)
result = serialize_value(value)
assert result == {"v": "http://example.com/test", "e": True}
term = Term(type=IRI, iri="http://example.com/test")
result = serialize_value(term)
assert result == {"t": "i", "i": "http://example.com/test"}
def test_serialize_value_with_literal(self):
"""Test serialize_value function with literal value"""
value = Value(value="test literal", is_uri=False)
result = serialize_value(value)
assert result == {"v": "test literal", "e": False}
term = Term(type=LITERAL, value="test literal")
result = serialize_value(term)
assert result == {"t": "l", "v": "test literal"}

View file

@ -1,7 +1,7 @@
"""
Unit tests for objects import dispatcher.
Unit tests for rows import dispatcher.
Tests the business logic of objects import dispatcher
Tests the business logic of rows import dispatcher
while mocking the Publisher and websocket components.
"""
@ -11,7 +11,7 @@ import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from aiohttp import web
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
from trustgraph.gateway.dispatch.rows_import import RowsImport
from trustgraph.schema import Metadata, ExtractedObject
@ -92,16 +92,16 @@ def minimal_objects_message():
}
class TestObjectsImportInitialization:
"""Test ObjectsImport initialization."""
class TestRowsImportInitialization:
"""Test RowsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters."""
"""Test that RowsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
)
# Verify instance variables are set correctly
assert objects_import.ws == mock_websocket
assert objects_import.running == mock_running
assert objects_import.publisher == mock_publisher_instance
assert rows_import.ws == mock_websocket
assert rows_import.running == mock_running
assert rows_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport(
"""Test that RowsImport stores all required references."""
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="objects-queue"
)
assert objects_import.ws is mock_websocket
assert objects_import.running is mock_running
assert rows_import.ws is mock_websocket
assert rows_import.running is mock_running
class TestObjectsImportLifecycle:
"""Test ObjectsImport lifecycle methods."""
class TestRowsImportLifecycle:
"""Test RowsImport lifecycle methods."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that start() calls publisher.start()."""
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.start = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.start()
await rows_import.start()
mock_publisher_instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket."""
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.destroy()
await rows_import.destroy()
# Verify sequence of operations
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
"""Test that destroy() handles None websocket gracefully."""
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
)
# Should not raise exception
await objects_import.destroy()
await rows_import.destroy()
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
class TestObjectsImportMessageProcessing:
"""Test ObjectsImport message processing."""
class TestRowsImportMessageProcessing:
"""Test RowsImport message processing."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly."""
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields."""
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = minimal_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields."""
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = message_data
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Get the sent object and verify defaults
sent_object = mock_publisher_instance.send.call_args[0][1]
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == ""
class TestObjectsImportRunMethod:
"""Test ObjectsImport run method."""
class TestRowsImportRunMethod:
"""Test RowsImport run method."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True."""
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
# Set up running state to return True twice, then False
mock_running.get.side_effect = [True, True, False]
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.run()
await rows_import.run()
# Verify sleep was called twice (for the two True iterations)
assert mock_sleep.call_count == 2
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
mock_websocket.close.assert_called_once()
# Verify websocket was set to None
assert objects_import.ws is None
assert rows_import.ws is None
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
"""Test that run() handles None websocket gracefully."""
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
mock_running.get.return_value = False # Exit immediately
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
)
# Should not raise exception
await objects_import.run()
await rows_import.run()
# Verify websocket remains None
assert objects_import.ws is None
assert rows_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
class TestRowsImportBatchProcessing:
"""Test RowsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport."""
class TestRowsImportErrorHandling:
"""Test error handling in RowsImport."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors."""
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
mock_msg.json.return_value = sample_objects_message
with pytest.raises(Exception, match="Publisher error"):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with pytest.raises(json.JSONDecodeError):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)

View file

@ -6,11 +6,21 @@ import pytest
from unittest.mock import Mock, AsyncMock
# Mock schema classes for testing
class Value:
def __init__(self, value, is_uri, type):
self.value = value
self.is_uri = is_uri
# Term type constants
IRI = "i"
LITERAL = "l"
BLANK = "b"
TRIPLE = "t"
class Term:
def __init__(self, type, iri=None, value=None, id=None, datatype=None, language=None, triple=None):
self.type = type
self.iri = iri
self.value = value
self.id = id
self.datatype = datatype
self.language = language
self.triple = triple
class Triple:
def __init__(self, s, p, o):
@ -66,32 +76,30 @@ def sample_relationships():
@pytest.fixture
def sample_value_uri():
"""Sample URI Value object"""
return Value(
value="http://example.com/person/john-smith",
is_uri=True,
type=""
def sample_term_uri():
"""Sample URI Term object"""
return Term(
type=IRI,
iri="http://example.com/person/john-smith"
)
@pytest.fixture
def sample_value_literal():
"""Sample literal Value object"""
return Value(
value="John Smith",
is_uri=False,
type="string"
def sample_term_literal():
"""Sample literal Term object"""
return Term(
type=LITERAL,
value="John Smith"
)
@pytest.fixture
def sample_triple(sample_value_uri, sample_value_literal):
def sample_triple(sample_term_uri, sample_term_literal):
"""Sample Triple object"""
return Triple(
s=sample_value_uri,
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=sample_value_literal
s=sample_term_uri,
p=Term(type=IRI, iri="http://schema.org/name"),
o=sample_term_literal
)

View file

@ -11,7 +11,7 @@ import json
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from trustgraph.template.prompt_manager import PromptManager
@ -33,7 +33,7 @@ class TestAgentKgExtractor:
# Set up the methods we want to test
extractor.to_uri = real_extractor.to_uri
extractor.parse_json = real_extractor.parse_json
extractor.parse_jsonl = real_extractor.parse_jsonl
extractor.process_extraction_data = real_extractor.process_extraction_data
extractor.emit_triples = real_extractor.emit_triples
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
@ -53,48 +53,49 @@ class TestAgentKgExtractor:
id="doc123",
metadata=[
Triple(
s=Value(value="doc123", is_uri=True),
p=Value(value="http://example.org/type", is_uri=True),
o=Value(value="document", is_uri=False)
s=Term(type=IRI, iri="doc123"),
p=Term(type=IRI, iri="http://example.org/type"),
o=Term(type=LITERAL, value="document")
)
]
)
@pytest.fixture
def sample_extraction_data(self):
"""Sample extraction data in expected format"""
return {
"definitions": [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
],
"relationships": [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "used_in",
"object": "Machine Learning",
"object-entity": True
},
{
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
}
"""Sample extraction data in JSONL format (list with type discriminators)"""
return [
{
"type": "definition",
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"type": "definition",
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
},
{
"type": "relationship",
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"type": "relationship",
"subject": "Neural Networks",
"predicate": "used_in",
"object": "Machine Learning",
"object-entity": True
},
{
"type": "relationship",
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
def test_to_uri_conversion(self, agent_extractor):
"""Test URI conversion for entities"""
@ -113,148 +114,147 @@ class TestAgentKgExtractor:
expected = f"{TRUSTGRAPH_ENTITIES}"
assert uri == expected
def test_parse_json_with_code_blocks(self, agent_extractor):
"""Test JSON parsing from code blocks"""
# Test JSON in code blocks
def test_parse_jsonl_with_code_blocks(self, agent_extractor):
"""Test JSONL parsing from code blocks"""
# Test JSONL in code blocks - note: JSON uses lowercase true/false
response = '''```json
{
"definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}],
"relationships": []
}
```'''
result = agent_extractor.parse_json(response)
assert result["definitions"][0]["entity"] == "AI"
assert result["definitions"][0]["definition"] == "Artificial Intelligence"
assert result["relationships"] == []
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}
{"type": "relationship", "subject": "AI", "predicate": "is", "object": "technology", "object-entity": false}
```'''
def test_parse_json_without_code_blocks(self, agent_extractor):
"""Test JSON parsing without code blocks"""
response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}'''
result = agent_extractor.parse_json(response)
assert result["definitions"][0]["entity"] == "ML"
assert result["definitions"][0]["definition"] == "Machine Learning"
result = agent_extractor.parse_jsonl(response)
def test_parse_json_invalid_format(self, agent_extractor):
"""Test JSON parsing with invalid format"""
invalid_response = "This is not JSON at all"
with pytest.raises(json.JSONDecodeError):
agent_extractor.parse_json(invalid_response)
assert len(result) == 2
assert result[0]["entity"] == "AI"
assert result[0]["definition"] == "Artificial Intelligence"
assert result[1]["type"] == "relationship"
def test_parse_json_malformed_code_blocks(self, agent_extractor):
"""Test JSON parsing with malformed code blocks"""
# Missing closing backticks
response = '''```json
{"definitions": [], "relationships": []}
'''
# Should still parse the JSON content
with pytest.raises(json.JSONDecodeError):
agent_extractor.parse_json(response)
def test_parse_jsonl_without_code_blocks(self, agent_extractor):
"""Test JSONL parsing without code blocks"""
response = '''{"type": "definition", "entity": "ML", "definition": "Machine Learning"}
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}'''
result = agent_extractor.parse_jsonl(response)
assert len(result) == 2
assert result[0]["entity"] == "ML"
assert result[1]["entity"] == "AI"
def test_parse_jsonl_invalid_lines_skipped(self, agent_extractor):
"""Test JSONL parsing skips invalid lines gracefully"""
response = '''{"type": "definition", "entity": "Valid", "definition": "Valid def"}
This is not JSON at all
{"type": "definition", "entity": "Also Valid", "definition": "Another def"}'''
result = agent_extractor.parse_jsonl(response)
# Should get 2 valid objects, skipping the invalid line
assert len(result) == 2
assert result[0]["entity"] == "Valid"
assert result[1]["entity"] == "Also Valid"
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
"""Test JSONL parsing handles truncated responses"""
# Simulates output cut off mid-line
response = '''{"type": "definition", "entity": "Complete", "definition": "Full def"}
{"type": "definition", "entity": "Trunca'''
result = agent_extractor.parse_jsonl(response)
# Should get 1 valid object, the truncated line is skipped
assert len(result) == 1
assert result[0]["entity"] == "Complete"
def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata):
"""Test processing of definition data"""
data = {
"definitions": [
{
"entity": "Machine Learning",
"definition": "A subset of AI that enables learning from data."
}
],
"relationships": []
}
data = [
{
"type": "definition",
"entity": "Machine Learning",
"definition": "A subset of AI that enables learning from data."
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check entity label triple
label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None)
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
assert label_triple is not None
assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert label_triple.s.is_uri == True
assert label_triple.o.is_uri == False
assert label_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert label_triple.s.type == IRI
assert label_triple.o.type == LITERAL
# Check definition triple
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
assert def_triple is not None
assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert def_triple.o.value == "A subset of AI that enables learning from data."
# Check subject-of triple
subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None)
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
assert subject_of_triple is not None
assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert subject_of_triple.o.value == "doc123"
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert subject_of_triple.o.iri == "doc123"
# Check entity context
assert len(entity_contexts) == 1
assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
assert entity_contexts[0].context == "A subset of AI that enables learning from data."
def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata):
"""Test processing of relationship data"""
data = {
"definitions": [],
"relationships": [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
}
]
}
data = [
{
"type": "relationship",
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that subject, predicate, and object labels are created
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of"
# Find label triples
subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None)
subject_label = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == RDF_LABEL), None)
assert subject_label is not None
assert subject_label.o.value == "Machine Learning"
predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None)
predicate_label = next((t for t in triples if t.s.iri == predicate_uri and t.p.iri == RDF_LABEL), None)
assert predicate_label is not None
assert predicate_label.o.value == "is_subset_of"
# Check main relationship triple
# NOTE: Current implementation has bugs:
# 1. Uses data.get("object-entity") instead of rel.get("object-entity")
# 2. Sets object_value to predicate_uri instead of actual object URI
# This test documents the current buggy behavior
rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None)
# Check main relationship triple
object_uri = f"{TRUSTGRAPH_ENTITIES}Artificial%20Intelligence"
rel_triple = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == predicate_uri), None)
assert rel_triple is not None
# Due to bug, object value is set to predicate_uri
assert rel_triple.o.value == predicate_uri
assert rel_triple.o.iri == object_uri
assert rel_triple.o.type == IRI
# Check subject-of relationships
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
"""Test processing of relationships with literal objects"""
data = {
"definitions": [],
"relationships": [
{
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
}
data = [
{
"type": "relationship",
"subject": "Deep Learning",
"predicate": "accuracy",
"object": "95%",
"object-entity": False
}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Check that object labels are not created for literal objects
object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"]
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
# Based on the code logic, it should not create object labels for non-entity objects
# But there might be a bug in the original implementation
@ -263,75 +263,62 @@ class TestAgentKgExtractor:
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
# Check that we have both definition and relationship triples
definition_triples = [t for t in triples if t.p.value == DEFINITION]
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
assert len(definition_triples) == 2 # Two definitions
# Check entity contexts are created for definitions
assert len(entity_contexts) == 2
entity_uris = [ec.entity.value for ec in entity_contexts]
entity_uris = [ec.entity.iri for ec in entity_contexts]
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
"""Test processing when metadata has no ID"""
metadata = Metadata(id=None, metadata=[])
data = {
"definitions": [
{"entity": "Test Entity", "definition": "Test definition"}
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of relationships when no metadata ID
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
# Should still create entity contexts
assert len(entity_contexts) == 1
def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata):
"""Test processing of empty extraction data"""
data = {"definitions": [], "relationships": []}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Should only have metadata triples
assert len(entity_contexts) == 0
# Triples should only contain metadata triples if any
data = []
def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata):
"""Test processing data with missing keys"""
# Test missing definitions key
data = {"relationships": []}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
# Should have no entity contexts
assert len(entity_contexts) == 0
# Test missing relationships key
data = {"definitions": []}
# Triples should be empty
assert len(triples) == 0
def test_process_extraction_data_unknown_types_ignored(self, agent_extractor, sample_metadata):
"""Test processing data with unknown type values"""
data = [
{"type": "definition", "entity": "Valid", "definition": "Valid def"},
{"type": "unknown_type", "foo": "bar"}, # Unknown type - should be ignored
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
]
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
assert len(entity_contexts) == 0
# Test completely missing keys
data = {}
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
assert len(entity_contexts) == 0
# Should process valid items and ignore unknown types
assert len(entity_contexts) == 1 # Only the definition creates entity context
def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata):
"""Test processing data with malformed entries"""
# Test definition missing required fields
data = {
"definitions": [
{"entity": "Test"}, # Missing definition
{"definition": "Test def"} # Missing entity
],
"relationships": [
{"subject": "A", "predicate": "rel"}, # Missing object
{"subject": "B", "object": "C"} # Missing predicate
]
}
# Test items missing required fields - should raise KeyError
data = [
{"type": "definition", "entity": "Test"}, # Missing definition
]
# Should handle gracefully or raise appropriate errors
with pytest.raises(KeyError):
agent_extractor.process_extraction_data(data, sample_metadata)
@ -340,17 +327,17 @@ class TestAgentKgExtractor:
async def test_emit_triples(self, agent_extractor, sample_metadata):
"""Test emitting triples to publisher"""
mock_publisher = AsyncMock()
test_triples = [
Triple(
s=Value(value="test:subject", is_uri=True),
p=Value(value="test:predicate", is_uri=True),
o=Value(value="test object", is_uri=False)
s=Term(type=IRI, iri="test:subject"),
p=Term(type=IRI, iri="test:predicate"),
o=Term(type=LITERAL, value="test object")
)
]
await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples)
mock_publisher.send.assert_called_once()
sent_triples = mock_publisher.send.call_args[0][0]
assert isinstance(sent_triples, Triples)
@ -361,22 +348,22 @@ class TestAgentKgExtractor:
# Note: metadata.metadata is now empty array in the new implementation
assert sent_triples.metadata.metadata == []
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.value == "test:subject"
assert sent_triples.triples[0].s.iri == "test:subject"
@pytest.mark.asyncio
async def test_emit_entity_contexts(self, agent_extractor, sample_metadata):
"""Test emitting entity contexts to publisher"""
mock_publisher = AsyncMock()
test_contexts = [
EntityContext(
entity=Value(value="test:entity", is_uri=True),
entity=Term(type=IRI, iri="test:entity"),
context="Test context"
)
]
await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts)
mock_publisher.send.assert_called_once()
sent_contexts = mock_publisher.send.call_args[0][0]
assert isinstance(sent_contexts, EntityContexts)
@ -387,7 +374,7 @@ class TestAgentKgExtractor:
# Note: metadata.metadata is now empty array in the new implementation
assert sent_contexts.metadata.metadata == []
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.value == "test:entity"
assert sent_contexts.entities[0].entity.iri == "test:entity"
def test_agent_extractor_initialization_params(self):
"""Test agent extractor parameter validation"""

View file

@ -11,7 +11,7 @@ import urllib.parse
from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
@ -32,11 +32,11 @@ class TestAgentKgExtractionEdgeCases:
# Set up the methods we want to test
extractor.to_uri = real_extractor.to_uri
extractor.parse_json = real_extractor.parse_json
extractor.parse_jsonl = real_extractor.parse_jsonl
extractor.process_extraction_data = real_extractor.process_extraction_data
extractor.emit_triples = real_extractor.emit_triples
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
return extractor
def test_to_uri_special_characters(self, agent_extractor):
@ -85,146 +85,116 @@ class TestAgentKgExtractionEdgeCases:
# Verify the URI is properly encoded
assert unicode_text not in uri # Original unicode should be encoded
def test_parse_json_whitespace_variations(self, agent_extractor):
"""Test JSON parsing with various whitespace patterns"""
# Test JSON with different whitespace patterns
def test_parse_jsonl_whitespace_variations(self, agent_extractor):
"""Test JSONL parsing with various whitespace patterns"""
# Test JSONL with different whitespace patterns
test_cases = [
# Extra whitespace around code blocks
" ```json\n{\"test\": true}\n``` ",
# Tabs and mixed whitespace
"\t\t```json\n\t{\"test\": true}\n\t```\t",
# Multiple newlines
"\n\n\n```json\n\n{\"test\": true}\n\n```\n\n",
# JSON without code blocks but with whitespace
" {\"test\": true} ",
# Mixed line endings
"```json\r\n{\"test\": true}\r\n```",
' ```json\n{"type": "definition", "entity": "test", "definition": "def"}\n``` ',
# Multiple newlines between lines
'{"type": "definition", "entity": "A", "definition": "def A"}\n\n\n{"type": "definition", "entity": "B", "definition": "def B"}',
# JSONL without code blocks but with whitespace
' {"type": "definition", "entity": "test", "definition": "def"} ',
]
for response in test_cases:
result = agent_extractor.parse_json(response)
assert result == {"test": True}
def test_parse_json_code_block_variations(self, agent_extractor):
"""Test JSON parsing with different code block formats"""
for response in test_cases:
result = agent_extractor.parse_jsonl(response)
assert len(result) >= 1
assert result[0].get("type") == "definition"
def test_parse_jsonl_code_block_variations(self, agent_extractor):
"""Test JSONL parsing with different code block formats"""
test_cases = [
# Standard json code block
"```json\n{\"valid\": true}\n```",
'```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# jsonl code block
'```jsonl\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# Code block without language
"```\n{\"valid\": true}\n```",
# Uppercase JSON
"```JSON\n{\"valid\": true}\n```",
# Mixed case
"```Json\n{\"valid\": true}\n```",
# Multiple code blocks (should take first one)
"```json\n{\"first\": true}\n```\n```json\n{\"second\": true}\n```",
# Code block with extra content
"Here's the result:\n```json\n{\"valid\": true}\n```\nDone!",
'```\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
# Code block with extra content before/after
'Here\'s the result:\n```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```\nDone!',
]
for i, response in enumerate(test_cases):
try:
result = agent_extractor.parse_json(response)
assert result.get("valid") == True or result.get("first") == True
except json.JSONDecodeError:
# Some cases may fail due to regex extraction issues
# This documents current behavior - the regex may not match all cases
print(f"Case {i} failed JSON parsing: {response[:50]}...")
pass
result = agent_extractor.parse_jsonl(response)
assert len(result) >= 1, f"Case {i} failed"
assert result[0].get("entity") == "A"
def test_parse_json_malformed_code_blocks(self, agent_extractor):
"""Test JSON parsing with malformed code block formats"""
# These should still work by falling back to treating entire text as JSON
test_cases = [
# Unclosed code block
"```json\n{\"test\": true}",
# No opening backticks
"{\"test\": true}\n```",
# Wrong number of backticks
"`json\n{\"test\": true}\n`",
# Nested backticks (should handle gracefully)
"```json\n{\"code\": \"```\", \"test\": true}\n```",
]
for response in test_cases:
try:
result = agent_extractor.parse_json(response)
assert "test" in result # Should successfully parse
except json.JSONDecodeError:
# This is also acceptable for malformed cases
pass
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
"""Test JSONL parsing with truncated responses"""
# Simulates LLM output being cut off mid-line
response = '''{"type": "definition", "entity": "Complete1", "definition": "Full definition"}
{"type": "definition", "entity": "Complete2", "definition": "Another full def"}
{"type": "definition", "entity": "Trunca'''
def test_parse_json_large_responses(self, agent_extractor):
"""Test JSON parsing with very large responses"""
# Create a large JSON structure
large_data = {
"definitions": [
{
"entity": f"Entity {i}",
"definition": f"Definition {i} " + "with more content " * 100
}
for i in range(100)
],
"relationships": [
{
"subject": f"Subject {i}",
"predicate": f"predicate_{i}",
"object": f"Object {i}",
"object-entity": i % 2 == 0
}
for i in range(50)
]
}
large_json_str = json.dumps(large_data)
response = f"```json\n{large_json_str}\n```"
result = agent_extractor.parse_json(response)
assert len(result["definitions"]) == 100
assert len(result["relationships"]) == 50
assert result["definitions"][0]["entity"] == "Entity 0"
result = agent_extractor.parse_jsonl(response)
# Should get 2 valid objects, the truncated line is skipped
assert len(result) == 2
assert result[0]["entity"] == "Complete1"
assert result[1]["entity"] == "Complete2"
def test_parse_jsonl_large_responses(self, agent_extractor):
"""Test JSONL parsing with very large responses"""
# Create a large JSONL response
lines = []
for i in range(100):
lines.append(json.dumps({
"type": "definition",
"entity": f"Entity {i}",
"definition": f"Definition {i} " + "with more content " * 100
}))
for i in range(50):
lines.append(json.dumps({
"type": "relationship",
"subject": f"Subject {i}",
"predicate": f"predicate_{i}",
"object": f"Object {i}",
"object-entity": i % 2 == 0
}))
response = f"```json\n{chr(10).join(lines)}\n```"
result = agent_extractor.parse_jsonl(response)
definitions = [r for r in result if r.get("type") == "definition"]
relationships = [r for r in result if r.get("type") == "relationship"]
assert len(definitions) == 100
assert len(relationships) == 50
assert definitions[0]["entity"] == "Entity 0"
def test_process_extraction_data_empty_metadata(self, agent_extractor):
"""Test processing with empty or minimal metadata"""
# Test with None metadata - may not raise AttributeError depending on implementation
try:
triples, contexts = agent_extractor.process_extraction_data(
{"definitions": [], "relationships": []},
None
)
triples, contexts = agent_extractor.process_extraction_data([], None)
# If it doesn't raise, check the results
assert len(triples) == 0
assert len(contexts) == 0
except (AttributeError, TypeError):
# This is expected behavior when metadata is None
pass
# Test with metadata without ID
metadata = Metadata(id=None, metadata=[])
triples, contexts = agent_extractor.process_extraction_data(
{"definitions": [], "relationships": []},
metadata
)
triples, contexts = agent_extractor.process_extraction_data([], metadata)
assert len(triples) == 0
assert len(contexts) == 0
# Test with metadata with empty string ID
metadata = Metadata(id="", metadata=[])
data = {
"definitions": [{"entity": "Test", "definition": "Test def"}],
"relationships": []
}
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should not create subject-of triples when ID is empty string
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
assert len(subject_of_triples) == 0
def test_process_extraction_data_special_entity_names(self, agent_extractor):
"""Test processing with special characters in entity names"""
metadata = Metadata(id="doc123", metadata=[])
special_entities = [
"Entity with spaces",
"Entity & Co.",
@ -237,71 +207,62 @@ class TestAgentKgExtractionEdgeCases:
"Quotes: \"test\"",
"Parentheses: (test)",
]
data = {
"definitions": [
{"entity": entity, "definition": f"Definition for {entity}"}
for entity in special_entities
],
"relationships": []
}
data = [
{"type": "definition", "entity": entity, "definition": f"Definition for {entity}"}
for entity in special_entities
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Verify all entities were processed
assert len(contexts) == len(special_entities)
# Verify URIs were properly encoded
for i, entity in enumerate(special_entities):
expected_uri = f"{TRUSTGRAPH_ENTITIES}{urllib.parse.quote(entity)}"
assert contexts[i].entity.value == expected_uri
assert contexts[i].entity.iri == expected_uri
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
"""Test processing with very long entity definitions"""
metadata = Metadata(id="doc123", metadata=[])
# Create very long definition
long_definition = "This is a very long definition. " * 1000
data = {
"definitions": [
{"entity": "Test Entity", "definition": long_definition}
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle long definitions without issues
assert len(contexts) == 1
assert contexts[0].context == long_definition
# Find definition triple
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
assert def_triple is not None
assert def_triple.o.value == long_definition
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
"""Test processing with duplicate entity names"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{"entity": "Machine Learning", "definition": "First definition"},
{"entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
{"entity": "AI", "definition": "AI definition"},
{"entity": "AI", "definition": "Another AI definition"}, # Duplicate
],
"relationships": []
}
data = [
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
{"type": "definition", "entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
{"type": "definition", "entity": "AI", "definition": "AI definition"},
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should process all entries (including duplicates)
assert len(contexts) == 4
# Check that both definitions for "Machine Learning" are present
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.value]
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.iri]
assert len(ml_contexts) == 2
assert ml_contexts[0].context == "First definition"
assert ml_contexts[1].context == "Second definition"
@ -309,49 +270,44 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_empty_strings(self, agent_extractor):
"""Test processing with empty strings in data"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{"entity": "", "definition": "Definition for empty entity"},
{"entity": "Valid Entity", "definition": ""},
{"entity": " ", "definition": " "}, # Whitespace only
],
"relationships": [
{"subject": "", "predicate": "test", "object": "test", "object-entity": True},
{"subject": "test", "predicate": "", "object": "test", "object-entity": True},
{"subject": "test", "predicate": "test", "object": "", "object-entity": True},
]
}
data = [
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
{"type": "definition", "entity": "Valid Entity", "definition": ""},
{"type": "definition", "entity": " ", "definition": " "}, # Whitespace only
{"type": "relationship", "subject": "", "predicate": "test", "object": "test", "object-entity": True},
{"type": "relationship", "subject": "test", "predicate": "", "object": "test", "object-entity": True},
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle empty strings by creating URIs (even if empty)
assert len(contexts) == 3
# Empty entity should create empty URI after encoding
empty_entity_context = next((ec for ec in contexts if ec.entity.value == TRUSTGRAPH_ENTITIES), None)
empty_entity_context = next((ec for ec in contexts if ec.entity.iri == TRUSTGRAPH_ENTITIES), None)
assert empty_entity_context is not None
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
"""Test processing when definitions contain JSON-like strings"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [
{
"entity": "JSON Entity",
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
},
{
"entity": "Array Entity",
"definition": 'Contains array: [1, 2, 3, "string"]'
}
],
"relationships": []
}
data = [
{
"type": "definition",
"entity": "JSON Entity",
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
},
{
"type": "definition",
"entity": "Array Entity",
"definition": 'Contains array: [1, 2, 3, "string"]'
}
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should handle JSON strings in definitions without parsing them
assert len(contexts) == 2
assert '{"key": "value"' in contexts[0].context
@ -360,32 +316,29 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
"""Test processing with various boolean values for object-entity"""
metadata = Metadata(id="doc123", metadata=[])
data = {
"definitions": [],
"relationships": [
# Explicit True
{"subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
# Explicit False
{"subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
# Missing object-entity (should default to True based on code)
{"subject": "A", "predicate": "rel3", "object": "C"},
# String "true" (should be treated as truthy)
{"subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
# String "false" (should be treated as truthy in Python)
{"subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
# Number 0 (falsy)
{"subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
# Number 1 (truthy)
{"subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
]
}
data = [
# Explicit True
{"type": "relationship", "subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
# Explicit False
{"type": "relationship", "subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
# Missing object-entity (should default to True based on code)
{"type": "relationship", "subject": "A", "predicate": "rel3", "object": "C"},
# String "true" (should be treated as truthy)
{"type": "relationship", "subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
# String "false" (should be treated as truthy in Python)
{"type": "relationship", "subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
# Number 0 (falsy)
{"type": "relationship", "subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
# Number 1 (truthy)
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
]
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
# Should process all relationships
# Note: The current implementation has some logic issues that these tests document
assert len([t for t in triples if t.p.value != RDF_LABEL and t.p.value != SUBJECT_OF]) >= 7
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
@pytest.mark.asyncio
async def test_emit_empty_collections(self, agent_extractor):
@ -437,41 +390,40 @@ class TestAgentKgExtractionEdgeCases:
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
"""Test performance with large extraction datasets"""
metadata = Metadata(id="large-doc", metadata=[])
# Create large dataset
# Create large dataset in JSONL format
num_definitions = 1000
num_relationships = 2000
large_data = {
"definitions": [
{
"entity": f"Entity_{i:04d}",
"definition": f"Definition for entity {i} with some detailed explanation."
}
for i in range(num_definitions)
],
"relationships": [
{
"subject": f"Entity_{i % num_definitions:04d}",
"predicate": f"predicate_{i % 10}",
"object": f"Entity_{(i + 1) % num_definitions:04d}",
"object-entity": True
}
for i in range(num_relationships)
]
}
large_data = [
{
"type": "definition",
"entity": f"Entity_{i:04d}",
"definition": f"Definition for entity {i} with some detailed explanation."
}
for i in range(num_definitions)
] + [
{
"type": "relationship",
"subject": f"Entity_{i % num_definitions:04d}",
"predicate": f"predicate_{i % 10}",
"object": f"Entity_{(i + 1) % num_definitions:04d}",
"object-entity": True
}
for i in range(num_relationships)
]
import time
start_time = time.time()
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
end_time = time.time()
processing_time = end_time - start_time
# Should complete within reasonable time (adjust threshold as needed)
assert processing_time < 10.0 # 10 seconds threshold
# Verify results
assert len(contexts) == num_definitions
# Triples include labels, definitions, relationships, and subject-of relations

View file

@ -7,7 +7,7 @@ processing graph structures, and performing graph operations.
import pytest
from unittest.mock import Mock
from .conftest import Triple, Value, Metadata
from .conftest import Triple, Metadata
from collections import defaultdict, deque

View file

@ -76,7 +76,7 @@ def cities_schema():
def validator():
"""Create a mock processor with just the validation method"""
from unittest.mock import MagicMock
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
# Create a mock processor
mock_processor = MagicMock()

View file

@ -2,13 +2,13 @@
Unit tests for triple construction logic
Tests the core business logic for constructing RDF triples from extracted
entities and relationships, including URI generation, Value object creation,
entities and relationships, including URI generation, Term object creation,
and triple validation.
"""
import pytest
from unittest.mock import Mock
from .conftest import Triple, Triples, Value, Metadata
from .conftest import Triple, Triples, Term, Metadata, IRI, LITERAL
import re
import hashlib
@ -48,80 +48,82 @@ class TestTripleConstructionLogic:
generated_uri = generate_uri(text, entity_type)
assert generated_uri == expected_uri, f"URI generation failed for '{text}'"
def test_value_object_creation(self):
"""Test creation of Value objects for subjects, predicates, and objects"""
def test_term_object_creation(self):
"""Test creation of Term objects for subjects, predicates, and objects"""
# Arrange
def create_value_object(text, is_uri, value_type=""):
return Value(
value=text,
is_uri=is_uri,
type=value_type
)
def create_term_object(text, is_uri, datatype=""):
if is_uri:
return Term(type=IRI, iri=text)
else:
return Term(type=LITERAL, value=text, datatype=datatype if datatype else None)
test_cases = [
("http://trustgraph.ai/kg/person/john-smith", True, ""),
("John Smith", False, "string"),
("42", False, "integer"),
("http://schema.org/worksFor", True, "")
]
# Act & Assert
for value_text, is_uri, value_type in test_cases:
value_obj = create_value_object(value_text, is_uri, value_type)
assert isinstance(value_obj, Value)
assert value_obj.value == value_text
assert value_obj.is_uri == is_uri
assert value_obj.type == value_type
for value_text, is_uri, datatype in test_cases:
term_obj = create_term_object(value_text, is_uri, datatype)
assert isinstance(term_obj, Term)
if is_uri:
assert term_obj.type == IRI
assert term_obj.iri == value_text
else:
assert term_obj.type == LITERAL
assert term_obj.value == value_text
def test_triple_construction_from_relationship(self):
"""Test constructing Triple objects from relationships"""
# Arrange
relationship = {
"subject": "John Smith",
"predicate": "works_for",
"predicate": "works_for",
"object": "OpenAI",
"subject_type": "PERSON",
"object_type": "ORG"
}
def construct_triple(relationship, uri_base="http://trustgraph.ai/kg"):
# Generate URIs
subject_uri = f"{uri_base}/person/{relationship['subject'].lower().replace(' ', '-')}"
object_uri = f"{uri_base}/org/{relationship['object'].lower().replace(' ', '-')}"
# Map predicate to schema.org URI
predicate_mappings = {
"works_for": "http://schema.org/worksFor",
"located_in": "http://schema.org/location",
"developed": "http://schema.org/creator"
}
predicate_uri = predicate_mappings.get(relationship["predicate"],
predicate_uri = predicate_mappings.get(relationship["predicate"],
f"{uri_base}/predicate/{relationship['predicate']}")
# Create Value objects
subject_value = Value(value=subject_uri, is_uri=True, type="")
predicate_value = Value(value=predicate_uri, is_uri=True, type="")
object_value = Value(value=object_uri, is_uri=True, type="")
# Create Term objects
subject_term = Term(type=IRI, iri=subject_uri)
predicate_term = Term(type=IRI, iri=predicate_uri)
object_term = Term(type=IRI, iri=object_uri)
# Create Triple
return Triple(
s=subject_value,
p=predicate_value,
o=object_value
s=subject_term,
p=predicate_term,
o=object_term
)
# Act
triple = construct_triple(relationship)
# Assert
assert isinstance(triple, Triple)
assert triple.s.value == "http://trustgraph.ai/kg/person/john-smith"
assert triple.s.is_uri is True
assert triple.p.value == "http://schema.org/worksFor"
assert triple.p.is_uri is True
assert triple.o.value == "http://trustgraph.ai/kg/org/openai"
assert triple.o.is_uri is True
assert triple.s.iri == "http://trustgraph.ai/kg/person/john-smith"
assert triple.s.type == IRI
assert triple.p.iri == "http://schema.org/worksFor"
assert triple.p.type == IRI
assert triple.o.iri == "http://trustgraph.ai/kg/org/openai"
assert triple.o.type == IRI
def test_literal_value_handling(self):
"""Test handling of literal values vs URI values"""
@ -132,10 +134,10 @@ class TestTripleConstructionLogic:
("John Smith", "email", "john@example.com", False), # Literal email
("John Smith", "worksFor", "http://trustgraph.ai/kg/org/openai", True) # URI reference
]
def create_triple_with_literal(subject_uri, predicate, object_value, object_is_uri):
subject_val = Value(value=subject_uri, is_uri=True, type="")
subject_term = Term(type=IRI, iri=subject_uri)
# Determine predicate URI
predicate_mappings = {
"name": "http://schema.org/name",
@ -144,32 +146,37 @@ class TestTripleConstructionLogic:
"worksFor": "http://schema.org/worksFor"
}
predicate_uri = predicate_mappings.get(predicate, f"http://trustgraph.ai/kg/predicate/{predicate}")
predicate_val = Value(value=predicate_uri, is_uri=True, type="")
# Create object value with appropriate type
object_type = ""
if not object_is_uri:
predicate_term = Term(type=IRI, iri=predicate_uri)
# Create object term with appropriate type
if object_is_uri:
object_term = Term(type=IRI, iri=object_value)
else:
datatype = None
if predicate == "age":
object_type = "integer"
datatype = "integer"
elif predicate in ["name", "email"]:
object_type = "string"
object_val = Value(value=object_value, is_uri=object_is_uri, type=object_type)
return Triple(s=subject_val, p=predicate_val, o=object_val)
datatype = "string"
object_term = Term(type=LITERAL, value=object_value, datatype=datatype)
return Triple(s=subject_term, p=predicate_term, o=object_term)
# Act & Assert
for subject_uri, predicate, object_value, object_is_uri in test_data:
subject_full_uri = "http://trustgraph.ai/kg/person/john-smith"
triple = create_triple_with_literal(subject_full_uri, predicate, object_value, object_is_uri)
assert triple.o.is_uri == object_is_uri
assert triple.o.value == object_value
if object_is_uri:
assert triple.o.type == IRI
assert triple.o.iri == object_value
else:
assert triple.o.type == LITERAL
assert triple.o.value == object_value
if predicate == "age":
assert triple.o.type == "integer"
assert triple.o.datatype == "integer"
elif predicate in ["name", "email"]:
assert triple.o.type == "string"
assert triple.o.datatype == "string"
def test_namespace_management(self):
"""Test namespace prefix management and expansion"""
@ -216,63 +223,74 @@ class TestTripleConstructionLogic:
def test_triple_validation(self):
"""Test triple validation rules"""
# Arrange
def get_term_value(term):
"""Extract value from a Term"""
if term.type == IRI:
return term.iri
else:
return term.value
def validate_triple(triple):
errors = []
# Check required components
if not triple.s or not triple.s.value:
s_val = get_term_value(triple.s) if triple.s else None
p_val = get_term_value(triple.p) if triple.p else None
o_val = get_term_value(triple.o) if triple.o else None
if not triple.s or not s_val:
errors.append("Missing or empty subject")
if not triple.p or not triple.p.value:
if not triple.p or not p_val:
errors.append("Missing or empty predicate")
if not triple.o or not triple.o.value:
if not triple.o or not o_val:
errors.append("Missing or empty object")
# Check URI validity for URI values
uri_pattern = r'^https?://[^\s/$.?#].[^\s]*$'
if triple.s.is_uri and not re.match(uri_pattern, triple.s.value):
if triple.s.type == IRI and not re.match(uri_pattern, triple.s.iri or ""):
errors.append("Invalid subject URI format")
if triple.p.is_uri and not re.match(uri_pattern, triple.p.value):
if triple.p.type == IRI and not re.match(uri_pattern, triple.p.iri or ""):
errors.append("Invalid predicate URI format")
if triple.o.is_uri and not re.match(uri_pattern, triple.o.value):
if triple.o.type == IRI and not re.match(uri_pattern, triple.o.iri or ""):
errors.append("Invalid object URI format")
# Predicates should typically be URIs
if not triple.p.is_uri:
if triple.p.type != IRI:
errors.append("Predicate should be a URI")
return len(errors) == 0, errors
# Test valid triple
valid_triple = Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John Smith", is_uri=False, type="string")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John Smith", datatype="string")
)
# Test invalid triples
invalid_triples = [
Triple(s=Value(value="", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John", is_uri=False, type="")), # Empty subject
Triple(s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="name", is_uri=False, type=""), # Non-URI predicate
o=Value(value="John", is_uri=False, type="")),
Triple(s=Value(value="invalid-uri", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John", is_uri=False, type="")) # Invalid URI format
Triple(s=Term(type=IRI, iri=""),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John")), # Empty subject
Triple(s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=LITERAL, value="name"), # Non-URI predicate
o=Term(type=LITERAL, value="John")),
Triple(s=Term(type=IRI, iri="invalid-uri"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John")) # Invalid URI format
]
# Act & Assert
is_valid, errors = validate_triple(valid_triple)
assert is_valid, f"Valid triple failed validation: {errors}"
for invalid_triple in invalid_triples:
is_valid, errors = validate_triple(invalid_triple)
assert not is_valid, f"Invalid triple passed validation: {invalid_triple}"
@ -286,97 +304,97 @@ class TestTripleConstructionLogic:
{"text": "OpenAI", "type": "ORG"},
{"text": "San Francisco", "type": "PLACE"}
]
relationships = [
{"subject": "John Smith", "predicate": "works_for", "object": "OpenAI"},
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"}
]
def construct_triple_batch(entities, relationships, document_id="doc-1"):
triples = []
# Create type triples for entities
for entity in entities:
entity_uri = f"http://trustgraph.ai/kg/{entity['type'].lower()}/{entity['text'].lower().replace(' ', '-')}"
type_uri = f"http://trustgraph.ai/kg/type/{entity['type']}"
type_triple = Triple(
s=Value(value=entity_uri, is_uri=True, type=""),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True, type=""),
o=Value(value=type_uri, is_uri=True, type="")
s=Term(type=IRI, iri=entity_uri),
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
o=Term(type=IRI, iri=type_uri)
)
triples.append(type_triple)
# Create relationship triples
for rel in relationships:
subject_uri = f"http://trustgraph.ai/kg/entity/{rel['subject'].lower().replace(' ', '-')}"
object_uri = f"http://trustgraph.ai/kg/entity/{rel['object'].lower().replace(' ', '-')}"
predicate_uri = f"http://schema.org/{rel['predicate'].replace('_', '')}"
rel_triple = Triple(
s=Value(value=subject_uri, is_uri=True, type=""),
p=Value(value=predicate_uri, is_uri=True, type=""),
o=Value(value=object_uri, is_uri=True, type="")
s=Term(type=IRI, iri=subject_uri),
p=Term(type=IRI, iri=predicate_uri),
o=Term(type=IRI, iri=object_uri)
)
triples.append(rel_triple)
return triples
# Act
triples = construct_triple_batch(entities, relationships)
# Assert
assert len(triples) == len(entities) + len(relationships) # Type triples + relationship triples
# Check that all triples are valid Triple objects
for triple in triples:
assert isinstance(triple, Triple)
assert triple.s.value != ""
assert triple.p.value != ""
assert triple.o.value != ""
assert triple.s.iri != ""
assert triple.p.iri != ""
assert triple.o.iri != ""
def test_triples_batch_object_creation(self):
"""Test creating Triples batch objects with metadata"""
# Arrange
sample_triples = [
Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/name", is_uri=True, type=""),
o=Value(value="John Smith", is_uri=False, type="string")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/name"),
o=Term(type=LITERAL, value="John Smith", datatype="string")
),
Triple(
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
p=Value(value="http://schema.org/worksFor", is_uri=True, type=""),
o=Value(value="http://trustgraph.ai/kg/org/openai", is_uri=True, type="")
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
p=Term(type=IRI, iri="http://schema.org/worksFor"),
o=Term(type=IRI, iri="http://trustgraph.ai/kg/org/openai")
)
]
metadata = Metadata(
id="test-doc-123",
user="test_user",
user="test_user",
collection="test_collection",
metadata=[]
)
# Act
triples_batch = Triples(
metadata=metadata,
triples=sample_triples
)
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2
# Check that triples are properly embedded
for triple in triples_batch.triples:
assert isinstance(triple, Triple)
assert isinstance(triple.s, Value)
assert isinstance(triple.p, Value)
assert isinstance(triple.o, Value)
assert isinstance(triple.s, Term)
assert isinstance(triple.p, Term)
assert isinstance(triple.o, Term)
def test_uri_collision_handling(self):
"""Test handling of URI collisions and duplicate detection"""

View file

@ -339,7 +339,250 @@ class TestPromptManager:
"""Test PromptManager with minimal configuration"""
pm = PromptManager()
pm.load_config({}) # Empty config
assert pm.config.system_template == "Be helpful." # Default system
assert pm.terms == {} # Default empty terms
assert len(pm.prompts) == 0
assert len(pm.prompts) == 0
@pytest.mark.unit
class TestPromptManagerJsonl:
"""Unit tests for PromptManager JSONL functionality"""
@pytest.fixture
def jsonl_config(self):
"""Configuration with JSONL response type prompts"""
return {
"system": json.dumps("You are an extraction assistant."),
"template-index": json.dumps(["extract_simple", "extract_with_schema", "extract_mixed"]),
"template.extract_simple": json.dumps({
"prompt": "Extract entities from: {{ text }}",
"response-type": "jsonl"
}),
"template.extract_with_schema": json.dumps({
"prompt": "Extract definitions from: {{ text }}",
"response-type": "jsonl",
"schema": {
"type": "object",
"properties": {
"entity": {"type": "string"},
"definition": {"type": "string"}
},
"required": ["entity", "definition"]
}
}),
"template.extract_mixed": json.dumps({
"prompt": "Extract knowledge from: {{ text }}",
"response-type": "jsonl",
"schema": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {"const": "definition"},
"entity": {"type": "string"},
"definition": {"type": "string"}
},
"required": ["type", "entity", "definition"]
},
{
"type": "object",
"properties": {
"type": {"const": "relationship"},
"subject": {"type": "string"},
"predicate": {"type": "string"},
"object": {"type": "string"}
},
"required": ["type", "subject", "predicate", "object"]
}
]
}
})
}
@pytest.fixture
def prompt_manager(self, jsonl_config):
"""Create a PromptManager with JSONL configuration"""
pm = PromptManager()
pm.load_config(jsonl_config)
return pm
def test_parse_jsonl_basic(self, prompt_manager):
"""Test basic JSONL parsing"""
text = '{"entity": "cat", "definition": "A small furry animal"}\n{"entity": "dog", "definition": "A loyal pet"}'
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
assert result[0]["entity"] == "cat"
assert result[1]["entity"] == "dog"
def test_parse_jsonl_with_empty_lines(self, prompt_manager):
"""Test JSONL parsing skips empty lines"""
text = '{"entity": "cat"}\n\n\n{"entity": "dog"}\n'
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
def test_parse_jsonl_with_markdown_fences(self, prompt_manager):
"""Test JSONL parsing strips markdown code fences"""
text = '''```json
{"entity": "cat", "definition": "A furry animal"}
{"entity": "dog", "definition": "A loyal pet"}
```'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
assert result[0]["entity"] == "cat"
assert result[1]["entity"] == "dog"
def test_parse_jsonl_with_jsonl_fence(self, prompt_manager):
"""Test JSONL parsing strips jsonl-marked code fences"""
text = '''```jsonl
{"entity": "cat"}
{"entity": "dog"}
```'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 2
def test_parse_jsonl_truncation_resilience(self, prompt_manager):
"""Test JSONL parsing handles truncated final line"""
text = '{"entity": "cat", "definition": "Complete"}\n{"entity": "dog", "defi'
result = prompt_manager.parse_jsonl(text)
# Should get the first valid object, skip the truncated one
assert len(result) == 1
assert result[0]["entity"] == "cat"
def test_parse_jsonl_invalid_lines_skipped(self, prompt_manager):
"""Test JSONL parsing skips invalid JSON lines"""
text = '''{"entity": "valid1"}
not json at all
{"entity": "valid2"}
{broken json
{"entity": "valid3"}'''
result = prompt_manager.parse_jsonl(text)
assert len(result) == 3
assert result[0]["entity"] == "valid1"
assert result[1]["entity"] == "valid2"
assert result[2]["entity"] == "valid3"
def test_parse_jsonl_empty_input(self, prompt_manager):
"""Test JSONL parsing with empty input"""
result = prompt_manager.parse_jsonl("")
assert result == []
result = prompt_manager.parse_jsonl("\n\n\n")
assert result == []
@pytest.mark.asyncio
async def test_invoke_jsonl_response(self, prompt_manager):
"""Test invoking a prompt with JSONL response"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"entity": "photosynthesis", "definition": "Plant process"}\n{"entity": "mitosis", "definition": "Cell division"}'
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Biology text"},
mock_llm
)
assert isinstance(result, list)
assert len(result) == 2
assert result[0]["entity"] == "photosynthesis"
assert result[1]["entity"] == "mitosis"
@pytest.mark.asyncio
async def test_invoke_jsonl_with_schema_validation(self, prompt_manager):
"""Test JSONL response with schema validation"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"entity": "cat", "definition": "A pet"}\n{"entity": "dog", "definition": "Another pet"}'
result = await prompt_manager.invoke(
"extract_with_schema",
{"text": "Animal text"},
mock_llm
)
assert len(result) == 2
assert all("entity" in obj and "definition" in obj for obj in result)
@pytest.mark.asyncio
async def test_invoke_jsonl_schema_filters_invalid(self, prompt_manager):
"""Test JSONL schema validation filters out invalid objects"""
mock_llm = AsyncMock()
# Second object is missing required 'definition' field
mock_llm.return_value = '{"entity": "valid", "definition": "Has both fields"}\n{"entity": "invalid_missing_definition"}\n{"entity": "also_valid", "definition": "Complete"}'
result = await prompt_manager.invoke(
"extract_with_schema",
{"text": "Test text"},
mock_llm
)
# Only the two valid objects should be returned
assert len(result) == 2
assert result[0]["entity"] == "valid"
assert result[1]["entity"] == "also_valid"
@pytest.mark.asyncio
async def test_invoke_jsonl_mixed_types(self, prompt_manager):
"""Test JSONL with discriminated union schema (oneOf)"""
mock_llm = AsyncMock()
mock_llm.return_value = '''{"type": "definition", "entity": "DNA", "definition": "Genetic material"}
{"type": "relationship", "subject": "DNA", "predicate": "found_in", "object": "nucleus"}
{"type": "definition", "entity": "RNA", "definition": "Messenger molecule"}'''
result = await prompt_manager.invoke(
"extract_mixed",
{"text": "Biology text"},
mock_llm
)
assert len(result) == 3
# Check definitions
definitions = [r for r in result if r.get("type") == "definition"]
assert len(definitions) == 2
# Check relationships
relationships = [r for r in result if r.get("type") == "relationship"]
assert len(relationships) == 1
assert relationships[0]["subject"] == "DNA"
@pytest.mark.asyncio
async def test_invoke_jsonl_empty_result(self, prompt_manager):
"""Test JSONL response that yields no valid objects"""
mock_llm = AsyncMock()
mock_llm.return_value = "No JSON here at all"
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Test"},
mock_llm
)
assert result == []
@pytest.mark.asyncio
async def test_invoke_jsonl_without_schema(self, prompt_manager):
"""Test JSONL response without schema validation"""
mock_llm = AsyncMock()
mock_llm.return_value = '{"any": "structure"}\n{"completely": "different"}'
result = await prompt_manager.invoke(
"extract_simple",
{"text": "Test"},
mock_llm
)
assert len(result) == 2
assert result[0] == {"any": "structure"}
assert result[1] == {"completely": "different"}

View file

@ -167,7 +167,7 @@ class TestFlowClient:
expected_methods = [
'text_completion', 'agent', 'graph_rag', 'document_rag',
'graph_embeddings_query', 'embeddings', 'prompt',
'triples_query', 'objects_query'
'triples_query', 'rows_query'
]
for method in expected_methods:
@ -216,7 +216,7 @@ class TestSocketClient:
expected_methods = [
'agent', 'text_completion', 'graph_rag', 'document_rag',
'prompt', 'graph_embeddings_query', 'embeddings',
'triples_query', 'objects_query', 'mcp_tool'
'triples_query', 'rows_query', 'mcp_tool'
]
for method in expected_methods:
@ -243,7 +243,7 @@ class TestBulkClient:
'import_graph_embeddings',
'import_document_embeddings',
'import_entity_contexts',
'import_objects'
'import_rows'
]
for method in import_methods:

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.graph_embeddings.milvus.service import Processor
from trustgraph.schema import Value, GraphEmbeddingsRequest
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
class TestMilvusGraphEmbeddingsQueryProcessor:
@ -68,50 +68,50 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
)
# Verify results are converted to Value objects
# Verify results are converted to Term objects
assert len(result) == 3
assert isinstance(result[0], Value)
assert result[0].value == "http://example.com/entity1"
assert result[0].is_uri is True
assert isinstance(result[1], Value)
assert result[1].value == "http://example.com/entity2"
assert result[1].is_uri is True
assert isinstance(result[2], Value)
assert isinstance(result[0], Term)
assert result[0].iri == "http://example.com/entity1"
assert result[0].type == IRI
assert isinstance(result[1], Term)
assert result[1].iri == "http://example.com/entity2"
assert result[1].type == IRI
assert isinstance(result[2], Term)
assert result[2].value == "literal entity"
assert result[2].is_uri is False
assert result[2].type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor):
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results are deduplicated and limited
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
assert "http://example.com/entity3" in entity_values
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify duplicates are removed
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert len(set(entity_values)) == 3 # All unique
assert "http://example.com/entity1" in entity_values
assert "http://example.com/entity2" in entity_values
@ -346,14 +346,14 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
assert len(result) == 4
# Check URI entities
uri_results = [r for r in result if r.is_uri]
uri_results = [r for r in result if r.type == IRI]
assert len(uri_results) == 2
uri_values = [r.value for r in uri_results]
uri_values = [r.iri for r in uri_results]
assert "http://example.com/uri_entity" in uri_values
assert "https://example.com/another_uri" in uri_values
# Check literal entities
literal_results = [r for r in result if not r.is_uri]
literal_results = [r for r in result if not r.type == IRI]
assert len(literal_results) == 2
literal_values = [r.value for r in literal_results]
assert "literal entity text" in literal_values
@ -486,7 +486,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
# Verify results from all dimensions
assert len(result) == 3
entity_values = [r.value for r in result]
entity_values = [r.iri if r.type == IRI else r.value for r in result]
assert "entity_2d" in entity_values
assert "entity_4d" in entity_values
assert "entity_3d" in entity_values

View file

@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
from trustgraph.query.graph_embeddings.pinecone.service import Processor
from trustgraph.schema import Value
from trustgraph.schema import Term, IRI, LITERAL
class TestPineconeGraphEmbeddingsQueryProcessor:
@ -105,27 +105,27 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
uri_entity = "http://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == uri_entity
assert value.is_uri == True
assert value.type == IRI
def test_create_value_https_uri(self, processor):
"""Test create_value method for HTTPS URI entities"""
uri_entity = "https://example.org/entity"
value = processor.create_value(uri_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == uri_entity
assert value.is_uri == True
assert value.type == IRI
def test_create_value_literal(self, processor):
"""Test create_value method for literal entities"""
literal_entity = "literal_entity"
value = processor.create_value(literal_entity)
assert isinstance(value, Value)
assert isinstance(value, Term)
assert value.value == literal_entity
assert value.is_uri == False
assert value.type == LITERAL
@pytest.mark.asyncio
async def test_query_graph_embeddings_single_vector(self, processor):
@ -165,11 +165,11 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
# Verify results
assert len(entities) == 3
assert entities[0].value == 'http://example.org/entity1'
assert entities[0].is_uri == True
assert entities[0].type == IRI
assert entities[1].value == 'entity2'
assert entities[1].is_uri == False
assert entities[1].type == LITERAL
assert entities[2].value == 'http://example.org/entity3'
assert entities[2].is_uri == True
assert entities[2].type == IRI
@pytest.mark.asyncio
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.query.graph_embeddings.qdrant.service import Processor
from trustgraph.schema import IRI, LITERAL
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
@ -85,10 +86,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
value = processor.create_value('http://example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'http://example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
assert hasattr(value, 'iri')
assert value.iri == 'http://example.com/entity'
assert hasattr(value, 'type')
assert value.type == IRI
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -109,10 +110,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
value = processor.create_value('https://secure.example.com/entity')
# Assert
assert hasattr(value, 'value')
assert value.value == 'https://secure.example.com/entity'
assert hasattr(value, 'is_uri')
assert value.is_uri == True
assert hasattr(value, 'iri')
assert value.iri == 'https://secure.example.com/entity'
assert hasattr(value, 'type')
assert value.type == IRI
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -135,8 +136,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
# Assert
assert hasattr(value, 'value')
assert value.value == 'regular entity name'
assert hasattr(value, 'is_uri')
assert value.is_uri == False
assert hasattr(value, 'type')
assert value.type == LITERAL
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
@ -428,14 +429,14 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
assert len(result) == 3
# Check URI entities
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
uri_entities = [entity for entity in result if entity.type == IRI]
assert len(uri_entities) == 2
uri_values = [entity.value for entity in uri_entities]
uri_values = [entity.iri for entity in uri_entities]
assert 'http://example.com/entity1' in uri_values
assert 'https://secure.example.com/entity2' in uri_values
# Check regular entities
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
regular_entities = [entity for entity in result if entity.type == LITERAL]
assert len(regular_entities) == 1
assert regular_entities[0].value == 'regular entity'

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestMemgraphQueryUserCollectionIsolation:
@ -24,9 +24,9 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="test_object"),
limit=1000
)
@ -65,8 +65,8 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=1000
)
@ -105,9 +105,9 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=Value(value="http://example.com/o", is_uri=True),
o=Term(type=IRI, iri="http://example.com/o"),
limit=1000
)
@ -145,7 +145,7 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -185,8 +185,8 @@ class TestMemgraphQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="literal"),
limit=1000
)
@ -225,7 +225,7 @@ class TestMemgraphQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=1000
)
@ -265,7 +265,7 @@ class TestMemgraphQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False),
o=Term(type=LITERAL, value="test_value"),
limit=1000
)
@ -355,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -385,7 +385,7 @@ class TestMemgraphQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None,
limit=1000
@ -416,17 +416,17 @@ class TestMemgraphQueryUserCollectionIsolation:
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].s.iri == "http://example.com/s"
assert result[0].s.type == IRI
assert result[0].p.iri == "http://example.com/p1"
assert result[0].p.type == IRI
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
assert result[0].o.type == LITERAL
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True
assert result[1].s.iri == "http://example.com/s"
assert result[1].s.type == IRI
assert result[1].p.iri == "http://example.com/p2"
assert result[1].p.type == IRI
assert result[1].o.iri == "http://example.com/o"
assert result[1].o.type == IRI

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
class TestNeo4jQueryUserCollectionIsolation:
@ -24,21 +24,23 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="test_object", is_uri=False)
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="test_object"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SPO query for literal includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN $src as src"
"RETURN $src as src "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -63,23 +65,25 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=None
s=Term(type=IRI, iri="http://example.com/s"),
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SP query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN dest.value as dest"
"RETURN dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
src="http://example.com/s",
@ -88,13 +92,14 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
database_='neo4j'
)
# Verify SP query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
"RETURN dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -118,21 +123,23 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=Value(value="http://example.com/o", is_uri=True)
o=Term(type=IRI, iri="http://example.com/o"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify SO query for nodes includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"RETURN rel.uri as rel"
"RETURN rel.uri as rel "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -156,23 +163,25 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify S query includes user/collection
expected_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest"
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_query,
src="http://example.com/s",
@ -194,20 +203,22 @@ class TestNeo4jQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="literal", is_uri=False)
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="literal"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify PO query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src"
"RETURN src.uri as src "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -232,20 +243,22 @@ class TestNeo4jQueryUserCollectionIsolation:
user="test_user",
collection="test_collection",
s=None,
p=Value(value="http://example.com/p", is_uri=True),
o=None
p=Term(type=IRI, iri="http://example.com/p"),
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify P query includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest"
"RETURN src.uri as src, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -270,19 +283,21 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=Value(value="test_value", is_uri=False)
o=Term(type=LITERAL, value="test_value"),
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify O query for literals includes user/collection
expected_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel"
"RETURN src.uri as src, rel.uri as rel "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -307,34 +322,37 @@ class TestNeo4jQueryUserCollectionIsolation:
collection="test_collection",
s=None,
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
await processor.query_triples(query)
# Verify wildcard query for literals includes user/collection
expected_literal_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
expected_literal_query,
user="test_user",
collection="test_collection",
database_='neo4j'
)
# Verify wildcard query for nodes includes user/collection
expected_node_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest"
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT 10"
)
mock_driver.execute_query.assert_any_call(
@ -355,9 +373,10 @@ class TestNeo4jQueryUserCollectionIsolation:
# Query without user/collection fields
query = TriplesQueryRequest(
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
@ -384,47 +403,48 @@ class TestNeo4jQueryUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/s", is_uri=True),
s=Term(type=IRI, iri="http://example.com/s"),
p=None,
o=None
o=None,
limit=10
)
# Mock some results
mock_record1 = MagicMock()
mock_record1.data.return_value = {
"rel": "http://example.com/p1",
"dest": "literal_value"
}
mock_record2 = MagicMock()
mock_record2.data.return_value = {
"rel": "http://example.com/p2",
"dest": "http://example.com/o"
}
# Return results for literal query, empty for node query
mock_driver.execute_query.side_effect = [
([mock_record1], MagicMock(), MagicMock()), # Literal query
([mock_record2], MagicMock(), MagicMock()) # Node query
]
result = await processor.query_triples(query)
# Verify results are proper Triple objects
assert len(result) == 2
# First triple (literal object)
assert result[0].s.value == "http://example.com/s"
assert result[0].s.is_uri == True
assert result[0].p.value == "http://example.com/p1"
assert result[0].p.is_uri == True
assert result[0].s.iri == "http://example.com/s"
assert result[0].s.type == IRI
assert result[0].p.iri == "http://example.com/p1"
assert result[0].p.type == IRI
assert result[0].o.value == "literal_value"
assert result[0].o.is_uri == False
assert result[0].o.type == LITERAL
# Second triple (URI object)
assert result[1].s.value == "http://example.com/s"
assert result[1].s.is_uri == True
assert result[1].p.value == "http://example.com/p2"
assert result[1].p.is_uri == True
assert result[1].o.value == "http://example.com/o"
assert result[1].o.is_uri == True
assert result[1].s.iri == "http://example.com/s"
assert result[1].s.type == IRI
assert result[1].p.iri == "http://example.com/p2"
assert result[1].p.type == IRI
assert result[1].o.iri == "http://example.com/o"
assert result[1].o.type == IRI

View file

@ -1,10 +1,11 @@
"""
Unit tests for Cassandra Objects GraphQL Query Processor
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
Tests the business logic of the GraphQL query processor including:
- GraphQL schema generation from RowSchema
- Query execution and validation
- CQL translation logic
- Schema configuration handling
- Query execution using unified rows table
- Name sanitization
- GraphQL query execution
- Message processing logic
"""
@ -12,119 +13,91 @@ import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import strawberry
from strawberry import Schema
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field
class TestObjectsGraphQLQueryLogic:
"""Test business logic without external dependencies"""
def test_get_python_type_mapping(self):
"""Test schema field type conversion to Python types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_python_type("string") == str
assert processor.get_python_type("integer") == int
assert processor.get_python_type("float") == float
assert processor.get_python_type("boolean") == bool
assert processor.get_python_type("timestamp") == str
assert processor.get_python_type("date") == str
assert processor.get_python_type("time") == str
assert processor.get_python_type("uuid") == str
# Unknown type defaults to str
assert processor.get_python_type("unknown_type") == str
def test_create_graphql_type_basic_fields(self):
"""Test GraphQL type creation for basic field types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="test_table",
description="Test table",
fields=[
Field(
name="id",
type="string",
primary=True,
required=True,
description="Primary key"
),
Field(
name="name",
type="string",
required=True,
description="Name field"
),
Field(
name="age",
type="integer",
required=False,
description="Optional age"
),
Field(
name="active",
type="boolean",
required=False,
description="Status flag"
)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test_table", schema)
# Verify type was created
assert graphql_type is not None
assert hasattr(graphql_type, '__name__')
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
class TestRowsGraphQLQueryLogic:
"""Test business logic for unified table query implementation"""
def test_sanitize_name_cassandra_compatibility(self):
"""Test name sanitization for Cassandra field names"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test field name sanitization (matches storage processor)
# Test field name sanitization (uses r_ prefix like storage processor)
assert processor.sanitize_name("simple_field") == "simple_field"
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
assert processor.sanitize_name("123_field") == "o_123_field"
assert processor.sanitize_name("123_field") == "r_123_field"
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
assert processor.sanitize_name("special!@#chars") == "special___chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
def test_sanitize_table_name(self):
"""Test table name sanitization (always gets o_ prefix)"""
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Table names always get o_ prefix
assert processor.sanitize_table("simple_table") == "o_simple_table"
assert processor.sanitize_table("Table-Name") == "o_table_name"
assert processor.sanitize_table("123table") == "o_123table"
assert processor.sanitize_table("") == "o_"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names
assert len(index_names) == 3
def test_find_matching_index_exact_match(self):
"""Test finding matching index for exact match query"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string") # Not indexed
]
)
# Filter on indexed field should return match
filters = {"category": "electronics"}
result = processor.find_matching_index(schema, filters)
assert result is not None
assert result[0] == "category"
assert result[1] == ["electronics"]
# Filter on non-indexed field should return None
filters = {"name": "test"}
result = processor.find_matching_index(schema, filters)
assert result is None
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.graphql_types = {}
processor.graphql_schema = None
processor.config_key = "schema" # Set the config key
processor.generate_graphql_schema = AsyncMock()
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
schema_config = {
"schema": {
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
})
}
}
# Process config
await processor.on_schema_config(schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
# Verify fields
id_field = next(f for f in schema.fields if f.name == "id")
assert id_field.primary is True
# The field should have been created correctly from JSON
# Let's test what we can verify - that the field has the right attributes
assert hasattr(id_field, 'required') # Has the required attribute
assert hasattr(id_field, 'primary') # Has the primary attribute
email_field = next(f for f in schema.fields if f.name == "email")
assert email_field.indexed is True
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify GraphQL schema regeneration was called
processor.generate_graphql_schema.assert_called_once()
def test_cql_query_building_basic(self):
"""Test basic CQL query construction"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to capture the query
mock_result = []
processor.session.execute.return_value = mock_result
# Create test schema
schema = RowSchema(
name="test_table",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string", indexed=True),
Field(name="status", type="string")
]
)
# Test query building
asyncio = pytest.importorskip("asyncio")
async def run_test():
await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="test_table",
row_schema=schema,
filters={"name": "John", "invalid_filter": "ignored"},
limit=10
)
# Run the async test
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_test())
finally:
loop.close()
# Verify Cassandra connection and query execution
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify the query structure (can't easily test exact query without complex mocking)
call_args = processor.session.execute.call_args
query = call_args[0][0] # First positional argument is the query
params = call_args[0][1] # Second positional argument is parameters
# Basic query structure checks
assert "SELECT * FROM test_user.o_test_table" in query
assert "WHERE" in query
assert "collection = %s" in query
assert "LIMIT 10" in query
# Parameters should include collection and name filter
assert "test_collection" in params
assert "John" in params
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { id name } }',
variables={},
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value'] # keyword argument
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["collection"] == "test_collection"
# Verify result structure
assert "data" in result
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error instead of MagicMock
# Create a simple object to simulate GraphQL error
class MockError:
def __init__(self, message, path, extensions):
self.message = message
self.path = path
self.extensions = extensions
def __str__(self):
return self.message
mock_error = MockError(
message="Field 'invalid_field' doesn't exist",
path=["customers", "0", "invalid_field"],
extensions={"code": "FIELD_NOT_FOUND"}
)
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
user="test_user",
collection="test_collection"
)
# Verify error handling
assert "errors" in result
assert len(result["errors"]) == 1
error = result["errors"][0]
assert error["message"] == "Field 'invalid_field' doesn't exist"
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
assert error["path"] == ["customers", "0", "invalid_field"]
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
def test_schema_generation_basic_structure(self):
"""Test basic GraphQL schema generation structure"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Test individual type creation (avoiding the full schema generation which has annotation issues)
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify type was created
assert len(processor.graphql_types) == 1
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None
@pytest.mark.asyncio
async def test_message_processing_success(self):
"""Test successful message processing flow"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock successful query result
processor.execute_graphql_query.return_value = {
"data": {"customers": [{"id": "1", "name": "John"}]},
"errors": [],
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
"extensions": {}
}
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
collection="test_collection",
query='{ customers { id name } }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-123"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
query='{ customers { id name } }',
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is None
assert '"customers"' in response_call.data # JSON encoded
assert len(response_call.errors) == 0
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock query execution error
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-456"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify error response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify error response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is not None
assert response_call.error.type == "objects-query-error"
assert response_call.error.type == "rows-query-error"
assert "No schema available" in response_call.error.message
assert response_call.data is None
class TestCQLQueryGeneration:
"""Test CQL query generation logic in isolation"""
def test_partition_key_inclusion(self):
"""Test that collection is always included in queries"""
class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Mock the query building (simplified version)
keyspace = processor.sanitize_name("test_user")
table = processor.sanitize_table("test_table")
query = f"SELECT * FROM {keyspace}.{table}"
where_clauses = ["collection = %s"]
assert "collection = %s" in where_clauses
assert keyspace == "test_user"
assert table == "o_test_table"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string")
]
)
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"category": "electronics"},
limit=10
)
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
assert "SELECT data, source FROM test_user.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
assert "index_value = %s" in query
assert params[0] == "test_collection"
assert params[1] == "products"
assert params[2] == "category"
assert params[3] == ["electronics"]
# Verify results
assert len(results) == 1
assert results[0]["id"] == "123"
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string"), # Not indexed
Field(name="price", type="string") # Not indexed
]
)
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"name": "Product A"},
limit=10
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
assert "ALLOW FILTERING" in query
# Should post-filter results
assert len(results) == 1
assert results[0]["name"] == "Product A"
class TestFilterMatching:
"""Test filter matching logic"""
def test_matches_filters_exact_match(self):
"""Test exact match filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active", "name": "test"}
assert processor._matches_filters(row, {"status": "active"}, schema) is True
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
def test_matches_filters_comparison_operators(self):
"""Test comparison operators in filters"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
row = {"price": "100.0"}
# Greater than
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
# Less than
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
# Greater than or equal
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
# Less than or equal
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
def test_matches_filters_contains(self):
"""Test contains filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
row = {"description": "A great product for everyone"}
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
def test_matches_filters_in_list(self):
"""Test in-list filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active"}
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
class TestIndexedFieldFiltering:
"""Test that only indexed or primary key fields can be directly filtered"""
def test_indexed_field_filtering(self):
"""Test that only indexed or primary key fields can be filtered"""
# Create schema with mixed field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="normal_field", type="string", indexed=False),
Field(name="another_field", type="string")
]
)
filters = {
"id": "test123", # Primary key - should be included
"indexed_field": "value", # Indexed - should be included
"normal_field": "ignored", # Not indexed - should be ignored
"another_field": "also_ignored" # Not indexed - should be ignored
}
# Simulate the filtering logic from the processor
valid_filters = []
for field_name, value in filters.items():
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
schema_field = next((f for f in schema.fields if f.name == field_name), None)
if schema_field and (schema_field.indexed or schema_field.primary):
valid_filters.append((field_name, value))
# Only id and indexed_field should be included
assert len(valid_filters) == 2
field_names = [f[0] for f in valid_filters]
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
assert "indexed_field" in field_names
assert "normal_field" not in field_names
assert "another_field" not in field_names
class TestGraphQLSchemaGeneration:
"""Test GraphQL schema generation in detail"""
def test_field_type_annotations(self):
"""Test that GraphQL types have correct field annotations"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create schema with various field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", required=True, primary=True),
Field(name="count", type="integer", required=True),
Field(name="price", type="float", required=False),
Field(name="active", type="boolean", required=False),
Field(name="optional_text", type="string", required=False)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test", schema)
# Verify type was created successfully
assert graphql_type is not None
def test_basic_type_creation(self):
"""Test that GraphQL types are created correctly"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create GraphQL type directly
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify customer type was created
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None

View file

@ -5,8 +5,8 @@ Tests for Cassandra triples query service
import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.cassandra.service import Processor
from trustgraph.schema import Value
from trustgraph.query.triples.cassandra.service import Processor, create_term
from trustgraph.schema import Term, IRI, LITERAL
class TestCassandraQueryProcessor:
@ -21,94 +21,101 @@ class TestCassandraQueryProcessor:
graph_host='localhost'
)
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
def test_create_term_with_http_uri(self, processor):
"""Test create_term with HTTP URI"""
result = create_term("http://example.com/resource")
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
def test_create_term_with_https_uri(self, processor):
"""Test create_term with HTTPS URI"""
result = create_term("https://example.com/resource")
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_term_with_literal(self, processor):
"""Test create_term with literal value"""
result = create_term("just a literal string")
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
def test_create_term_with_empty_string(self, processor):
"""Test create_term with empty string"""
result = create_term("")
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
def test_create_term_with_partial_uri(self, processor):
"""Test create_term with string that looks like URI but isn't complete"""
result = create_term("http")
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
def test_create_term_with_ftp_uri(self, processor):
"""Test create_term with FTP URI (should not be detected as URI)"""
result = create_term("ftp://example.com/file")
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_spo_query(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_spo_query(self, mock_kg_class):
"""Test querying triples with subject, predicate, and object specified"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
# Setup mock TrustGraph via factory function
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results (with mock graph attribute)
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
cassandra_host='localhost'
)
# Create query request with all SPO values
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
result = await processor.query_triples(query)
# Verify KnowledgeGraph was created with correct parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
)
# Verify result contains the queried triple
assert len(result) == 1
assert result[0].s.value == 'test_subject'
@ -143,154 +150,174 @@ class TestCassandraQueryProcessor:
assert processor.table is None
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_kg_class):
"""Test SP query pattern (subject and predicate, no object)"""
from trustgraph.schema import TriplesQueryRequest, Value
# Setup mock TrustGraph and response
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
# Setup mock TrustGraph via factory function
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_sp.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=50
)
result = await processor.query_triples(query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_kg_class):
"""Test S query pattern (subject only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_s.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=25
)
result = await processor.query_triples(query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_kg_class):
"""Test P query pattern (predicate only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.o = 'result_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_p.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=10
)
result = await processor.query_triples(query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_kg_class):
"""Test O query pattern (object only)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.p = 'result_predicate'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_o.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=None,
p=None,
o=Value(value='test_object', is_uri=False),
o=Term(type=LITERAL, value='test_object'),
limit=75
)
result = await processor.query_triples(query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_kg_class):
"""Test query pattern with no constraints (get all)"""
from trustgraph.schema import TriplesQueryRequest
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'all_subject'
mock_result.p = 'all_predicate'
mock_result.o = 'all_object'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_all.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
@ -299,9 +326,9 @@ class TestCassandraQueryProcessor:
o=None,
limit=1000
)
result = await processor.query_triples(query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
@ -372,37 +399,44 @@ class TestCassandraQueryProcessor:
run()
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o, g) quad pattern, some values may be\nnull. Output is a list of quads.\n')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_kg_class):
"""Test querying with username and password authentication"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(
taskgroup=MagicMock(),
cassandra_username='authuser',
cassandra_password='authpass'
)
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
await processor.query_triples(query)
# Verify KnowledgeGraph was created with authentication
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='test_user',
username='authuser',
@ -410,128 +444,154 @@ class TestCassandraQueryProcessor:
)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_kg_class):
"""Test that TrustGraph is reused for same table"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_tg_instance.get_spo.return_value = None
mock_kg_class.return_value = mock_tg_instance
# SPO query returns a list of results
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
# First query should create TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1
assert mock_kg_class.call_count == 1
# Second query with same table should reuse TrustGraph
await processor.query_triples(query)
assert mock_trustgraph.call_count == 1 # Should not increase
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_switching(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_table_switching(self, mock_kg_class):
"""Test table switching creates new TrustGraph"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
# Setup mock results for both instances
mock_result = MagicMock()
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_result.p = 'p'
mock_result.o = 'o'
mock_tg_instance1.get_s.return_value = [mock_result]
mock_tg_instance2.get_s.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
# First query
query1 = TriplesQueryRequest(
user='user1',
collection='collection1',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=100
)
await processor.query_triples(query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
user='user2',
collection='collection2',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=None,
limit=100
)
await processor.query_triples(query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
assert mock_trustgraph.call_count == 2
assert mock_kg_class.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_kg_class):
"""Test exception handling during query execution"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=100
)
with pytest.raises(Exception, match="Query failed"):
await processor.query_triples(query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_kg_class):
"""Test query returning multiple results"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock multiple results
mock_result1 = MagicMock()
mock_result1.o = 'object1'
mock_result1.g = ''
mock_result1.otype = None
mock_result1.dtype = None
mock_result1.lang = None
mock_result2 = MagicMock()
mock_result2.o = 'object2'
mock_result2.g = ''
mock_result2.otype = None
mock_result2.dtype = None
mock_result2.lang = None
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
processor = Processor(taskgroup=MagicMock())
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
p=Value(value='test_predicate', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=Term(type=LITERAL, value='test_predicate'),
o=None,
limit=100
)
result = await processor.query_triples(query)
assert len(result) == 2
assert result[0].o.value == 'object1'
assert result[1].o.value == 'object2'
@ -541,16 +601,20 @@ class TestCassandraQueryPerformanceOptimizations:
"""Test cases for multi-table performance optimizations in query service"""
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_po_query_optimization(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_get_po_query_optimization(self, mock_kg_class):
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.s = 'result_subject'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_po.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
@ -560,8 +624,8 @@ class TestCassandraQueryPerformanceOptimizations:
user='test_user',
collection='test_collection',
s=None,
p=Value(value='test_predicate', is_uri=False),
o=Value(value='test_object', is_uri=False),
p=Term(type=LITERAL, value='test_predicate'),
o=Term(type=LITERAL, value='test_object'),
limit=50
)
@ -569,7 +633,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', limit=50
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
)
assert len(result) == 1
@ -578,16 +642,20 @@ class TestCassandraQueryPerformanceOptimizations:
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_get_os_query_optimization(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_get_os_query_optimization(self, mock_kg_class):
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
mock_result = MagicMock()
mock_result.p = 'result_predicate'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_os.return_value = [mock_result]
processor = Processor(taskgroup=MagicMock())
@ -596,9 +664,9 @@ class TestCassandraQueryPerformanceOptimizations:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value='test_subject', is_uri=False),
s=Term(type=LITERAL, value='test_subject'),
p=None,
o=Value(value='test_object', is_uri=False),
o=Term(type=LITERAL, value='test_object'),
limit=25
)
@ -606,7 +674,7 @@ class TestCassandraQueryPerformanceOptimizations:
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', limit=25
'test_collection', 'test_object', 'test_subject', g=None, limit=25
)
assert len(result) == 1
@ -615,13 +683,13 @@ class TestCassandraQueryPerformanceOptimizations:
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_all_query_patterns_use_correct_tables(self, mock_kg_class):
"""Test that all query patterns route to their optimal tables"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
@ -655,9 +723,9 @@ class TestCassandraQueryPerformanceOptimizations:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value=s, is_uri=False) if s else None,
p=Value(value=p, is_uri=False) if p else None,
o=Value(value=o, is_uri=False) if o else None,
s=Term(type=LITERAL, value=s) if s else None,
p=Term(type=LITERAL, value=p) if p else None,
o=Term(type=LITERAL, value=o) if o else None,
limit=10
)
@ -687,19 +755,23 @@ class TestCassandraQueryPerformanceOptimizations:
# Mode is determined in KnowledgeGraph initialization
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_performance_critical_po_query_no_filtering(self, mock_kg_class):
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
from trustgraph.schema import TriplesQueryRequest, Value
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
# Mock multiple subjects for the same predicate-object pair
mock_results = []
for i in range(5):
mock_result = MagicMock()
mock_result.s = f'subject_{i}'
mock_result.g = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
@ -711,8 +783,8 @@ class TestCassandraQueryPerformanceOptimizations:
user='large_dataset_user',
collection='massive_collection',
s=None,
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
o=Value(value='http://example.com/Person', is_uri=True),
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
o=Term(type=IRI, iri='http://example.com/Person'),
limit=1000
)
@ -723,14 +795,15 @@ class TestCassandraQueryPerformanceOptimizations:
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',
g=None,
limit=1000
)
# Verify all results were returned
assert len(result) == 5
for i, triple in enumerate(result):
assert triple.s.value == f'subject_{i}'
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.is_uri is True
assert triple.o.value == 'http://example.com/Person'
assert triple.o.is_uri is True
assert triple.s.value == f'subject_{i}' # Mock returns literal values
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
assert triple.p.type == IRI
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
assert triple.o.type == IRI

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.falkordb.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestFalkorDBQueryProcessor:
@ -25,50 +25,50 @@ class TestFalkorDBQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
def test_processor_initialization_with_defaults(self, mock_falkordb):
@ -125,9 +125,9 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -138,8 +138,8 @@ class TestFalkorDBQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -166,8 +166,8 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -211,9 +211,9 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -224,12 +224,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -256,7 +256,7 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
@ -269,13 +269,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -302,8 +302,8 @@ class TestFalkorDBQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -314,12 +314,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -347,7 +347,7 @@ class TestFalkorDBQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -359,13 +359,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -393,7 +393,7 @@ class TestFalkorDBQueryProcessor:
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -404,12 +404,12 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@ -449,13 +449,13 @@ class TestFalkorDBQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
@pytest.mark.asyncio
@ -476,7 +476,7 @@ class TestFalkorDBQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.memgraph.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestMemgraphQueryProcessor:
@ -25,50 +25,50 @@ class TestMemgraphQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
@ -124,9 +124,9 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -137,8 +137,8 @@ class TestMemgraphQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -166,8 +166,8 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -212,9 +212,9 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -225,12 +225,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different predicates
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -258,7 +258,7 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
@ -271,13 +271,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different predicate-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -305,8 +305,8 @@ class TestMemgraphQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -317,12 +317,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subjects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -351,7 +351,7 @@ class TestMemgraphQueryProcessor:
user='test_user',
collection='test_collection',
s=None,
p=Value(value="http://example.com/predicate", is_uri=True),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -363,13 +363,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subject-object pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -398,7 +398,7 @@ class TestMemgraphQueryProcessor:
collection='test_collection',
s=None,
p=None,
o=Value(value="literal object", is_uri=False),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -409,12 +409,12 @@ class TestMemgraphQueryProcessor:
# Verify results contain different subject-predicate pairs
assert len(result) == 2
assert result[0].s.value == "http://example.com/subj1"
assert result[0].p.value == "http://example.com/pred1"
assert result[0].s.iri == "http://example.com/subj1"
assert result[0].p.iri == "http://example.com/pred1"
assert result[0].o.value == "literal object"
assert result[1].s.value == "http://example.com/subj2"
assert result[1].p.value == "http://example.com/pred2"
assert result[1].s.iri == "http://example.com/subj2"
assert result[1].p.iri == "http://example.com/pred2"
assert result[1].o.value == "literal object"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@ -455,13 +455,13 @@ class TestMemgraphQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
@pytest.mark.asyncio
@ -480,7 +480,7 @@ class TestMemgraphQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.query.triples.neo4j.service import Processor
from trustgraph.schema import Value, TriplesQueryRequest
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
class TestNeo4jQueryProcessor:
@ -25,50 +25,50 @@ class TestNeo4jQueryProcessor:
def test_create_value_with_http_uri(self, processor):
"""Test create_value with HTTP URI"""
result = processor.create_value("http://example.com/resource")
assert isinstance(result, Value)
assert result.value == "http://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "http://example.com/resource"
assert result.type == IRI
def test_create_value_with_https_uri(self, processor):
"""Test create_value with HTTPS URI"""
result = processor.create_value("https://example.com/resource")
assert isinstance(result, Value)
assert result.value == "https://example.com/resource"
assert result.is_uri is True
assert isinstance(result, Term)
assert result.iri == "https://example.com/resource"
assert result.type == IRI
def test_create_value_with_literal(self, processor):
"""Test create_value with literal value"""
result = processor.create_value("just a literal string")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "just a literal string"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_empty_string(self, processor):
"""Test create_value with empty string"""
result = processor.create_value("")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == ""
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_partial_uri(self, processor):
"""Test create_value with string that looks like URI but isn't complete"""
result = processor.create_value("http")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "http"
assert result.is_uri is False
assert result.type == LITERAL
def test_create_value_with_ftp_uri(self, processor):
"""Test create_value with FTP URI (should not be detected as URI)"""
result = processor.create_value("ftp://example.com/file")
assert isinstance(result, Value)
assert isinstance(result, Term)
assert result.value == "ftp://example.com/file"
assert result.is_uri is False
assert result.type == LITERAL
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
def test_processor_initialization_with_defaults(self, mock_graph_db):
@ -124,9 +124,9 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal object", is_uri=False),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal object"),
limit=100
)
@ -137,8 +137,8 @@ class TestNeo4jQueryProcessor:
# Verify result contains the queried triple (appears twice - once from each query)
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal object"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@ -166,8 +166,8 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None,
limit=100
)
@ -179,13 +179,13 @@ class TestNeo4jQueryProcessor:
# Verify results contain different objects
assert len(result) == 2
assert result[0].s.value == "http://example.com/subject"
assert result[0].p.value == "http://example.com/predicate"
assert result[0].s.iri == "http://example.com/subject"
assert result[0].p.iri == "http://example.com/predicate"
assert result[0].o.value == "literal result"
assert result[1].s.value == "http://example.com/subject"
assert result[1].p.value == "http://example.com/predicate"
assert result[1].o.value == "http://example.com/uri_result"
assert result[1].s.iri == "http://example.com/subject"
assert result[1].p.iri == "http://example.com/predicate"
assert result[1].o.iri == "http://example.com/uri_result"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
@ -225,13 +225,13 @@ class TestNeo4jQueryProcessor:
# Verify results contain different triples
assert len(result) == 2
assert result[0].s.value == "http://example.com/s1"
assert result[0].p.value == "http://example.com/p1"
assert result[0].s.iri == "http://example.com/s1"
assert result[0].p.iri == "http://example.com/p1"
assert result[0].o.value == "literal1"
assert result[1].s.value == "http://example.com/s2"
assert result[1].p.value == "http://example.com/p2"
assert result[1].o.value == "http://example.com/o2"
assert result[1].s.iri == "http://example.com/s2"
assert result[1].p.iri == "http://example.com/p2"
assert result[1].o.iri == "http://example.com/o2"
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
@ -250,12 +250,12 @@ class TestNeo4jQueryProcessor:
query = TriplesQueryRequest(
user='test_user',
collection='test_collection',
s=Value(value="http://example.com/subject", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=None,
o=None,
limit=100
)
# Should raise the exception
with pytest.raises(Exception, match="Database connection failed"):
await processor.query_triples(query)

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=None,
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
# Verify objects query service was called correctly
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
assert response.error is not None
assert "empty GraphQL query" in response.error.message
async def test_objects_query_service_error(self, processor):
async def test_rows_query_service_error(self, processor):
"""Test handling of objects query service errors"""
# Arrange
request = StructuredQueryRequest(
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service error
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
data=None,
errors=None,
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
response = response_call[0][0]
assert response.error is not None
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "Table 'customers' not found" in response.error.message
async def test_graphql_errors_handling(self, processor):
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None,
errors=graphql_errors,
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
errors=None
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # Null data
errors=None,
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_password is None
class TestObjectsWriterConfiguration:
class TestRowsWriterConfiguration:
"""Test Cassandra configuration in objects writer processor."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_environment_variable_configuration(self, mock_cluster):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
assert processor.cassandra_username == 'obj-env-user'
assert processor.cassandra_password == 'obj-env-pass'
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
"""Test that Cassandra connection uses hosts list correctly."""
env_vars = {
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify cluster was called with hosts list
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
assert 'contact_points' in call_args.kwargs
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
"""Test authentication is configured when credentials are provided."""
env_vars = {
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify auth provider was created with correct credentials
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
def test_objects_writer_add_args(self):
"""Test that objects writer adds standard Cassandra arguments."""
import argparse
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
parser = argparse.ArgumentParser()
ObjectsWriter.add_args(parser)
RowsWriter.add_args(parser)
# Parse empty args to check that arguments exist
args = parser.parse_args([])

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.graph_embeddings.milvus.write import Processor
from trustgraph.schema import Value, EntityEmbeddings
from trustgraph.schema import Term, EntityEmbeddings, IRI, LITERAL
class TestMilvusGraphEmbeddingsStorageProcessor:
@ -22,11 +22,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
# Create test entities with embeddings
entity1 = EntityEmbeddings(
entity=Value(value='http://example.com/entity1', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity1'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
entity2 = EntityEmbeddings(
entity=Value(value='literal entity', is_uri=False),
entity=Term(type=LITERAL, value='literal entity'),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [entity1, entity2]
@ -84,7 +84,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
)
message.entities = [entity]
@ -136,7 +136,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
entity=Term(type=LITERAL, value=''),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
@ -155,7 +155,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
entity=Term(type=LITERAL, value=None),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
@ -174,15 +174,15 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
entity=Value(value='http://example.com/valid', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/valid'),
vectors=[[0.1, 0.2, 0.3]]
)
empty_entity = EntityEmbeddings(
entity=Value(value='', is_uri=False),
entity=Term(type=LITERAL, value=''),
vectors=[[0.4, 0.5, 0.6]]
)
none_entity = EntityEmbeddings(
entity=Value(value=None, is_uri=False),
entity=Term(type=LITERAL, value=None),
vectors=[[0.7, 0.8, 0.9]]
)
message.entities = [valid_entity, empty_entity, none_entity]
@ -217,7 +217,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[]
)
message.entities = [entity]
@ -236,7 +236,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value='http://example.com/entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/entity'),
vectors=[
[0.1, 0.2], # 2D vector
[0.3, 0.4, 0.5, 0.6], # 4D vector
@ -269,11 +269,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
entity=Value(value='http://example.com/uri_entity', is_uri=True),
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
vectors=[[0.1, 0.2, 0.3]]
)
literal_entity = EntityEmbeddings(
entity=Value(value='literal entity text', is_uri=False),
entity=Term(type=LITERAL, value='literal entity text'),
vectors=[[0.4, 0.5, 0.6]]
)
message.entities = [uri_entity, literal_entity]

View file

@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
from trustgraph.schema import IRI, LITERAL
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@ -67,7 +68,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'test_entity'
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'test_entity'
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
mock_message.entities = [mock_entity]
@ -120,11 +122,13 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
mock_entity1.entity.value = 'entity_one'
mock_entity1.entity.type = IRI
mock_entity1.entity.iri = 'entity_one'
mock_entity1.vectors = [[0.1, 0.2]]
mock_entity2 = MagicMock()
mock_entity2.entity.value = 'entity_two'
mock_entity2.entity.type = IRI
mock_entity2.entity.iri = 'entity_two'
mock_entity2.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity1, mock_entity2]
@ -179,7 +183,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'vector_collection'
mock_entity = MagicMock()
mock_entity.entity.value = 'multi_vector_entity'
mock_entity.entity.type = IRI
mock_entity.entity.iri = 'multi_vector_entity'
mock_entity.vectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
@ -231,11 +236,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
mock_entity_empty.entity.type = LITERAL
mock_entity_empty.entity.value = "" # Empty string
mock_entity_empty.vectors = [[0.1, 0.2]]
mock_entity_none = MagicMock()
mock_entity_none.entity.value = None # None value
mock_entity_none.entity = None # None entity
mock_entity_none.vectors = [[0.3, 0.4]]
mock_message.entities = [mock_entity_empty, mock_entity_none]

View file

@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch, call
from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor
from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor
from trustgraph.schema import Triples, Triple, Value, Metadata
from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest
@ -60,9 +60,9 @@ class TestNeo4jUserCollectionIsolation:
)
triple = Triple(
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="literal_value", is_uri=False)
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal_value")
)
message = Triples(
@ -128,9 +128,9 @@ class TestNeo4jUserCollectionIsolation:
metadata = Metadata(id="test-id")
triple = Triple(
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="http://example.com/object", is_uri=True)
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=IRI, iri="http://example.com/object")
)
message = Triples(
@ -170,8 +170,8 @@ class TestNeo4jUserCollectionIsolation:
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Value(value="http://example.com/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None
)
@ -254,9 +254,9 @@ class TestNeo4jUserCollectionIsolation:
metadata=Metadata(user="user1", collection="coll1"),
triples=[
Triple(
s=Value(value="http://example.com/user1/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="user1_data", is_uri=False)
s=Term(type=IRI, iri="http://example.com/user1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data")
)
]
)
@ -265,9 +265,9 @@ class TestNeo4jUserCollectionIsolation:
metadata=Metadata(user="user2", collection="coll2"),
triples=[
Triple(
s=Value(value="http://example.com/user2/subject", is_uri=True),
p=Value(value="http://example.com/predicate", is_uri=True),
o=Value(value="user2_data", is_uri=False)
s=Term(type=IRI, iri="http://example.com/user2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data")
)
]
)
@ -429,9 +429,9 @@ class TestNeo4jUserCollectionRegression:
metadata=Metadata(user="user1", collection="coll1"),
triples=[
Triple(
s=Value(value=shared_uri, is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="user1_value", is_uri=False)
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value")
)
]
)
@ -440,9 +440,9 @@ class TestNeo4jUserCollectionRegression:
metadata=Metadata(user="user2", collection="coll2"),
triples=[
Triple(
s=Value(value=shared_uri, is_uri=True),
p=Value(value="http://example.com/p", is_uri=True),
o=Value(value="user2_value", is_uri=False)
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value")
)
]
)

View file

@ -1,533 +0,0 @@
"""
Unit tests for Cassandra Object Storage Processor
Tests the business logic of the object storage processor including:
- Schema configuration handling
- Type conversions
- Name sanitization
- Table structure generation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestObjectsCassandraStorageLogic:
"""Test business logic without FlowProcessor dependencies"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns (back to original logic)
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
def test_get_cassandra_type(self):
"""Test field type conversion to Cassandra types"""
processor = MagicMock()
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_cassandra_type("string") == "text"
assert processor.get_cassandra_type("boolean") == "boolean"
assert processor.get_cassandra_type("timestamp") == "timestamp"
assert processor.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert processor.get_cassandra_type("integer", size=2) == "int"
assert processor.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert processor.get_cassandra_type("float", size=2) == "float"
assert processor.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert processor.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self):
"""Test value conversion for different field types"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Integer conversions
assert processor.convert_value("123", "integer") == 123
assert processor.convert_value(123.5, "integer") == 123
assert processor.convert_value(None, "integer") is None
# Float conversions
assert processor.convert_value("123.45", "float") == 123.45
assert processor.convert_value(123, "float") == 123.0
# Boolean conversions
assert processor.convert_value("true", "boolean") is True
assert processor.convert_value("false", "boolean") is False
assert processor.convert_value("1", "boolean") is True
assert processor.convert_value("0", "boolean") is False
assert processor.convert_value("yes", "boolean") is True
assert processor.convert_value("no", "boolean") is False
# String conversions
assert processor.convert_value(123, "string") == "123"
assert processor.convert_value(True, "string") == "True"
def test_table_creation_cql_generation(self):
"""Test CQL generation for table creation"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="customer_records",
description="Test customer schema",
fields=[
Field(
name="customer_id",
type="string",
size=50,
primary=True,
required=True,
indexed=False
),
Field(
name="email",
type="string",
size=100,
required=True,
indexed=True
),
Field(
name="age",
type="integer",
size=4,
required=False,
indexed=False
)
]
)
# Call ensure_table
processor.ensure_table("test_user", "customer_records", schema)
# Verify keyspace was ensured (check that it was added to known_keyspaces)
assert "test_user" in processor.known_keyspaces
# Check the CQL that was executed (first call should be table creation)
all_calls = processor.session.execute.call_args_list
table_creation_cql = all_calls[0][0][0] # First call
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
assert "collection text" in table_creation_cql
assert "customer_id text" in table_creation_cql
assert "email text" in table_creation_cql
assert "age int" in table_creation_cql
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
def test_table_creation_without_primary_key(self):
"""Test table creation when no primary key is defined"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema without primary key
schema = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Call ensure_table
processor.ensure_table("test_user", "events", schema)
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
executed_cql = processor.session.execute.call_args[0][0]
assert "synthetic_id uuid" in executed_cql
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "balance",
"type": "float",
"size": 8
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
# Note: Field.required always returns False due to Pulsar schema limitations
# The actual required value is tracked during schema parsing
@pytest.mark.asyncio
async def test_object_processing_logic(self):
"""Test the logic for processing ExtractedObject"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "456"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
# Verify insert was executed (keyspace normal, table with o_ prefix)
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value (from values[0])
assert values[2] == 456 # converted integer value (from values[0])
def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
description="Product catalog",
fields=[
Field(name="product_id", type="string", size=50, primary=True),
Field(name="category", type="string", size=30, indexed=True),
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
# Check index creation calls (table has o_ prefix, fields don't)
calls = processor.session.execute.call_args_list
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls)
class TestObjectsCassandraStorageBatchLogic:
"""Test batch processing logic in Cassandra storage"""
@pytest.mark.asyncio
async def test_batch_object_processing_logic(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
description="Test batch schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="value", type="integer", size=4)
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First", "value": "100"},
{"id": "002", "name": "Second", "value": "200"},
{"id": "003", "name": "Third", "value": "300"}
],
confidence=0.95,
source_span="batch source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = batch_obj
# Process batch object
await processor.on_object(msg, None, None)
# Verify table was ensured once
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
# Verify 3 separate insert calls (one per batch item)
assert processor.session.execute.call_count == 3
# Check each insert call
calls = processor.session.execute.call_args_list
for i, call in enumerate(calls):
insert_cql = call[0][0]
values = call[0][1]
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
assert "collection" in insert_cql
# Check values for each batch item
assert values[0] == "batch_collection" # collection
assert values[1] == f"00{i+1}" # id from batch item i
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
assert values[3] == (i+1) * 100 # converted integer value
@pytest.mark.asyncio
async def test_empty_batch_processing_logic(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="empty source"
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
# Process empty batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@pytest.mark.asyncio
async def test_single_item_batch_processing_logic(self):
"""Test processing of single-item batch (backward compatibility)"""
processor = MagicMock()
processor.schemas = {
"single_schema": RowSchema(
name="single_schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(
metadata=Metadata(
id="single-001",
user="test_user",
collection="single_collection",
metadata=[]
),
schema_name="single_schema",
values=[{"id": "single-1", "data": "single data"}], # Array with one item
confidence=0.8,
source_span="single source"
)
msg = MagicMock()
msg.value.return_value = single_batch_obj
# Process single-item batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify exactly one insert call
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_single_schema" in insert_cql
assert values[0] == "single_collection" # collection
assert values[1] == "single-1" # id value
assert values[2] == "single data" # data value
def test_batch_value_conversion_logic(self):
"""Test value conversion works correctly for batch items"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Test various conversion scenarios that would occur in batch processing
test_cases = [
# Integer conversions for batch items
("123", "integer", 123),
("456", "integer", 456),
("789", "integer", 789),
# Float conversions for batch items
("12.5", "float", 12.5),
("34.7", "float", 34.7),
# Boolean conversions for batch items
("true", "boolean", True),
("false", "boolean", False),
("1", "boolean", True),
("0", "boolean", False),
# String conversions for batch items
(123, "string", "123"),
(45.6, "string", "45.6"),
]
for input_val, field_type, expected_output in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"

View file

@ -0,0 +1,435 @@
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant row embeddings storage functionality"""
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_basic(self, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key='test-api-key'
)
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
"""Test processor initialization with default values"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key=None
)
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_sanitize_name(self, mock_qdrant_client):
"""Test name sanitization for Qdrant collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Test basic sanitization
assert processor.sanitize_name("simple") == "simple"
assert processor.sanitize_name("with-dash") == "with_dash"
assert processor.sanitize_name("with.dot") == "with_dot"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
# Test numeric prefix handling
assert processor.sanitize_name("123start") == "r_123start"
assert processor.sanitize_name("_underscore") == "r__underscore"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_get_collection_name(self, mock_qdrant_client):
"""Test Qdrant collection name generation"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
"""Test that ensure_collection creates a new collection when needed"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
"""Test that ensure_collection skips creation when collection exists"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
"""Test that ensure_collection uses cache for previously created collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
"""Test processing basic row embeddings message"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid-123'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='customers',
embeddings=[embedding]
)
# Mock message wrapper
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['index_name'] == 'customer_id'
assert point.payload['index_value'] == ['CUST001']
assert point.payload['text'] == 'CUST001'
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='people',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
"""Test that embeddings with no vectors are skipped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with no vectors
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
"""Test that messages for unknown collections are dropped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection(self, mock_qdrant_client):
"""Test deleting all collections for a user/collection"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
"""Test deleting collections for a specific schema"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,474 @@
"""
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
Tests the business logic of the row storage processor including:
- Schema configuration handling
- Name sanitization
- Unified table structure
- Index management
- Row storage with multi-index support
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
# Schema with primary and indexed fields
schema = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed fields
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names # Not indexed
assert len(index_names) == 3
def test_get_index_names_no_indexes(self):
"""Test schema with no indexed fields"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="no_index_schema",
fields=[
Field(name="data1", type="string"),
Field(name="data2", type="string")
]
)
index_names = processor.get_index_names(schema)
assert len(index_names) == 0
def test_build_index_value(self):
"""Test building index values from row data"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
# Single field index
result = processor.build_index_value(value_map, "id")
assert result == ["123"]
result = processor.build_index_value(value_map, "category")
assert result == ["electronics"]
# Missing field returns empty string
result = processor.build_index_value(value_map, "missing")
assert result == [""]
def test_build_index_value_composite(self):
"""Test building composite index values"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
# Composite index (comma-separated field names)
result = processor.build_index_value(value_map, "region,category")
assert result == ["us-west", "electronics"]
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "category",
"type": "string",
"indexed": True
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
@pytest.mark.asyncio
async def test_object_processing_stores_data_map(self):
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "test_data"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify insert was executed
processor.session.execute.assert_called()
insert_call = processor.session.execute.call_args
insert_cql = insert_call[0][0]
values = insert_call[0][1]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
assert values[1] == "test_schema" # schema_name
assert values[2] == "id" # index_name (primary key field)
assert values[3] == ["123"] # index_value as list
assert values[4] == {"id": "123", "value": "test_data"} # data map
assert values[5] == "" # source
@pytest.mark.asyncio
async def test_object_processing_multiple_indexes(self):
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="multi_index_schema",
values=[{"id": "123", "category": "electronics", "status": "active"}],
confidence=0.9,
source_span=""
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per indexed field: id, category, status)
assert processor.session.execute.call_count == 3
# Check that different index_names were used
index_names_used = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
index_names_used.add(values[2]) # index_name is 3rd value
assert index_names_used == {"id", "category", "status"}
class TestRowsCassandraStorageBatchLogic:
"""Test batch processing logic for unified table implementation"""
@pytest.mark.asyncio
async def test_batch_object_processing(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First"},
{"id": "002", "name": "Second"},
{"id": "003", "name": "Third"}
],
confidence=0.95,
source_span=""
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert processor.session.execute.call_count == 3
# Check each insert has different id
ids_inserted = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
ids_inserted.add(tuple(values[3])) # index_value is 4th value
assert ids_inserted == {("001",), ("002",), ("003",)}
@pytest.mark.asyncio
async def test_empty_batch_processing(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span=""
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
class TestUnifiedTableStructure:
"""Test the unified rows table structure"""
def test_ensure_tables_creates_unified_structure(self):
"""Test that ensure_tables creates the unified rows table"""
processor = MagicMock()
processor.known_keyspaces = {"test_user"}
processor.tables_initialized = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.ensure_keyspace = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should have 2 calls: create rows table + create row_partitions table
assert processor.session.execute.call_count == 2
# Check rows table creation
rows_cql = processor.session.execute.call_args_list[0][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
assert "collection text" in rows_cql
assert "schema_name text" in rows_cql
assert "index_name text" in rows_cql
assert "index_value frozen<list<text>>" in rows_cql
assert "data map<text, text>" in rows_cql
assert "source text" in rows_cql
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
# Check row_partitions table creation
partitions_cql = processor.session.execute.call_args_list[1][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
# Verify keyspace added to initialized set
assert "test_user" in processor.tables_initialized
def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent"""
processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized
processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should not execute any CQL since already initialized
processor.session.execute.assert_not_called()
class TestPartitionRegistration:
"""Test partition registration for tracking what's stored"""
def test_register_partitions(self):
"""Test registering partitions for a collection/schema pair"""
processor = MagicMock()
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
# Verify cache was updated
assert ("test_collection", "test_schema") in processor.registered_partitions
def test_register_partitions_idempotent(self):
"""Test that partition registration is idempotent"""
processor = MagicMock()
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -6,7 +6,8 @@ import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.cassandra.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Triple, LITERAL, IRI
from trustgraph.direct.cassandra_kg import DEFAULT_GRAPH
class TestCassandraStorageProcessor:
@ -86,29 +87,29 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_username == 'new-user' # Only cassandra_* params work
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_with_auth(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_switching_with_auth(self, mock_kg_class):
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(
taskgroup=taskgroup_mock,
cassandra_username='testuser',
cassandra_password='testpass'
)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph was called with auth parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user1',
username='testuser',
@ -117,128 +118,150 @@ class TestCassandraStorageProcessor:
assert processor.table == 'user1'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_without_auth(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_switching_without_auth(self, mock_kg_class):
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify KnowledgeGraph was called without auth parameters
mock_trustgraph.assert_called_once_with(
mock_kg_class.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user2'
)
assert processor.table == 'user2'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_reuse_when_same(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_table_reuse_when_same(self, mock_kg_class):
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1
assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
assert mock_trustgraph.call_count == 1 # Should not increase
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_triple_insertion(self, mock_kg_class):
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock triples
# Create mock triples with proper Term structure
triple1 = MagicMock()
triple1.s.type = LITERAL
triple1.s.value = 'subject1'
triple1.s.datatype = ''
triple1.s.language = ''
triple1.p.type = LITERAL
triple1.p.value = 'predicate1'
triple1.o.type = LITERAL
triple1.o.value = 'object1'
triple1.o.datatype = ''
triple1.o.language = ''
triple1.g = None
triple2 = MagicMock()
triple2.s.type = LITERAL
triple2.s.value = 'subject2'
triple2.s.datatype = ''
triple2.s.language = ''
triple2.p.type = LITERAL
triple2.p.value = 'predicate2'
triple2.o.type = LITERAL
triple2.o.value = 'object2'
triple2.o.datatype = ''
triple2.o.language = ''
triple2.g = None
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
# Verify both triples were inserted
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1')
mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2')
mock_tg_instance.insert.assert_any_call(
'collection1', 'subject1', 'predicate1', 'object1',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
mock_tg_instance.insert.assert_any_call(
'collection1', 'subject2', 'predicate2', 'object2',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_triple_insertion_with_empty_list(self, mock_kg_class):
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
"""Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock()
mock_trustgraph.side_effect = Exception("Connection failed")
mock_kg_class.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
@ -326,92 +349,104 @@ class TestCassandraStorageProcessor:
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_kg_class):
"""Test table switching when different tables are used in sequence"""
taskgroup_mock = MagicMock()
mock_tg_instance1 = MagicMock()
mock_tg_instance2 = MagicMock()
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
processor = Processor(taskgroup=taskgroup_mock)
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
assert mock_trustgraph.call_count == 2
assert mock_kg_class.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_kg_class):
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create triple with special characters
# Create triple with special characters and proper Term structure
triple = MagicMock()
triple.s.type = LITERAL
triple.s.value = 'subject with spaces & symbols'
triple.s.datatype = ''
triple.s.language = ''
triple.p.type = LITERAL
triple.p.value = 'predicate:with/colons'
triple.o.type = LITERAL
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
triple.o.datatype = ''
triple.o.language = ''
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
'test_collection',
'subject with spaces & symbols',
'predicate:with/colons',
'object with "quotes" and unicode: ñáéíóú'
'object with "quotes" and unicode: ñáéíóú',
g=DEFAULT_GRAPH,
otype='l',
dtype='',
lang=''
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
"""Test that table remains unchanged when TrustGraph creation fails"""
taskgroup_mock = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_trustgraph.side_effect = Exception("Connection failed")
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
@ -422,12 +457,12 @@ class TestCassandraPerformanceOptimizations:
"""Test cases for multi-table performance optimizations"""
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_legacy_mode_uses_single_table(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_legacy_mode_uses_single_table(self, mock_kg_class):
"""Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
processor = Processor(taskgroup=taskgroup_mock)
@ -440,16 +475,15 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance uses legacy mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
assert mock_tg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_optimized_mode_uses_multi_table(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_optimized_mode_uses_multi_table(self, mock_kg_class):
"""Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
processor = Processor(taskgroup=taskgroup_mock)
@ -462,24 +496,31 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples(mock_message)
# Verify KnowledgeGraph instance is in optimized mode
kg_instance = mock_trustgraph.return_value
assert kg_instance is not None
assert mock_tg_instance is not None
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_batch_write_consistency(self, mock_trustgraph):
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_batch_write_consistency(self, mock_kg_class):
"""Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_trustgraph.return_value = mock_tg_instance
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
# Create test triple
# Create test triple with proper Term structure
triple = MagicMock()
triple.s.type = LITERAL
triple.s.value = 'test_subject'
triple.s.datatype = ''
triple.s.language = ''
triple.p.type = LITERAL
triple.p.value = 'test_predicate'
triple.o.type = LITERAL
triple.o.value = 'test_object'
triple.o.datatype = ''
triple.o.language = ''
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
@ -490,7 +531,8 @@ class TestCassandraPerformanceOptimizations:
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object'
'collection1', 'test_subject', 'test_predicate', 'test_object',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
def test_environment_variable_controls_mode(self):

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.falkordb.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestFalkorDBStorageProcessor:
@ -22,9 +22,9 @@ class TestFalkorDBStorageProcessor:
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
message.triples = [triple]
@ -183,9 +183,9 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=IRI, iri='http://example.com/object')
)
message.triples = [triple]
@ -269,14 +269,14 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object1')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]
@ -337,14 +337,14 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]

View file

@ -6,7 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
from trustgraph.schema import Value, Triple
from trustgraph.schema import Term, Triple, IRI, LITERAL
class TestMemgraphStorageProcessor:
@ -22,9 +22,9 @@ class TestMemgraphStorageProcessor:
# Create a test triple
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
message.triples = [triple]
@ -231,9 +231,9 @@ class TestMemgraphStorageProcessor:
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='http://example.com/object', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=IRI, iri='http://example.com/object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
@ -265,9 +265,9 @@ class TestMemgraphStorageProcessor:
mock_tx = MagicMock()
triple = Triple(
s=Value(value='http://example.com/subject', is_uri=True),
p=Value(value='http://example.com/predicate', is_uri=True),
o=Value(value='literal object', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject'),
p=Term(type=IRI, iri='http://example.com/predicate'),
o=Term(type=LITERAL, value='literal object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
@ -347,14 +347,14 @@ class TestMemgraphStorageProcessor:
message.metadata.collection = 'test_collection'
triple1 = Triple(
s=Value(value='http://example.com/subject1', is_uri=True),
p=Value(value='http://example.com/predicate1', is_uri=True),
o=Value(value='literal object1', is_uri=False)
s=Term(type=IRI, iri='http://example.com/subject1'),
p=Term(type=IRI, iri='http://example.com/predicate1'),
o=Term(type=LITERAL, value='literal object1')
)
triple2 = Triple(
s=Value(value='http://example.com/subject2', is_uri=True),
p=Value(value='http://example.com/predicate2', is_uri=True),
o=Value(value='http://example.com/object2', is_uri=True)
s=Term(type=IRI, iri='http://example.com/subject2'),
p=Term(type=IRI, iri='http://example.com/predicate2'),
o=Term(type=IRI, iri='http://example.com/object2')
)
message.triples = [triple1, triple2]

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.storage.triples.neo4j.write import Processor
from trustgraph.schema import IRI, LITERAL
class TestNeo4jStorageProcessor:
@ -257,10 +258,12 @@ class TestNeo4jStorageProcessor:
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock()
@ -327,10 +330,12 @@ class TestNeo4jStorageProcessor:
# Create mock triple with literal object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal value"
triple.o.is_uri = False
# Create mock message with metadata
mock_message = MagicMock()
@ -398,16 +403,20 @@ class TestNeo4jStorageProcessor:
# Create mock triples
triple1 = MagicMock()
triple1.s.value = "http://example.com/subject1"
triple1.p.value = "http://example.com/predicate1"
triple1.o.value = "http://example.com/object1"
triple1.o.is_uri = True
triple1.s.type = IRI
triple1.s.iri = "http://example.com/subject1"
triple1.p.type = IRI
triple1.p.iri = "http://example.com/predicate1"
triple1.o.type = IRI
triple1.o.iri = "http://example.com/object1"
triple2 = MagicMock()
triple2.s.value = "http://example.com/subject2"
triple2.p.value = "http://example.com/predicate2"
triple2.s.type = IRI
triple2.s.iri = "http://example.com/subject2"
triple2.p.type = IRI
triple2.p.iri = "http://example.com/predicate2"
triple2.o.type = LITERAL
triple2.o.value = "literal value"
triple2.o.is_uri = False
# Create mock message with metadata
mock_message = MagicMock()
@ -550,10 +559,12 @@ class TestNeo4jStorageProcessor:
# Create triple with special characters
triple = MagicMock()
triple.s.value = "http://example.com/subject with spaces"
triple.p.value = "http://example.com/predicate:with/symbols"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject with spaces"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate:with/symbols"
triple.o.type = LITERAL
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
triple.o.is_uri = False
mock_message = MagicMock()
mock_message.triples = [triple]

View file

@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'client')
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4 # 4 safety categories
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Assert
# Verify Google AI Studio client was called with correct API key
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
# Verify processor has the client
assert processor.client == mock_genai_client

View file

@ -1,6 +1,6 @@
"""
Unit tests for trustgraph.model.text_completion.vertexai
Starting simple with one test to get the basics working
Updated for google-genai SDK
"""
import pytest
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
"""Simple test for processor initialization"""
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test basic processor initialization with mocked dependencies"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock the parent class initialization to avoid taskgroup requirement
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert processor.default_model == 'gemini-2.0-flash-001'
assert hasattr(processor, 'generation_configs') # Cache dictionary
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'model_clients') # LLM clients are now cached
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
assert hasattr(processor, 'client') # genai.Client
mock_service_account.Credentials.from_service_account_file.assert_called_once()
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test successful content generation"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response from Gemini"
mock_response.usage_metadata.prompt_token_count = 15
mock_response.usage_metadata.candidates_token_count = 8
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'gemini-2.0-flash-001'
# Check that the method was called (actual prompt format may vary)
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
# Generation config is now created dynamically per model
assert 'generation_config' in call_args[1]
assert call_args[1]['safety_settings'] == processor.safety_settings
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test rate limit error handling"""
# Arrange
from google.api_core.exceptions import ResourceExhausted
from trustgraph.exceptions import TooManyRequests
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of blocked content (safety filters)"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None # Blocked content returns None
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 0
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
"""Test processor initialization without private key (uses default credentials)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Mock google.auth.default() to return credentials and project ID
mock_credentials = MagicMock()
mock_auth_default.return_value = (mock_credentials, "test-project-123")
# Mock GenerativeModel
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
config = {
'region': 'us-central1',
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
project='test-project-123'
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of generic exceptions"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Network error")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = Exception("Network error")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Network error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test processor initialization with custom parameters"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert processor.default_model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
# Verify that generation_config cache exists
assert hasattr(processor, 'generation_configs')
assert processor.generation_configs == {} # Empty cache initially
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
# Verify service account was called with custom key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
# Verify that api_params dict has the correct values (this is accessible)
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify that api_params dict has the correct values
assert processor.api_params["temperature"] == 0.7
assert processor.api_params["max_output_tokens"] == 4096
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test that VertexAI is initialized correctly with credentials"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
# Verify VertexAI init was called with correct parameters
mock_vertexai.init.assert_called_once_with(
# Verify genai.Client was called with correct parameters
mock_genai.Client.assert_called_once_with(
vertexai=True,
project='test-project-123',
location='europe-west1',
credentials=mock_credentials,
project='test-project-123'
credentials=mock_credentials
)
# GenerativeModel is now created lazily on first use, not at initialization
mock_generative_model.assert_not_called()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test content generation with empty prompts"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Default response"
mock_response.usage_metadata.prompt_token_count = 2
mock_response.usage_metadata.candidates_token_count = 3
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gemini-2.0-flash-001'
# Verify the model was called with the combined empty prompts
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# The prompt should be "" + "\n\n" + "" = "\n\n"
assert call_args[0][0] == "\n\n"
# Verify the client was called
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
"""Test Anthropic processor initialization with private key credentials"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-456"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock AnthropicVertex
mock_anthropic_client = MagicMock()
mock_anthropic_vertex.return_value = mock_anthropic_client
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'claude-3-sonnet@20240229'
# is_anthropic logic is now determined dynamically per request
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
# Verify AnthropicVertex was initialized with credentials
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
mock_anthropic_vertex.assert_called_once_with(
region='us-west1',
project_id='test-project-456',
credentials=mock_credentials
)
# Verify api_params are set correctly
assert processor.api_params["temperature"] == 0.5
assert processor.api_params["max_output_tokens"] == 2048
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
mock_client.models.generate_content.assert_called_once()
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'temperature': 0.2,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
# Verify the call was made with the override model
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-pro"
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
mock_client.models.generate_content.assert_called_once()
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
# Verify the model override was used
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])

View file

@ -73,6 +73,8 @@ from .async_metrics import AsyncMetrics
# Types
from .types import (
Triple,
Uri,
Literal,
ConfigKey,
ConfigValue,
DocumentMetadata,
@ -99,7 +101,7 @@ from .exceptions import (
LoadError,
LookupError,
NLPQueryError,
ObjectsQueryError,
RowsQueryError,
RequestError,
StructuredQueryError,
UnexpectedError,
@ -133,6 +135,8 @@ __all__ = [
# Types
"Triple",
"Uri",
"Literal",
"ConfigKey",
"ConfigValue",
"DocumentMetadata",
@ -157,7 +161,7 @@ __all__ = [
"LoadError",
"LookupError",
"NLPQueryError",
"ObjectsQueryError",
"RowsQueryError",
"RequestError",
"StructuredQueryError",
"UnexpectedError",

View file

@ -115,15 +115,15 @@ class AsyncBulkClient:
async for raw_message in websocket:
yield json.loads(raw_message)
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import objects via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for obj in objects:
await websocket.send(json.dumps(obj))
async for row in rows:
await websocket.send(json.dumps(row))
async def aclose(self) -> None:
"""Close connections"""

View file

@ -612,8 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}")
```
"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request_data = {
"text": text,
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
@ -704,18 +708,18 @@ class AsyncFlowInstance:
return await self.request("triples", request_data)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs: Any):
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs: Any):
"""
Execute a GraphQL query on stored objects.
Execute a GraphQL query on stored rows.
Queries structured data objects using GraphQL syntax. Supports complex
Queries structured data rows using GraphQL syntax. Supports complex
queries with variables and named operations.
Args:
query: GraphQL query string
user: User identifier
collection: Collection identifier containing objects
collection: Collection identifier containing rows
variables: Optional GraphQL query variables
operation_name: Optional operation name for multi-operation queries
**kwargs: Additional service-specific parameters
@ -739,7 +743,7 @@ class AsyncFlowInstance:
}
'''
result = await flow.objects_query(
result = await flow.rows_query(
query=query,
user="trustgraph",
collection="users",
@ -761,4 +765,64 @@ class AsyncFlowInstance:
request_data["operationName"] = operation_name
request_data.update(kwargs)
return await self.request("objects", request_data)
return await self.request("rows", request_data)
async def row_embeddings_query(
self, text: str, schema_name: str, user: str = "trustgraph",
collection: str = "default", index_name: Optional[str] = None,
limit: int = 10, **kwargs: Any
):
"""
Query row embeddings for semantic search on structured data.
Performs semantic search over row index embeddings to find rows whose
indexed field values are most similar to the input text. Enables
fuzzy/semantic matching on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results to return (default: 10)
**kwargs: Additional service-specific parameters
Returns:
dict: Response containing matches with index_name, index_value,
text, and score
Example:
```python
async_flow = await api.async_flow()
flow = async_flow.id("default")
# Search for customers by name similarity
results = await flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
for match in results.get("matches", []):
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
```
"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request_data = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request_data["index_name"] = index_name
request_data.update(kwargs)
return await self.request("row-embeddings", request_data)

View file

@ -282,8 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"text": text,
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
@ -316,9 +320,9 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("triples", self.flow_id, request)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs):
"""GraphQL query"""
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs):
"""GraphQL query against structured rows"""
request = {
"query": query,
"user": user,
@ -330,7 +334,7 @@ class AsyncSocketFlowInstance:
request["operationName"] = operation_name
request.update(kwargs)
return await self.client._send_request("objects", self.flow_id, request)
return await self.client._send_request("rows", self.flow_id, request)
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
"""Execute MCP tool"""
@ -341,3 +345,26 @@ class AsyncSocketFlowInstance:
request.update(kwargs)
return await self.client._send_request("mcp-tool", self.flow_id, request)
async def row_embeddings_query(
self, text: str, schema_name: str, user: str = "trustgraph",
collection: str = "default", index_name: Optional[str] = None,
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request["index_name"] = index_name
request.update(kwargs)
return await self.client._send_request("row-embeddings", self.flow_id, request)

View file

@ -15,6 +15,15 @@ from . types import Triple
from . exceptions import ProtocolException
def _string_to_term(value: str) -> Dict[str, Any]:
"""Convert a string value to Term format for the gateway."""
# Treat URIs as IRI type, otherwise as literal
if value.startswith("http://") or value.startswith("https://") or "://" in value:
return {"t": "i", "i": value}
else:
return {"t": "l", "v": value}
class BulkClient:
"""
Synchronous bulk operations client for import/export.
@ -62,7 +71,12 @@ class BulkClient:
return loop.run_until_complete(coro)
def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None:
def import_triples(
self, flow: str, triples: Iterator[Triple],
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 100,
**kwargs: Any
) -> None:
"""
Bulk import RDF triples into a flow.
@ -71,6 +85,8 @@ class BulkClient:
Args:
flow: Flow identifier
triples: Iterator yielding Triple objects
metadata: Metadata dict with id, metadata, user, collection
batch_size: Number of triples per batch (default 100)
**kwargs: Additional parameters (reserved for future use)
Example:
@ -86,23 +102,47 @@ class BulkClient:
# ... more triples
# Import triples
bulk.import_triples(flow="default", triples=triple_generator())
bulk.import_triples(
flow="default",
triples=triple_generator(),
metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"}
)
```
"""
self._run_async(self._import_triples_async(flow, triples))
self._run_async(self._import_triples_async(flow, triples, metadata, batch_size))
async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None:
async def _import_triples_async(
self, flow: str, triples: Iterator[Triple],
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
if metadata is None:
metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
batch = []
for triple in triples:
batch.append({
"s": _string_to_term(triple.s),
"p": _string_to_term(triple.p),
"o": _string_to_term(triple.o)
})
if len(batch) >= batch_size:
message = {
"metadata": metadata,
"triples": batch
}
await websocket.send(json.dumps(message))
batch = []
# Send remaining items
if batch:
message = {
"s": triple.s,
"p": triple.p,
"o": triple.o
"metadata": metadata,
"triples": batch
}
await websocket.send(json.dumps(message))
@ -362,7 +402,12 @@ class BulkClient:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
def import_entity_contexts(
self, flow: str, contexts: Iterator[Dict[str, Any]],
metadata: Optional[Dict[str, Any]] = None,
batch_size: int = 100,
**kwargs: Any
) -> None:
"""
Bulk import entity contexts into a flow.
@ -373,6 +418,8 @@ class BulkClient:
Args:
flow: Flow identifier
contexts: Iterator yielding context dictionaries
metadata: Metadata dict with id, metadata, user, collection
batch_size: Number of contexts per batch (default 100)
**kwargs: Additional parameters (reserved for future use)
Example:
@ -381,27 +428,49 @@ class BulkClient:
# Generate entity contexts to import
def context_generator():
yield {"entity": "entity1", "context": "Description of entity1..."}
yield {"entity": "entity2", "context": "Description of entity2..."}
yield {"entity": {"v": "entity1", "e": True}, "context": "Description..."}
yield {"entity": {"v": "entity2", "e": True}, "context": "Description..."}
# ... more contexts
bulk.import_entity_contexts(
flow="default",
contexts=context_generator()
contexts=context_generator(),
metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"}
)
```
"""
self._run_async(self._import_entity_contexts_async(flow, contexts))
self._run_async(self._import_entity_contexts_async(flow, contexts, metadata, batch_size))
async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None:
async def _import_entity_contexts_async(
self, flow: str, contexts: Iterator[Dict[str, Any]],
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
if metadata is None:
metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
batch = []
for context in contexts:
await websocket.send(json.dumps(context))
batch.append(context)
if len(batch) >= batch_size:
message = {
"metadata": metadata,
"entities": batch
}
await websocket.send(json.dumps(message))
batch = []
# Send remaining items
if batch:
message = {
"metadata": metadata,
"entities": batch
}
await websocket.send(json.dumps(message))
def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
"""
@ -461,45 +530,45 @@ class BulkClient:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""
Bulk import structured objects into a flow.
Bulk import structured rows into a flow.
Efficiently uploads structured data objects via WebSocket streaming
Efficiently uploads structured data rows via WebSocket streaming
for use in GraphQL queries.
Args:
flow: Flow identifier
objects: Iterator yielding object dictionaries
rows: Iterator yielding row dictionaries
**kwargs: Additional parameters (reserved for future use)
Example:
```python
bulk = api.bulk()
# Generate objects to import
def object_generator():
yield {"id": "obj1", "name": "Object 1", "value": 100}
yield {"id": "obj2", "name": "Object 2", "value": 200}
# ... more objects
# Generate rows to import
def row_generator():
yield {"id": "row1", "name": "Row 1", "value": 100}
yield {"id": "row2", "name": "Row 2", "value": 200}
# ... more rows
bulk.import_objects(
bulk.import_rows(
flow="default",
objects=object_generator()
rows=row_generator()
)
```
"""
self._run_async(self._import_objects_async(flow, objects))
self._run_async(self._import_rows_async(flow, rows))
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of objects import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for obj in objects:
await websocket.send(json.dumps(obj))
for row in rows:
await websocket.send(json.dumps(row))
def close(self) -> None:
"""Close connections"""

View file

@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException):
pass
class ObjectsQueryError(TrustGraphException):
"""Objects query service error"""
class RowsQueryError(TrustGraphException):
"""Rows query service error"""
pass
@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = {
"load-error": LoadError,
"lookup-error": LookupError,
"nlp-query-error": NLPQueryError,
"objects-query-error": ObjectsQueryError,
"rows-query-error": RowsQueryError,
"request-error": RequestError,
"structured-query-error": StructuredQueryError,
"unexpected-error": UnexpectedError,

View file

@ -10,12 +10,27 @@ import json
import base64
from .. knowledge import hash, Uri, Literal
from .. schema import IRI, LITERAL
from . types import Triple
from . exceptions import ProtocolException
def to_value(x):
if x["e"]: return Uri(x["v"])
return Literal(x["v"])
"""Convert wire format to Uri or Literal."""
if x.get("t") == IRI:
return Uri(x.get("i", ""))
elif x.get("t") == LITERAL:
return Literal(x.get("v", ""))
# Fallback for any other type
return Literal(x.get("v", x.get("i", "")))
def from_value(v):
"""Convert Uri or Literal to wire format."""
if isinstance(v, Uri):
return {"t": IRI, "i": str(v)}
else:
return {"t": LITERAL, "v": str(v)}
class Flow:
"""
@ -569,9 +584,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
# Query graph embeddings for semantic search
input = {
"text": text,
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
@ -582,6 +601,51 @@ class FlowInstance:
input
)
def document_embeddings_query(self, text, user, collection, limit=10):
"""
Query document chunks using semantic similarity.
Finds document chunks whose content is semantically similar to the
input text, using vector embeddings.
Args:
text: Query text for semantic search
user: User/keyspace identifier
collection: Collection identifier
limit: Maximum number of results (default: 10)
Returns:
dict: Query results with similar document chunks
Example:
```python
flow = api.flow().id("default")
results = flow.document_embeddings_query(
text="machine learning algorithms",
user="trustgraph",
collection="research-papers",
limit=5
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
# Query document embeddings for semantic search
input = {
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
}
return self.request(
"service/document-embeddings",
input
)
def prompt(self, id, variables):
"""
Execute a prompt template with variable substitution.
@ -751,17 +815,17 @@ class FlowInstance:
if s:
if not isinstance(s, Uri):
raise RuntimeError("s must be Uri")
input["s"] = { "v": str(s), "e": isinstance(s, Uri), }
input["s"] = from_value(s)
if p:
if not isinstance(p, Uri):
raise RuntimeError("p must be Uri")
input["p"] = { "v": str(p), "e": isinstance(p, Uri), }
input["p"] = from_value(p)
if o:
if not isinstance(o, Uri) and not isinstance(o, Literal):
raise RuntimeError("o must be Uri or Literal")
input["o"] = { "v": str(o), "e": isinstance(o, Uri), }
input["o"] = from_value(o)
object = self.request(
"service/triples",
@ -834,9 +898,9 @@ class FlowInstance:
if metadata:
metadata.emit(
lambda t: triples.append({
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
"s": from_value(t["s"]),
"p": from_value(t["p"]),
"o": from_value(t["o"]),
})
)
@ -913,9 +977,9 @@ class FlowInstance:
if metadata:
metadata.emit(
lambda t: triples.append({
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
"s": from_value(t["s"]),
"p": from_value(t["p"]),
"o": from_value(t["o"]),
})
)
@ -937,12 +1001,12 @@ class FlowInstance:
input
)
def objects_query(
def rows_query(
self, query, user="trustgraph", collection="default",
variables=None, operation_name=None
):
"""
Execute a GraphQL query against structured objects in the knowledge graph.
Execute a GraphQL query against structured rows in the knowledge graph.
Queries structured data using GraphQL syntax, allowing complex queries
with filtering, aggregation, and relationship traversal.
@ -974,7 +1038,7 @@ class FlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
user="trustgraph",
collection="scientists"
@ -989,7 +1053,7 @@ class FlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
variables={"name": "Marie Curie"}
)
@ -1010,7 +1074,7 @@ class FlowInstance:
input["operation_name"] = operation_name
response = self.request(
"service/objects",
"service/rows",
input
)
@ -1233,3 +1297,78 @@ class FlowInstance:
return response["schema-matches"]
def row_embeddings_query(
self, text, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10
):
"""
Query row data using semantic similarity on indexed fields.
Finds rows whose indexed field values are semantically similar to the
input text, using vector embeddings. This enables fuzzy/semantic matching
on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User/keyspace identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results (default: 10)
Returns:
dict: Query results with matches containing index_name, index_value,
text, and score
Example:
```python
flow = api.flow().id("default")
# Search for customers by name similarity
results = flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
# Filter to specific index
results = flow.row_embeddings_query(
text="machine learning engineer",
schema_name="employees",
index_name="job_title",
limit=10
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
# Query row embeddings for semantic search
input = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
input["index_name"] = index_name
response = self.request(
"service/row-embeddings",
input
)
# Check for system-level error
if "error" in response and response["error"]:
error_type = response["error"].get("type", "unknown")
error_message = response["error"].get("message", "Unknown error")
raise ProtocolException(f"{error_type}: {error_message}")
return response

View file

@ -10,11 +10,18 @@ import json
import base64
from .. knowledge import hash, Uri, Literal
from .. schema import IRI, LITERAL
from . types import Triple
def to_value(x):
if x["e"]: return Uri(x["v"])
return Literal(x["v"])
"""Convert wire format to Uri or Literal."""
if x.get("t") == IRI:
return Uri(x.get("i", ""))
elif x.get("t") == LITERAL:
return Literal(x.get("v", ""))
# Fallback for any other type
return Literal(x.get("v", x.get("i", "")))
class Knowledge:
"""

View file

@ -12,13 +12,28 @@ import logging
from . types import DocumentMetadata, ProcessingMetadata, Triple
from .. knowledge import hash, Uri, Literal
from .. schema import IRI, LITERAL
from . exceptions import *
logger = logging.getLogger(__name__)
def to_value(x):
if x["e"]: return Uri(x["v"])
return Literal(x["v"])
"""Convert wire format to Uri or Literal."""
if x.get("t") == IRI:
return Uri(x.get("i", ""))
elif x.get("t") == LITERAL:
return Literal(x.get("v", ""))
# Fallback for any other type
return Literal(x.get("v", x.get("i", "")))
def from_value(v):
"""Convert Uri or Literal to wire format."""
if isinstance(v, Uri):
return {"t": IRI, "i": str(v)}
else:
return {"t": LITERAL, "v": str(v)}
class Library:
"""
@ -118,18 +133,18 @@ class Library:
if isinstance(metadata, list):
triples = [
{
"s": { "v": t.s, "e": isinstance(t.s, Uri) },
"p": { "v": t.p, "e": isinstance(t.p, Uri) },
"o": { "v": t.o, "e": isinstance(t.o, Uri) }
"s": from_value(t.s),
"p": from_value(t.p),
"o": from_value(t.o),
}
for t in metadata
]
elif hasattr(metadata, "emit"):
metadata.emit(
lambda t: triples.append({
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
"s": from_value(t["s"]),
"p": from_value(t["p"]),
"o": from_value(t["o"]),
})
)
else:
@ -315,9 +330,9 @@ class Library:
"comments": metadata.comments,
"metadata": [
{
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
"s": from_value(t["s"]),
"p": from_value(t["p"]),
"o": from_value(t["o"]),
}
for t in metadata.metadata
],

View file

@ -649,8 +649,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"text": text,
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
@ -659,6 +663,54 @@ class SocketFlowInstance:
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
def document_embeddings_query(
self,
text: str,
user: str,
collection: str,
limit: int = 10,
**kwargs: Any
) -> Dict[str, Any]:
"""
Query document chunks using semantic similarity.
Args:
text: Query text for semantic search
user: User/keyspace identifier
collection: Collection identifier
limit: Maximum number of results (default: 10)
**kwargs: Additional parameters passed to the service
Returns:
dict: Query results with similar document chunks
Example:
```python
socket = api.socket()
flow = socket.flow("default")
results = flow.document_embeddings_query(
text="machine learning algorithms",
user="trustgraph",
collection="research-papers",
limit=5
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"vectors": vectors,
"user": user,
"collection": collection,
"limit": limit
}
request.update(kwargs)
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
"""
Generate vector embeddings for text.
@ -737,7 +789,7 @@ class SocketFlowInstance:
return self.client._send_request_sync("triples", self.flow_id, request, False)
def objects_query(
def rows_query(
self,
query: str,
user: str,
@ -747,7 +799,7 @@ class SocketFlowInstance:
**kwargs: Any
) -> Dict[str, Any]:
"""
Execute a GraphQL query against structured objects.
Execute a GraphQL query against structured rows.
Args:
query: GraphQL query string
@ -774,7 +826,7 @@ class SocketFlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
user="trustgraph",
collection="scientists"
@ -792,7 +844,7 @@ class SocketFlowInstance:
request["operationName"] = operation_name
request.update(kwargs)
return self.client._send_request_sync("objects", self.flow_id, request, False)
return self.client._send_request_sync("rows", self.flow_id, request, False)
def mcp_tool(
self,
@ -829,3 +881,73 @@ class SocketFlowInstance:
request.update(kwargs)
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
def row_embeddings_query(
self,
text: str,
schema_name: str,
user: str = "trustgraph",
collection: str = "default",
index_name: Optional[str] = None,
limit: int = 10,
**kwargs: Any
) -> Dict[str, Any]:
"""
Query row data using semantic similarity on indexed fields.
Finds rows whose indexed field values are semantically similar to the
input text, using vector embeddings. This enables fuzzy/semantic matching
on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User/keyspace identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results (default: 10)
**kwargs: Additional parameters passed to the service
Returns:
dict: Query results with matches containing index_name, index_value,
text, and score
Example:
```python
socket = api.socket()
flow = socket.flow("default")
# Search for customers by name similarity
results = flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
# Filter to specific index
results = flow.row_embeddings_query(
text="machine learning engineer",
schema_name="employees",
index_name="job_title",
limit=10
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request["index_name"] = index_name
request.update(kwargs)
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)

View file

@ -34,5 +34,6 @@ from . tool_service import ToolService
from . tool_client import ToolClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -7,7 +7,7 @@ embeddings.
import logging
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. schema import Error, Value
from .. schema import Error, Term
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-query"
default_ident = "doc-embeddings-query"
class DocumentEmbeddingsQueryService(FlowProcessor):

View file

@ -2,15 +2,21 @@
import logging
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, IRI, LITERAL
from .. knowledge import Uri, Literal
# Module logger
logger = logging.getLogger(__name__)
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
"""Convert schema Term to Uri or Literal."""
if x.type == IRI:
return Uri(x.iri)
elif x.type == LITERAL:
return Literal(x.value)
# Fallback
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",

View file

@ -7,7 +7,7 @@ embeddings.
import logging
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import Error, Value
from .. schema import Error, Term
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec
@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-query"
default_ident = "graph-embeddings-query"
class GraphEmbeddingsQueryService(FlowProcessor):

View file

@ -0,0 +1,45 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
schema_name=schema_name,
user=user,
collection=collection,
limit=limit
)
if index_name:
request.index_name = index_name
resp = await self.request(request, timeout=timeout)
if resp.error:
raise RuntimeError(resp.error.message)
# Return matches as list of dicts
return [
{
"index_name": match.index_name,
"index_value": match.index_value,
"text": match.text,
"score": match.score
}
for match in (resp.matches or [])
]
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RowEmbeddingsQueryClientSpec, self).__init__(
request_name = request_name,
request_schema = RowEmbeddingsRequest,
response_name = response_name,
response_schema = RowEmbeddingsResponse,
impl = RowEmbeddingsQueryClient,
)

View file

@ -222,35 +222,50 @@ class Subscriber:
# Store message for later acknowledgment
msg_id = str(uuid.uuid4())
self.pending_acks[msg_id] = msg
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
delivery_success = False
has_matching_waiter = False
async with self.lock:
# Deliver to specific subscribers
if id in self.q:
has_matching_waiter = True
delivery_success = await self._deliver_to_queue(
self.q[id], value
)
# Deliver to all subscribers
for q in self.full.values():
has_matching_waiter = True
if await self._deliver_to_queue(q, value):
delivery_success = True
# Acknowledge only on successful delivery
if delivery_success:
self.consumer.acknowledge(msg)
del self.pending_acks[msg_id]
else:
# Negative acknowledge for retry
self.consumer.negative_acknowledge(msg)
del self.pending_acks[msg_id]
# Always acknowledge the message to prevent redelivery storms
# on shared topics. Negative acknowledging orphaned messages
# (no matching waiter) causes immediate redelivery to all
# subscribers, none of whom can handle it either.
self.consumer.acknowledge(msg)
del self.pending_acks[msg_id]
if not delivery_success:
if not has_matching_waiter:
# Message arrived for a waiter that no longer exists
# (likely due to client disconnect or timeout)
logger.debug(
f"Discarding orphaned message with id={id} - "
"no matching waiter"
)
else:
# Delivery failed (e.g., queue full with drop_new strategy)
logger.debug(
f"Message with id={id} dropped due to backpressure"
)
async def _deliver_to_queue(self, queue, value):
"""Deliver message to queue with backpressure handling"""

View file

@ -1,24 +1,34 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
from .. knowledge import Uri, Literal
class Triple:
def __init__(self, s, p, o):
self.s = s
self.p = p
self.o = o
def to_value(x):
if x.is_uri: return Uri(x.value)
return Literal(x.value)
"""Convert schema Term to Uri or Literal."""
if x.type == IRI:
return Uri(x.iri)
elif x.type == LITERAL:
return Literal(x.value)
# Fallback
return Literal(x.value or x.iri)
def from_value(x):
if x is None: return None
"""Convert Uri or Literal to schema Term."""
if x is None:
return None
if isinstance(x, Uri):
return Value(value=str(x), is_uri=True)
return Term(type=IRI, iri=str(x))
else:
return Value(value=str(x), is_uri=False)
return Term(type=LITERAL, value=str(x))
class TriplesClient(RequestResponse):
async def query(self, s=None, p=None, o=None, limit=20,

View file

@ -7,7 +7,7 @@ null. Output is a list of triples.
import logging
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .. schema import Value, Triple
from .. schema import Term, Triple
from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec

View file

@ -0,0 +1,60 @@
import _pulsar
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
from .. schema import row_embeddings_request_queue
from .. schema import row_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class RowEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
):
if input_queue == None:
input_queue = row_embeddings_request_queue
if output_queue == None:
output_queue = row_embeddings_response_queue
super(RowEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_schema=RowEmbeddingsRequest,
output_schema=RowEmbeddingsResponse,
)
def request(
self, vectors, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300
):
kwargs = dict(
user=user, collection=collection,
vectors=vectors, schema_name=schema_name,
limit=limit, timeout=timeout
)
if index_name:
kwargs["index_name"] = index_name
response = self.call(**kwargs)
if response.error:
raise RuntimeError(f"{response.error.type}: {response.error.message}")
return response.matches

View file

@ -2,7 +2,7 @@
import _pulsar
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . base import BaseClient
@ -46,9 +46,9 @@ class TriplesQueryClient(BaseClient):
if ent == None: return None
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Term(type=IRI, iri=ent)
return Value(value=ent, is_uri=False)
return Term(type=LITERAL, value=ent)
def request(
self,

View file

@ -19,9 +19,10 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
)
from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
@ -107,15 +108,21 @@ TranslatorRegistry.register_service(
)
TranslatorRegistry.register_service(
"graph-embeddings-query",
GraphEmbeddingsRequestTranslator(),
"graph-embeddings-query",
GraphEmbeddingsRequestTranslator(),
GraphEmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"objects-query",
ObjectsQueryRequestTranslator(),
ObjectsQueryResponseTranslator()
"row-embeddings-query",
RowEmbeddingsRequestTranslator(),
RowEmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"rows-query",
RowsQueryRequestTranslator(),
RowsQueryResponseTranslator()
)
TranslatorRegistry.register_service(

View file

@ -1,5 +1,5 @@
from .base import Translator, MessageTranslator
from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator
from .primitives import TermTranslator, ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
from .agent import AgentRequestTranslator, AgentResponseTranslator
from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
from .prompt import PromptRequestTranslator, PromptResponseTranslator
from .embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
)
from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator

View file

@ -1,7 +1,8 @@
from typing import Dict, Any, Tuple
from ...schema import (
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
GraphEmbeddingsRequest, GraphEmbeddingsResponse
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
)
from .base import MessageTranslator
from .primitives import ValueTranslator
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
class RowEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for RowEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest(
vectors=data["vectors"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
schema_name=data.get("schema_name", ""),
index_name=data.get("index_name")
)
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = {
"vectors": obj.vectors,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection,
"schema_name": obj.schema_name,
}
if obj.index_name:
result["index_name"] = obj.index_name
return result
class RowEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for RowEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.error is not None:
result["error"] = {
"type": obj.error.type,
"message": obj.error.message
}
if obj.matches is not None:
result["matches"] = [
{
"index_name": match.index_name,
"index_value": match.index_value,
"text": match.text,
"score": match.score
}
for match in obj.matches
]
return result
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

Some files were not shown because too many files have changed in this diff Show more