mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge 2.0 to master (#651)
This commit is contained in:
parent
3666ece2c5
commit
b9d7bf9a8b
212 changed files with 13940 additions and 6180 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=1.8.999
|
||||
run: make update-package-versions VERSION=2.0.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
|
|||
10
Makefile
10
Makefile
|
|
@ -5,7 +5,7 @@ VERSION=0.0.0
|
|||
|
||||
DOCKER=podman
|
||||
|
||||
all: container
|
||||
all: containers
|
||||
|
||||
# Not used
|
||||
wheels:
|
||||
|
|
@ -49,7 +49,9 @@ update-package-versions:
|
|||
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
|
||||
|
||||
container: update-package-versions
|
||||
FORCE:
|
||||
|
||||
containers: FORCE
|
||||
${DOCKER} build -f containers/Containerfile.base \
|
||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
|
|
@ -70,8 +72,8 @@ some-containers:
|
|||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ RUN dnf install -y python3.13 && \
|
|||
dnf clean all
|
||||
|
||||
RUN pip3 install --no-cache-dir \
|
||||
anthropic cohere mistralai openai google-generativeai \
|
||||
anthropic cohere mistralai openai \
|
||||
ollama \
|
||||
langchain==0.3.25 langchain-core==0.3.60 \
|
||||
langchain-text-splitters==0.3.8 \
|
||||
|
|
|
|||
260
docs/tech-specs/entity-centric-graph.md
Normal file
260
docs/tech-specs/entity-centric-graph.md
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
# Entity-Centric Knowledge Graph Storage on Cassandra
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes a storage model for RDF-style knowledge graphs on Apache Cassandra. The model uses an **entity-centric** approach where every entity knows every quad it participates in and the role it plays. This replaces a traditional multi-table SPO permutation approach with just two tables.
|
||||
|
||||
## Background and Motivation
|
||||
|
||||
### The Traditional Approach
|
||||
|
||||
A standard RDF quad store on Cassandra requires multiple denormalised tables to cover query patterns — typically 6 or more tables representing different permutations of Subject, Predicate, Object, and Dataset (SPOD). Each quad is written to every table, resulting in significant write amplification, operational overhead, and schema complexity.
|
||||
|
||||
Additionally, label resolution (fetching human-readable names for entities) requires separate round-trip queries, which is particularly costly in AI and GraphRAG use cases where labels are essential for LLM context.
|
||||
|
||||
### The Entity-Centric Insight
|
||||
|
||||
Every quad `(D, S, P, O)` involves up to 4 entities. By writing a row for each entity's participation in the quad, we guarantee that **any query with at least one known element will hit a partition key**. This covers all 16 query patterns with a single data table.
|
||||
|
||||
Key benefits:
|
||||
|
||||
- **2 tables** instead of 7+
|
||||
- **4 writes per quad** instead of 6+
|
||||
- **Label resolution for free** — an entity's labels are co-located with its relationships, naturally warming the application cache
|
||||
- **All 16 query patterns** served by single-partition reads
|
||||
- **Simpler operations** — one data table to tune, compact, and repair
|
||||
|
||||
## Schema
|
||||
|
||||
### Table 1: quads_by_entity
|
||||
|
||||
The primary data table. Every entity has a partition containing all quads it participates in. Named to reflect the query pattern (lookup by entity).
|
||||
|
||||
```sql
|
||||
CREATE TABLE quads_by_entity (
|
||||
collection text, -- Collection/tenant scope (always specified)
|
||||
entity text, -- The entity this row is about
|
||||
role text, -- 'S', 'P', 'O', 'G' — how this entity participates
|
||||
p text, -- Predicate of the quad
|
||||
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
|
||||
s text, -- Subject of the quad
|
||||
o text, -- Object of the quad
|
||||
d text, -- Dataset/graph of the quad
|
||||
dtype text, -- XSD datatype (when otype = 'L'), e.g. 'xsd:string'
|
||||
lang text, -- Language tag (when otype = 'L'), e.g. 'en', 'fr'
|
||||
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
|
||||
);
|
||||
```
|
||||
|
||||
**Partition key**: `(collection, entity)` — scoped to collection, one partition per entity.
|
||||
|
||||
**Clustering column order rationale**:
|
||||
|
||||
1. **role** — most queries start with "where is this entity a subject/object"
|
||||
2. **p** — next most common filter, "give me all `knows` relationships"
|
||||
3. **otype** — enables filtering by URI-valued vs literal-valued relationships
|
||||
4. **s, o, d** — remaining columns for uniqueness
|
||||
|
||||
### Table 2: quads_by_collection
|
||||
|
||||
Supports collection-level queries and deletion. Provides a manifest of all quads belonging to a collection. Named to reflect the query pattern (lookup by collection).
|
||||
|
||||
```sql
|
||||
CREATE TABLE quads_by_collection (
|
||||
collection text,
|
||||
d text, -- Dataset/graph of the quad
|
||||
s text, -- Subject of the quad
|
||||
p text, -- Predicate of the quad
|
||||
o text, -- Object of the quad
|
||||
otype text, -- 'U' (URI), 'L' (literal), 'T' (triple/reification)
|
||||
dtype text, -- XSD datatype (when otype = 'L')
|
||||
lang text, -- Language tag (when otype = 'L')
|
||||
PRIMARY KEY (collection, d, s, p, o)
|
||||
);
|
||||
```
|
||||
|
||||
Clustered by dataset first, enabling deletion at either collection or dataset granularity.
|
||||
|
||||
## Write Path
|
||||
|
||||
For each incoming quad `(D, S, P, O)` within a collection `C`, write **4 rows** to `quads_by_entity` and **1 row** to `quads_by_collection`.
|
||||
|
||||
### Example
|
||||
|
||||
Given the quad in collection `tenant1`:
|
||||
|
||||
```
|
||||
Dataset: https://example.org/graph1
|
||||
Subject: https://example.org/Alice
|
||||
Predicate: https://example.org/knows
|
||||
Object: https://example.org/Bob
|
||||
```
|
||||
|
||||
Write 4 rows to `quads_by_entity`:
|
||||
|
||||
| collection | entity | role | p | otype | s | o | d |
|
||||
|---|---|---|---|---|---|---|---|
|
||||
| tenant1 | https://example.org/graph1 | G | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
|
||||
| tenant1 | https://example.org/Alice | S | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
|
||||
| tenant1 | https://example.org/knows | P | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
|
||||
| tenant1 | https://example.org/Bob | O | https://example.org/knows | U | https://example.org/Alice | https://example.org/Bob | https://example.org/graph1 |
|
||||
|
||||
Write 1 row to `quads_by_collection`:
|
||||
|
||||
| collection | d | s | p | o | otype | dtype | lang |
|
||||
|---|---|---|---|---|---|---|---|
|
||||
| tenant1 | https://example.org/graph1 | https://example.org/Alice | https://example.org/knows | https://example.org/Bob | U | | |
|
||||
|
||||
### Literal Example
|
||||
|
||||
For a label triple:
|
||||
|
||||
```
|
||||
Dataset: https://example.org/graph1
|
||||
Subject: https://example.org/Alice
|
||||
Predicate: http://www.w3.org/2000/01/rdf-schema#label
|
||||
Object: "Alice Smith" (lang: en)
|
||||
```
|
||||
|
||||
The `otype` is `'L'`, `dtype` is `'xsd:string'`, and `lang` is `'en'`. The literal value `"Alice Smith"` is stored in `o`. Only 3 rows are needed in `quads_by_entity` — no row is written for the literal as entity, since literals are not independently queryable entities.
|
||||
|
||||
## Query Patterns
|
||||
|
||||
### All 16 DSPO Patterns
|
||||
|
||||
In the table below, "Perfect prefix" means the query uses a contiguous prefix of the clustering columns. "Partition scan + filter" means Cassandra reads a slice of one partition and filters in memory — still efficient, just not a pure prefix match.
|
||||
|
||||
| # | Known | Lookup entity | Clustering prefix | Efficiency |
|
||||
|---|---|---|---|---|
|
||||
| 1 | D,S,P,O | entity=S, role='S', p=P | Full match | Perfect prefix |
|
||||
| 2 | D,S,P,? | entity=S, role='S', p=P | Filter on D | Partition scan + filter |
|
||||
| 3 | D,S,?,O | entity=S, role='S' | Filter on D, O | Partition scan + filter |
|
||||
| 4 | D,?,P,O | entity=O, role='O', p=P | Filter on D | Partition scan + filter |
|
||||
| 5 | ?,S,P,O | entity=S, role='S', p=P | Filter on O | Partition scan + filter |
|
||||
| 6 | D,S,?,? | entity=S, role='S' | Filter on D | Partition scan + filter |
|
||||
| 7 | D,?,P,? | entity=P, role='P' | Filter on D | Partition scan + filter |
|
||||
| 8 | D,?,?,O | entity=O, role='O' | Filter on D | Partition scan + filter |
|
||||
| 9 | ?,S,P,? | entity=S, role='S', p=P | — | **Perfect prefix** |
|
||||
| 10 | ?,S,?,O | entity=S, role='S' | Filter on O | Partition scan + filter |
|
||||
| 11 | ?,?,P,O | entity=O, role='O', p=P | — | **Perfect prefix** |
|
||||
| 12 | D,?,?,? | entity=D, role='G' | — | **Perfect prefix** |
|
||||
| 13 | ?,S,?,? | entity=S, role='S' | — | **Perfect prefix** |
|
||||
| 14 | ?,?,P,? | entity=P, role='P' | — | **Perfect prefix** |
|
||||
| 15 | ?,?,?,O | entity=O, role='O' | — | **Perfect prefix** |
|
||||
| 16 | ?,?,?,? | — | Full scan | Exploration only |
|
||||
|
||||
**Key result**: 7 of the 15 non-trivial patterns are perfect clustering prefix hits. The remaining 8 are single-partition reads with in-partition filtering. Every query with at least one known element hits a partition key.
|
||||
|
||||
Pattern 16 (?,?,?,?) does not occur in practice since collection is always specified, reducing it to pattern 12.
|
||||
|
||||
### Common Query Examples
|
||||
|
||||
**Everything about an entity:**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice';
|
||||
```
|
||||
|
||||
**All outgoing relationships for an entity:**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
|
||||
AND role = 'S';
|
||||
```
|
||||
|
||||
**Specific predicate for an entity:**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
|
||||
AND role = 'S' AND p = 'https://example.org/knows';
|
||||
```
|
||||
|
||||
**Label for an entity (specific language):**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
|
||||
AND role = 'S' AND p = 'http://www.w3.org/2000/01/rdf-schema#label'
|
||||
AND otype = 'L';
|
||||
```
|
||||
|
||||
Then filter by `lang = 'en'` application-side if needed.
|
||||
|
||||
**Only URI-valued relationships (entity-to-entity links):**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Alice'
|
||||
AND role = 'S' AND p = 'https://example.org/knows' AND otype = 'U';
|
||||
```
|
||||
|
||||
**Reverse lookup — what points to this entity:**
|
||||
|
||||
```sql
|
||||
SELECT * FROM quads_by_entity
|
||||
WHERE collection = 'tenant1' AND entity = 'https://example.org/Bob'
|
||||
AND role = 'O';
|
||||
```
|
||||
|
||||
## Label Resolution and Cache Warming
|
||||
|
||||
One of the most significant advantages of the entity-centric model is that **label resolution becomes a free side effect**.
|
||||
|
||||
In the traditional multi-table model, fetching labels requires separate round-trip queries: retrieve triples, identify entity URIs in the results, then fetch `rdfs:label` for each. This N+1 pattern is expensive.
|
||||
|
||||
In the entity-centric model, querying an entity returns **all** its quads — including its labels, types, and other properties. When the application caches query results, labels are pre-warmed before anything asks for them.
|
||||
|
||||
Two usage regimes confirm this works well in practice:
|
||||
|
||||
- **Human-facing queries**: naturally small result sets, labels essential. Entity reads pre-warm the cache.
|
||||
- **AI/bulk queries**: large result sets with hard limits. Labels either unnecessary or needed only for a curated subset of entities already in cache.
|
||||
|
||||
The theoretical concern of resolving labels for huge result sets (e.g. 30,000 entities) is mitigated by the practical observation that no human or AI consumer usefully processes that many labels. Application-level query limits ensure cache pressure remains manageable.
|
||||
|
||||
## Wide Partitions and Reification
|
||||
|
||||
Reification (RDF-star style statements about statements) creates hub entities — e.g. a source document that supports thousands of extracted facts. This can produce wide partitions.
|
||||
|
||||
Mitigating factors:
|
||||
|
||||
- **Application-level query limits**: all GraphRAG and human-facing queries enforce hard limits, so wide partitions are never fully scanned on the hot read path
|
||||
- **Cassandra handles partial reads efficiently**: a clustering column scan with an early stop is fast even on large partitions
|
||||
- **Collection deletion** (the only operation that might traverse full partitions) is an accepted background process
|
||||
|
||||
## Collection Deletion
|
||||
|
||||
Triggered by API call, runs in the background (eventually consistent).
|
||||
|
||||
1. Read `quads_by_collection` for the target collection to get all quads
|
||||
2. Extract unique entities from the quads (s, p, o, d values)
|
||||
3. For each unique entity, delete the partition from `quads_by_entity`
|
||||
4. Delete the rows from `quads_by_collection`
|
||||
|
||||
The `quads_by_collection` table provides the index needed to locate all entity partitions without a full table scan. Partition-level deletes are efficient since `(collection, entity)` is the partition key.
|
||||
|
||||
## Migration Path from Multi-Table Model
|
||||
|
||||
The entity-centric model can coexist with the existing multi-table model during migration:
|
||||
|
||||
1. Deploy `quads_by_entity` and `quads_by_collection` tables alongside existing tables
|
||||
2. Dual-write new quads to both old and new tables
|
||||
3. Backfill existing data into the new tables
|
||||
4. Migrate read paths one query pattern at a time
|
||||
5. Decommission old tables once all reads are migrated
|
||||
|
||||
## Summary
|
||||
|
||||
| Aspect | Traditional (6-table) | Entity-centric (2-table) |
|
||||
|---|---|---|
|
||||
| Tables | 7+ | 2 |
|
||||
| Writes per quad | 6+ | 5 (4 data + 1 manifest) |
|
||||
| Label resolution | Separate round trips | Free via cache warming |
|
||||
| Query patterns | 16 across 6 tables | 16 on 1 table |
|
||||
| Schema complexity | High | Low |
|
||||
| Operational overhead | 6 tables to tune/repair | 1 data table |
|
||||
| Reification support | Additional complexity | Natural fit |
|
||||
| Object type filtering | Not available | Native (via otype clustering) |
|
||||
|
||||
573
docs/tech-specs/graph-contexts.md
Normal file
573
docs/tech-specs/graph-contexts.md
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
# Graph Contexts Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes changes to TrustGraph's core graph primitives to
|
||||
align with RDF 1.2 and support full RDF Dataset semantics. This is a breaking
|
||||
change for the 2.x release series.
|
||||
|
||||
### Versioning
|
||||
|
||||
- **2.0**: Early adopter release. Core features available, may not be fully
|
||||
production-ready.
|
||||
- **2.1 / 2.2**: Production release. Stability and completeness validated.
|
||||
|
||||
Flexibility on maturity is intentional - early adopters can access new
|
||||
capabilities before all features are production-hardened.
|
||||
|
||||
## Goals
|
||||
|
||||
The primary goals for this work are to enable metadata about facts/statements:
|
||||
|
||||
- **Temporal information**: Associate facts with time metadata
|
||||
- When a fact was believed to be true
|
||||
- When a fact became true
|
||||
- When a fact was discovered to be false
|
||||
|
||||
- **Provenance/Sources**: Track which sources support a fact
|
||||
- "This fact was supported by source X"
|
||||
- Link facts back to their origin documents
|
||||
|
||||
- **Veracity/Trust**: Record assertions about truth
|
||||
- "Person P asserted this was true"
|
||||
- "Person Q claims this is false"
|
||||
- Enable trust scoring and conflict detection
|
||||
|
||||
**Hypothesis**: Reification (RDF-star / quoted triples) is the key mechanism
|
||||
to achieve these outcomes, as all require making statements about statements.
|
||||
|
||||
## Background
|
||||
|
||||
To express "the fact (Alice knows Bob) was discovered on 2024-01-15" or
|
||||
"source X supports the claim (Y causes Z)", you need to reference an edge
|
||||
as a thing you can make statements about. Standard triples don't support this.
|
||||
|
||||
### Current Limitations
|
||||
|
||||
The current `Value` class in `trustgraph-base/trustgraph/schema/core/primitives.py`
|
||||
can represent:
|
||||
- URI nodes (`is_uri=True`)
|
||||
- Literal values (`is_uri=False`)
|
||||
|
||||
The `type` field exists but is not used to represent XSD datatypes.
|
||||
|
||||
## Technical Design
|
||||
|
||||
### RDF Features to Support
|
||||
|
||||
#### Core Features (Related to Reification Goals)
|
||||
|
||||
These features are directly related to the temporal, provenance, and veracity
|
||||
goals:
|
||||
|
||||
1. **RDF 1.2 Quoted Triples (RDF-star)**
|
||||
- Edges that point at other edges
|
||||
- A Triple can appear as the subject or object of another Triple
|
||||
- Enables statements about statements (reification)
|
||||
- Core mechanism for annotating individual facts
|
||||
|
||||
2. **RDF Dataset / Named Graphs**
|
||||
- Support for multiple named graphs within a dataset
|
||||
- Each graph identified by an IRI
|
||||
- Moves from triples (s, p, o) to quads (s, p, o, g)
|
||||
- Includes a default graph plus zero or more named graphs
|
||||
- The graph IRI can be a subject in statements, e.g.:
|
||||
```
|
||||
<graph-source-A> <discoveredOn> "2024-01-15"
|
||||
<graph-source-A> <hasVeracity> "high"
|
||||
```
|
||||
- Note: Named graphs are a separate feature from reification. They have
|
||||
uses beyond statement annotation (partitioning, access control, dataset
|
||||
organization) and should be treated as a distinct capability.
|
||||
|
||||
3. **Blank Nodes** (Limited Support)
|
||||
- Anonymous nodes without a global URI
|
||||
- Supported for compatibility when loading external RDF data
|
||||
- **Limited status**: No guarantees about stable identity after loading
|
||||
- Find them via wildcard queries (match by connections, not by ID)
|
||||
- Not a first-class feature - don't rely on precise blank node handling
|
||||
|
||||
#### Opportunistic Fixes (2.0 Breaking Change)
|
||||
|
||||
These features are not directly related to the reification goals but are
|
||||
valuable improvements to include while making breaking changes:
|
||||
|
||||
4. **Literal Datatypes**
|
||||
- Properly use the `type` field for XSD datatypes
|
||||
- Examples: xsd:string, xsd:integer, xsd:dateTime, etc.
|
||||
- Fixes current limitation: cannot represent dates or integers properly
|
||||
|
||||
5. **Language Tags**
|
||||
- Support for language attributes on string literals (@en, @fr, etc.)
|
||||
- Note: A literal has either a language tag OR a datatype, not both
|
||||
(except for rdf:langString)
|
||||
- Important for AI/multilingual use cases
|
||||
|
||||
### Data Models
|
||||
|
||||
#### Term (rename from Value)
|
||||
|
||||
The `Value` class will be renamed to `Term` to better reflect RDF terminology.
|
||||
This rename serves two purposes:
|
||||
1. Aligns naming with RDF concepts (a "Term" can be an IRI, literal, blank
|
||||
node, or quoted triple - not just a "value")
|
||||
2. Forces code review at the breaking change interface - any code still
|
||||
referencing `Value` is visibly broken and needs updating
|
||||
|
||||
A Term can represent:
|
||||
|
||||
- **IRI/URI** - A named node/resource
|
||||
- **Blank Node** - An anonymous node with local scope
|
||||
- **Literal** - A data value with either:
|
||||
- A datatype (XSD type), OR
|
||||
- A language tag
|
||||
- **Quoted Triple** - A triple used as a term (RDF 1.2)
|
||||
|
||||
##### Chosen Approach: Single Class with Type Discriminator
|
||||
|
||||
Serialization requirements drive the structure - a type discriminator is needed
|
||||
in the wire format regardless of the Python representation. A single class with
|
||||
a type field is the natural fit and aligns with the current `Value` pattern.
|
||||
|
||||
Single-character type codes provide compact serialization:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Term type constants
|
||||
IRI = "i" # IRI/URI node
|
||||
BLANK = "b" # Blank node
|
||||
LITERAL = "l" # Literal value
|
||||
TRIPLE = "t" # Quoted triple (RDF-star)
|
||||
|
||||
@dataclass
|
||||
class Term:
|
||||
type: str = "" # One of: IRI, BLANK, LITERAL, TRIPLE
|
||||
|
||||
# For IRI terms (type == IRI)
|
||||
iri: str = ""
|
||||
|
||||
# For blank nodes (type == BLANK)
|
||||
id: str = ""
|
||||
|
||||
# For literals (type == LITERAL)
|
||||
value: str = ""
|
||||
datatype: str = "" # XSD datatype URI (mutually exclusive with language)
|
||||
language: str = "" # Language tag (mutually exclusive with datatype)
|
||||
|
||||
# For quoted triples (type == TRIPLE)
|
||||
triple: "Triple | None" = None
|
||||
```
|
||||
|
||||
Usage examples:
|
||||
|
||||
```python
|
||||
# IRI term
|
||||
node = Term(type=IRI, iri="http://example.org/Alice")
|
||||
|
||||
# Literal with datatype
|
||||
age = Term(type=LITERAL, value="42", datatype="xsd:integer")
|
||||
|
||||
# Literal with language tag
|
||||
label = Term(type=LITERAL, value="Hello", language="en")
|
||||
|
||||
# Blank node
|
||||
anon = Term(type=BLANK, id="_:b1")
|
||||
|
||||
# Quoted triple (statement about a statement)
|
||||
inner = Triple(
|
||||
s=Term(type=IRI, iri="http://example.org/Alice"),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/Bob"),
|
||||
)
|
||||
reified = Term(type=TRIPLE, triple=inner)
|
||||
```
|
||||
|
||||
##### Alternatives Considered
|
||||
|
||||
**Option B: Union of specialized classes** (`Term = IRI | BlankNode | Literal | QuotedTriple`)
|
||||
- Rejected: Serialization would still need a type discriminator, adding complexity
|
||||
|
||||
**Option C: Base class with subclasses**
|
||||
- Rejected: Same serialization issue, plus dataclass inheritance quirks
|
||||
|
||||
#### Triple / Quad
|
||||
|
||||
The `Triple` class gains an optional graph field to become a quad:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Triple:
|
||||
s: Term | None = None # Subject
|
||||
p: Term | None = None # Predicate
|
||||
o: Term | None = None # Object
|
||||
g: str | None = None # Graph name (IRI), None = default graph
|
||||
```
|
||||
|
||||
Design decisions:
|
||||
- **Field name**: `g` for consistency with `s`, `p`, `o`
|
||||
- **Optional**: `None` means the default graph (unnamed)
|
||||
- **Type**: Plain string (IRI) rather than Term
|
||||
- Graph names are always IRIs
|
||||
- Blank nodes as graph names ruled out (too confusing)
|
||||
- No need for the full Term machinery
|
||||
|
||||
Note: The class name stays `Triple` even though it's technically a quad now.
|
||||
This avoids churn and "triple" is still the common terminology for the s/p/o
|
||||
portion. The graph context is metadata about where the triple lives.
|
||||
|
||||
### Candidate Query Patterns
|
||||
|
||||
The current query engine accepts combinations of S, P, O terms. With quoted
|
||||
triples, a triple itself becomes a valid term in those positions. Below are
|
||||
candidate query patterns that support the original goals.
|
||||
|
||||
#### Graph Parameter Semantics
|
||||
|
||||
Following SPARQL conventions for backward compatibility:
|
||||
|
||||
- **`g` omitted / None**: Query the default graph only
|
||||
- **`g` = specific IRI**: Query that named graph only
|
||||
- **`g` = wildcard / `*`**: Query across all graphs (equivalent to SPARQL
|
||||
`GRAPH ?g { ... }`)
|
||||
|
||||
This keeps simple queries simple and makes named graph queries opt-in.
|
||||
|
||||
Cross-graph queries (g=wildcard) are fully supported. The Cassandra schema
|
||||
includes dedicated tables (SPOG, POSG, OSPG) where g is a clustering column
|
||||
rather than a partition key, enabling efficient queries across all graphs.
|
||||
|
||||
#### Temporal Queries
|
||||
|
||||
**Find all facts discovered after a given date:**
|
||||
```
|
||||
S: ? # any quoted triple
|
||||
P: <discoveredOn>
|
||||
O: > "2024-01-15"^^xsd:date # date comparison
|
||||
```
|
||||
|
||||
**Find when a specific fact was believed true:**
|
||||
```
|
||||
S: << <Alice> <knows> <Bob> >> # quoted triple as subject
|
||||
P: <believedTrueFrom>
|
||||
O: ? # returns the date
|
||||
```
|
||||
|
||||
**Find facts that became false:**
|
||||
```
|
||||
S: ? # any quoted triple
|
||||
P: <discoveredFalseOn>
|
||||
O: ? # has any value (exists)
|
||||
```
|
||||
|
||||
#### Provenance Queries
|
||||
|
||||
**Find all facts supported by a specific source:**
|
||||
```
|
||||
S: ? # any quoted triple
|
||||
P: <supportedBy>
|
||||
O: <source:document-123>
|
||||
```
|
||||
|
||||
**Find which sources support a specific fact:**
|
||||
```
|
||||
S: << <DrugA> <treats> <DiseaseB> >> # quoted triple as subject
|
||||
P: <supportedBy>
|
||||
O: ? # returns source IRIs
|
||||
```
|
||||
|
||||
#### Veracity Queries
|
||||
|
||||
**Find assertions a person marked as true:**
|
||||
```
|
||||
S: ? # any quoted triple
|
||||
P: <assertedTrueBy>
|
||||
O: <person:Alice>
|
||||
```
|
||||
|
||||
**Find conflicting assertions (same fact, different veracity):**
|
||||
```
|
||||
# First query: facts asserted true
|
||||
S: ?
|
||||
P: <assertedTrueBy>
|
||||
O: ?
|
||||
|
||||
# Second query: facts asserted false
|
||||
S: ?
|
||||
P: <assertedFalseBy>
|
||||
O: ?
|
||||
|
||||
# Application logic: find intersection of subjects
|
||||
```
|
||||
|
||||
**Find facts with trust score below threshold:**
|
||||
```
|
||||
S: ? # any quoted triple
|
||||
P: <trustScore>
|
||||
O: < 0.5 # numeric comparison
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
Significant changes required across multiple components:
|
||||
|
||||
#### This Repository (trustgraph)
|
||||
|
||||
- **Schema primitives** (`trustgraph-base/trustgraph/schema/core/primitives.py`)
|
||||
- Value → Term rename
|
||||
- New Term structure with type discriminator
|
||||
- Triple gains `g` field for graph context
|
||||
|
||||
- **Message translators** (`trustgraph-base/trustgraph/messaging/translators/`)
|
||||
- Update for new Term/Triple structures
|
||||
- Serialization/deserialization for new fields
|
||||
|
||||
- **Gateway components**
|
||||
- Handle new Term and quad structures
|
||||
|
||||
- **Knowledge cores**
|
||||
- Core changes to support quads and reification
|
||||
|
||||
- **Knowledge manager**
|
||||
- Schema changes propagate here
|
||||
|
||||
- **Storage layers**
|
||||
- Cassandra: Schema redesign (see Implementation Details)
|
||||
- Other backends: Deferred to later phases
|
||||
|
||||
- **Command-line utilities**
|
||||
- Update for new data structures
|
||||
|
||||
- **REST API documentation**
|
||||
- OpenAPI spec updates
|
||||
|
||||
#### External Repositories
|
||||
|
||||
- **Python API** (this repo)
|
||||
- Client library updates for new structures
|
||||
|
||||
- **TypeScript APIs** (separate repo)
|
||||
- Client library updates
|
||||
|
||||
- **Workbench** (separate repo)
|
||||
- Significant state management changes
|
||||
|
||||
### APIs
|
||||
|
||||
#### REST API
|
||||
|
||||
- Documented in OpenAPI spec
|
||||
- Will need updates for new Term/Triple structures
|
||||
- New endpoints may be needed for graph context operations
|
||||
|
||||
#### Python API (this repo)
|
||||
|
||||
- Client library changes to match new primitives
|
||||
- Breaking changes to Term (was Value) and Triple
|
||||
|
||||
#### TypeScript API (separate repo)
|
||||
|
||||
- Parallel changes to Python API
|
||||
- Separate release coordination
|
||||
|
||||
#### Workbench (separate repo)
|
||||
|
||||
- Significant state management changes
|
||||
- UI updates for graph context features
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Phased Storage Implementation
|
||||
|
||||
Multiple graph store backends exist (Cassandra, Neo4j, etc.). Implementation
|
||||
will proceed in phases:
|
||||
|
||||
1. **Phase 1: Cassandra**
|
||||
- Start with the home-grown Cassandra store
|
||||
- Full control over the storage layer enables rapid iteration
|
||||
- Schema will be redesigned from scratch for quads + reification
|
||||
- Validate the data model and query patterns against real use cases
|
||||
|
||||
#### Cassandra Schema Design
|
||||
|
||||
Cassandra requires multiple tables to support different query access patterns
|
||||
(each table efficiently queries by its partition key + clustering columns).
|
||||
|
||||
##### Query Patterns
|
||||
|
||||
With quads (g, s, p, o), each position can be specified or wildcard, giving
|
||||
16 possible query patterns:
|
||||
|
||||
| # | g | s | p | o | Description |
|
||||
|---|---|---|---|---|-------------|
|
||||
| 1 | ? | ? | ? | ? | All quads |
|
||||
| 2 | ? | ? | ? | o | By object |
|
||||
| 3 | ? | ? | p | ? | By predicate |
|
||||
| 4 | ? | ? | p | o | By predicate + object |
|
||||
| 5 | ? | s | ? | ? | By subject |
|
||||
| 6 | ? | s | ? | o | By subject + object |
|
||||
| 7 | ? | s | p | ? | By subject + predicate |
|
||||
| 8 | ? | s | p | o | Full triple (which graphs?) |
|
||||
| 9 | g | ? | ? | ? | By graph |
|
||||
| 10 | g | ? | ? | o | By graph + object |
|
||||
| 11 | g | ? | p | ? | By graph + predicate |
|
||||
| 12 | g | ? | p | o | By graph + predicate + object |
|
||||
| 13 | g | s | ? | ? | By graph + subject |
|
||||
| 14 | g | s | ? | o | By graph + subject + object |
|
||||
| 15 | g | s | p | ? | By graph + subject + predicate |
|
||||
| 16 | g | s | p | o | Exact quad |
|
||||
|
||||
##### Table Design
|
||||
|
||||
Cassandra constraint: You can only efficiently query by partition key, then
|
||||
filter on clustering columns left-to-right. For g-wildcard queries, g must be
|
||||
a clustering column. For g-specified queries, g in the partition key is more
|
||||
efficient.
|
||||
|
||||
**Two table families needed:**
|
||||
|
||||
**Family A: g-wildcard queries** (g in clustering columns)
|
||||
|
||||
| Table | Partition | Clustering | Supports patterns |
|
||||
|-------|-----------|------------|-------------------|
|
||||
| SPOG | (user, collection, s) | p, o, g | 5, 7, 8 |
|
||||
| POSG | (user, collection, p) | o, s, g | 3, 4 |
|
||||
| OSPG | (user, collection, o) | s, p, g | 2, 6 |
|
||||
|
||||
**Family B: g-specified queries** (g in partition key)
|
||||
|
||||
| Table | Partition | Clustering | Supports patterns |
|
||||
|-------|-----------|------------|-------------------|
|
||||
| GSPO | (user, collection, g, s) | p, o | 9, 13, 15, 16 |
|
||||
| GPOS | (user, collection, g, p) | o, s | 11, 12 |
|
||||
| GOSP | (user, collection, g, o) | s, p | 10, 14 |
|
||||
|
||||
**Collection table** (for iteration and bulk deletion)
|
||||
|
||||
| Table | Partition | Clustering | Purpose |
|
||||
|-------|-----------|------------|---------|
|
||||
| COLL | (user, collection) | g, s, p, o | Enumerate all quads in collection |
|
||||
|
||||
##### Write and Delete Paths
|
||||
|
||||
**Write path**: Insert into all 7 tables.
|
||||
|
||||
**Delete collection path**:
|
||||
1. Iterate COLL table for `(user, collection)`
|
||||
2. For each quad, delete from all 6 query tables
|
||||
3. Delete from COLL table (or range delete)
|
||||
|
||||
**Delete single quad path**: Delete from all 7 tables directly.
|
||||
|
||||
##### Storage Cost
|
||||
|
||||
Each quad is stored 7 times. This is the cost of flexible querying combined
|
||||
with efficient collection deletion.
|
||||
|
||||
##### Quoted Triples in Storage
|
||||
|
||||
Subject or object can be a triple itself. Options:
|
||||
|
||||
**Option A: Serialize quoted triples to canonical string**
|
||||
```
|
||||
S: "<<http://ex/Alice|http://ex/knows|http://ex/Bob>>"
|
||||
P: http://ex/discoveredOn
|
||||
O: "2024-01-15"
|
||||
G: null
|
||||
```
|
||||
- Store quoted triple as serialized string in S or O columns
|
||||
- Query by exact match on serialized form
|
||||
- Pro: Simple, fits existing index patterns
|
||||
- Con: Can't query "find triples where quoted subject's predicate is X"
|
||||
|
||||
**Option B: Triple IDs / Hashes**
|
||||
```
|
||||
Triple table:
|
||||
id: hash(s,p,o,g)
|
||||
s, p, o, g: ...
|
||||
|
||||
Metadata table:
|
||||
subject_triple_id: <hash>
|
||||
p: http://ex/discoveredOn
|
||||
o: "2024-01-15"
|
||||
```
|
||||
- Assign each triple an ID (hash of components)
|
||||
- Reification metadata references triples by ID
|
||||
- Pro: Clean separation, can index triple IDs
|
||||
- Con: Requires computing/managing triple identity, two-phase lookups
|
||||
|
||||
**Recommendation**: Start with Option A (serialized strings) for simplicity.
|
||||
Option B may be needed if advanced query patterns over quoted triple
|
||||
components are required.
|
||||
|
||||
2. **Phase 2+: Other Backends**
|
||||
- Neo4j and other stores implemented in subsequent stages
|
||||
- Lessons learned from Cassandra inform these implementations
|
||||
|
||||
This approach de-risks the design by validating on a fully-controlled backend
|
||||
before committing to implementations across all stores.
|
||||
|
||||
#### Value → Term Rename
|
||||
|
||||
The `Value` class will be renamed to `Term`. This affects ~78 files across
|
||||
the codebase. The rename acts as a forcing function: any code still using
|
||||
`Value` is immediately identifiable as needing review/update for 2.0
|
||||
compatibility.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
Named graphs are not a security feature. Users and collections remain the
|
||||
security boundaries. Named graphs are purely for data organization and
|
||||
reification support.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Quoted triples add nesting depth - may impact query performance
|
||||
- Named graph indexing strategies needed for efficient graph-scoped queries
|
||||
- Cassandra schema design will need to accommodate quad storage efficiently
|
||||
|
||||
### Vector Store Boundary
|
||||
|
||||
Vector stores always reference IRIs only:
|
||||
- Never edges (quoted triples)
|
||||
- Never literal values
|
||||
- Never blank nodes
|
||||
|
||||
This keeps the vector store simple - it handles semantic similarity of named
|
||||
entities. The graph structure handles relationships, reification, and metadata.
|
||||
Quoted triples and named graphs don't complicate vector operations.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Use existing test strategy. As this is a breaking change, extensive focus on
|
||||
the end-to-end test suite to validate the new structures work correctly across
|
||||
all components.
|
||||
|
||||
## Migration Plan
|
||||
|
||||
- 2.0 is a breaking release; no backward compatibility required
|
||||
- Existing data may need migration to new schema (TBD based on final design)
|
||||
- Consider migration tooling for converting existing triples
|
||||
|
||||
## Open Questions
|
||||
|
||||
- **Blank nodes**: Limited support confirmed. May need to decide on
|
||||
skolemization strategy (generate IRIs on load, or preserve blank node IDs).
|
||||
- **Query syntax**: What is the concrete syntax for specifying quoted triples
|
||||
in queries? Need to define the query API.
|
||||
- ~~**Predicate vocabulary**~~: Resolved. Any valid RDF predicates permitted,
|
||||
including custom user-defined. Minimal assumptions about RDF validity.
|
||||
Very few locked-in values (e.g., `rdfs:label` used in some places).
|
||||
Strategy: avoid locking anything in unless absolutely necessary.
|
||||
- ~~**Vector store impact**~~: Resolved. Vector stores always point to IRIs
|
||||
only - never edges, literals, or blank nodes. Quoted triples and
|
||||
reification don't affect the vector store.
|
||||
- ~~**Named graph semantics**~~: Resolved. Queries default to the default
|
||||
graph (matches SPARQL behavior, backward compatible). Explicit graph
|
||||
parameter required to query named graphs or all graphs.
|
||||
|
||||
## References
|
||||
|
||||
- [RDF 1.2 Concepts](https://www.w3.org/TR/rdf12-concepts/)
|
||||
- [RDF-star and SPARQL-star](https://w3c.github.io/rdf-star/)
|
||||
- [RDF Dataset](https://www.w3.org/TR/rdf11-concepts/#section-dataset)
|
||||
455
docs/tech-specs/jsonl-prompt-output.md
Normal file
455
docs/tech-specs/jsonl-prompt-output.md
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
# JSONL Prompt Output Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This specification describes the implementation of JSONL (JSON Lines) output
|
||||
format for prompt responses in TrustGraph. JSONL enables truncation-resilient
|
||||
extraction of structured data from LLM responses, addressing critical issues
|
||||
with JSON array outputs being corrupted when LLM responses hit output token
|
||||
limits.
|
||||
|
||||
This implementation supports the following use cases:
|
||||
|
||||
1. **Truncation-Resilient Extraction**: Extract valid partial results even when
|
||||
LLM output is truncated mid-response
|
||||
2. **Large-Scale Extraction**: Handle extraction of many items without risk of
|
||||
complete failure due to token limits
|
||||
3. **Mixed-Type Extraction**: Support extraction of multiple entity types
|
||||
(definitions, relationships, entities, attributes) in a single prompt
|
||||
4. **Streaming-Compatible Output**: Enable future streaming/incremental
|
||||
processing of extraction results
|
||||
|
||||
## Goals
|
||||
|
||||
- **Backward Compatibility**: Existing prompts using `response-type: "text"` and
|
||||
`response-type: "json"` continue to work without modification
|
||||
- **Truncation Resilience**: Partial LLM outputs yield partial valid results
|
||||
rather than complete failure
|
||||
- **Schema Validation**: Support JSON Schema validation for individual objects
|
||||
- **Discriminated Unions**: Support mixed-type outputs using a `type` field
|
||||
discriminator
|
||||
- **Minimal API Changes**: Extend existing prompt configuration with new
|
||||
response type and schema key
|
||||
|
||||
## Background
|
||||
|
||||
### Current Architecture
|
||||
|
||||
The prompt service supports two response types:
|
||||
|
||||
1. `response-type: "text"` - Raw text response returned as-is
|
||||
2. `response-type: "json"` - JSON parsed from response, validated against
|
||||
optional `schema`
|
||||
|
||||
Current implementation in `trustgraph-flow/trustgraph/template/prompt_manager.py`:
|
||||
|
||||
```python
|
||||
class Prompt:
|
||||
def __init__(self, template, response_type = "text", terms=None, schema=None):
|
||||
self.template = template
|
||||
self.response_type = response_type
|
||||
self.terms = terms
|
||||
self.schema = schema
|
||||
```
|
||||
|
||||
### Current Limitations
|
||||
|
||||
When extraction prompts request output as JSON arrays (`[{...}, {...}, ...]`):
|
||||
|
||||
- **Truncation corruption**: If the LLM hits output token limits mid-array, the
|
||||
entire response becomes invalid JSON and cannot be parsed
|
||||
- **All-or-nothing parsing**: Must receive complete output before parsing
|
||||
- **No partial results**: A truncated response yields zero usable data
|
||||
- **Unreliable for large extractions**: More extracted items = higher failure risk
|
||||
|
||||
This specification addresses these limitations by introducing JSONL format for
|
||||
extraction prompts, where each extracted item is a complete JSON object on its
|
||||
own line.
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Response Type Extension
|
||||
|
||||
Add a new response type `"jsonl"` alongside existing `"text"` and `"json"` types.
|
||||
|
||||
#### Configuration Changes
|
||||
|
||||
**New response type value:**
|
||||
|
||||
```
|
||||
"response-type": "jsonl"
|
||||
```
|
||||
|
||||
**Schema interpretation:**
|
||||
|
||||
The existing `"schema"` key is used for both `"json"` and `"jsonl"` response
|
||||
types. The interpretation depends on the response type:
|
||||
|
||||
- `"json"`: Schema describes the entire response (typically an array or object)
|
||||
- `"jsonl"`: Schema describes each individual line/object
|
||||
|
||||
```json
|
||||
{
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": { "type": "string" },
|
||||
"definition": { "type": "string" }
|
||||
},
|
||||
"required": ["entity", "definition"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This avoids changes to prompt configuration tooling and editors.
|
||||
|
||||
### JSONL Format Specification
|
||||
|
||||
#### Simple Extraction
|
||||
|
||||
For prompts extracting a single type of object (definitions, relationships,
|
||||
topics, rows), the output is one JSON object per line with no wrapper:
|
||||
|
||||
**Prompt output format:**
|
||||
```
|
||||
{"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"}
|
||||
{"entity": "chlorophyll", "definition": "Green pigment in plants"}
|
||||
{"entity": "mitochondria", "definition": "Powerhouse of the cell"}
|
||||
```
|
||||
|
||||
**Contrast with previous JSON array format:**
|
||||
```json
|
||||
[
|
||||
{"entity": "photosynthesis", "definition": "Process by which plants convert sunlight"},
|
||||
{"entity": "chlorophyll", "definition": "Green pigment in plants"},
|
||||
{"entity": "mitochondria", "definition": "Powerhouse of the cell"}
|
||||
]
|
||||
```
|
||||
|
||||
If the LLM truncates after line 2, the JSON array format yields invalid JSON,
|
||||
while JSONL yields two valid objects.
|
||||
|
||||
#### Mixed-Type Extraction (Discriminated Unions)
|
||||
|
||||
For prompts extracting multiple types of objects (e.g., both definitions and
|
||||
relationships, or entities, relationships, and attributes), use a `"type"`
|
||||
field as discriminator:
|
||||
|
||||
**Prompt output format:**
|
||||
```
|
||||
{"type": "definition", "entity": "DNA", "definition": "Molecule carrying genetic instructions"}
|
||||
{"type": "relationship", "subject": "DNA", "predicate": "located_in", "object": "cell nucleus", "object-entity": true}
|
||||
{"type": "definition", "entity": "RNA", "definition": "Molecule that carries genetic information"}
|
||||
{"type": "relationship", "subject": "RNA", "predicate": "transcribed_from", "object": "DNA", "object-entity": true}
|
||||
```
|
||||
|
||||
**Schema for discriminated unions uses `oneOf`:**
|
||||
```json
|
||||
{
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": { "const": "definition" },
|
||||
"entity": { "type": "string" },
|
||||
"definition": { "type": "string" }
|
||||
},
|
||||
"required": ["type", "entity", "definition"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": { "const": "relationship" },
|
||||
"subject": { "type": "string" },
|
||||
"predicate": { "type": "string" },
|
||||
"object": { "type": "string" },
|
||||
"object-entity": { "type": "boolean" }
|
||||
},
|
||||
"required": ["type", "subject", "predicate", "object", "object-entity"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Ontology Extraction
|
||||
|
||||
For ontology-based extraction with entities, relationships, and attributes:
|
||||
|
||||
**Prompt output format:**
|
||||
```
|
||||
{"type": "entity", "entity": "Cornish pasty", "entity_type": "fo/Recipe"}
|
||||
{"type": "entity", "entity": "beef", "entity_type": "fo/Food"}
|
||||
{"type": "relationship", "subject": "Cornish pasty", "subject_type": "fo/Recipe", "relation": "fo/has_ingredient", "object": "beef", "object_type": "fo/Food"}
|
||||
{"type": "attribute", "entity": "Cornish pasty", "entity_type": "fo/Recipe", "attribute": "fo/serves", "value": "4 people"}
|
||||
```
|
||||
|
||||
### Implementation Details
|
||||
|
||||
#### Prompt Class
|
||||
|
||||
The existing `Prompt` class requires no changes. The `schema` field is reused
|
||||
for JSONL, with its interpretation determined by `response_type`:
|
||||
|
||||
```python
|
||||
class Prompt:
|
||||
def __init__(self, template, response_type="text", terms=None, schema=None):
|
||||
self.template = template
|
||||
self.response_type = response_type
|
||||
self.terms = terms
|
||||
self.schema = schema # Interpretation depends on response_type
|
||||
```
|
||||
|
||||
#### PromptManager.load_config
|
||||
|
||||
No changes required - existing configuration loading already handles the
|
||||
`schema` key.
|
||||
|
||||
#### JSONL Parsing
|
||||
|
||||
Add a new parsing method for JSONL responses:
|
||||
|
||||
```python
|
||||
def parse_jsonl(self, text):
|
||||
"""
|
||||
Parse JSONL response, returning list of valid objects.
|
||||
|
||||
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
|
||||
This provides truncation resilience - partial output yields partial results.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for line_num, line in enumerate(text.strip().split('\n'), 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Skip markdown code fence markers if present
|
||||
if line.startswith('```'):
|
||||
continue
|
||||
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
results.append(obj)
|
||||
except json.JSONDecodeError as e:
|
||||
# Log warning but continue - this provides truncation resilience
|
||||
logger.warning(f"JSONL parse error on line {line_num}: {e}")
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
#### PromptManager.invoke Changes
|
||||
|
||||
Extend the invoke method to handle the new response type:
|
||||
|
||||
```python
|
||||
async def invoke(self, id, input, llm):
|
||||
logger.debug("Invoking prompt template...")
|
||||
|
||||
terms = self.terms | self.prompts[id].terms | input
|
||||
resp_type = self.prompts[id].response_type
|
||||
|
||||
prompt = {
|
||||
"system": self.system_template.render(terms),
|
||||
"prompt": self.render(id, input)
|
||||
}
|
||||
|
||||
resp = await llm(**prompt)
|
||||
|
||||
if resp_type == "text":
|
||||
return resp
|
||||
|
||||
if resp_type == "json":
|
||||
try:
|
||||
obj = self.parse_json(resp)
|
||||
except:
|
||||
logger.error(f"JSON parse failed: {resp}")
|
||||
raise RuntimeError("JSON parse fail")
|
||||
|
||||
if self.prompts[id].schema:
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
logger.debug("Schema validation successful")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Schema validation fail: {e}")
|
||||
|
||||
return obj
|
||||
|
||||
if resp_type == "jsonl":
|
||||
objects = self.parse_jsonl(resp)
|
||||
|
||||
if not objects:
|
||||
logger.warning("JSONL parse returned no valid objects")
|
||||
return []
|
||||
|
||||
# Validate each object against schema if provided
|
||||
if self.prompts[id].schema:
|
||||
validated = []
|
||||
for i, obj in enumerate(objects):
|
||||
try:
|
||||
validate(instance=obj, schema=self.prompts[id].schema)
|
||||
validated.append(obj)
|
||||
except Exception as e:
|
||||
logger.warning(f"Object {i} failed schema validation: {e}")
|
||||
return validated
|
||||
|
||||
return objects
|
||||
|
||||
raise RuntimeError(f"Response type {resp_type} not known")
|
||||
```
|
||||
|
||||
### Affected Prompts
|
||||
|
||||
The following prompts should be migrated to JSONL format:
|
||||
|
||||
| Prompt ID | Description | Type Field |
|
||||
|-----------|-------------|------------|
|
||||
| `extract-definitions` | Entity/definition extraction | No (single type) |
|
||||
| `extract-relationships` | Relationship extraction | No (single type) |
|
||||
| `extract-topics` | Topic/definition extraction | No (single type) |
|
||||
| `extract-rows` | Structured row extraction | No (single type) |
|
||||
| `agent-kg-extract` | Combined definition + relationship extraction | Yes: `"definition"`, `"relationship"` |
|
||||
| `extract-with-ontologies` / `ontology-extract` | Ontology-based extraction | Yes: `"entity"`, `"relationship"`, `"attribute"` |
|
||||
|
||||
### API Changes
|
||||
|
||||
#### Client Perspective
|
||||
|
||||
JSONL parsing is transparent to prompt service API callers. The parsing occurs
|
||||
server-side in the prompt service, and the response is returned via the standard
|
||||
`PromptResponse.object` field as a serialized JSON array.
|
||||
|
||||
When clients call the prompt service (via `PromptClient.prompt()` or similar):
|
||||
|
||||
- **`response-type: "json"`** with array schema → client receives Python `list`
|
||||
- **`response-type: "jsonl"`** → client receives Python `list`
|
||||
|
||||
From the client's perspective, both return identical data structures. The
|
||||
difference is entirely in how the LLM output is parsed server-side:
|
||||
|
||||
- JSON array format: Single `json.loads()` call; fails completely if truncated
|
||||
- JSONL format: Line-by-line parsing; yields partial results if truncated
|
||||
|
||||
This means existing client code expecting a list from extraction prompts
|
||||
requires no changes when migrating prompts from JSON to JSONL format.
|
||||
|
||||
#### Server Return Value
|
||||
|
||||
For `response-type: "jsonl"`, the `PromptManager.invoke()` method returns a
|
||||
`list[dict]` containing all successfully parsed and validated objects. This
|
||||
list is then serialized to JSON for the `PromptResponse.object` field.
|
||||
|
||||
#### Error Handling
|
||||
|
||||
- Empty results: Returns empty list `[]` with warning log
|
||||
- Partial parse failure: Returns list of successfully parsed objects with
|
||||
warning logs for failures
|
||||
- Complete parse failure: Returns empty list `[]` with warning logs
|
||||
|
||||
This differs from `response-type: "json"` which raises `RuntimeError` on
|
||||
parse failure. The lenient behavior for JSONL is intentional to provide
|
||||
truncation resilience.
|
||||
|
||||
### Configuration Example
|
||||
|
||||
Complete prompt configuration example:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "Extract all entities and their definitions from the following text. Output one JSON object per line.\n\nText:\n{{text}}\n\nOutput format per line:\n{\"entity\": \"<name>\", \"definition\": \"<definition>\"}",
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {
|
||||
"type": "string",
|
||||
"description": "The entity name"
|
||||
},
|
||||
"definition": {
|
||||
"type": "string",
|
||||
"description": "A clear definition of the entity"
|
||||
}
|
||||
},
|
||||
"required": ["entity", "definition"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Input Validation**: JSON parsing uses standard `json.loads()` which is safe
|
||||
against injection attacks
|
||||
- **Schema Validation**: Uses `jsonschema.validate()` for schema enforcement
|
||||
- **No New Attack Surface**: JSONL parsing is strictly safer than JSON array
|
||||
parsing due to line-by-line processing
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **Memory**: Line-by-line parsing uses less peak memory than loading full JSON
|
||||
arrays
|
||||
- **Latency**: Parsing performance is comparable to JSON array parsing
|
||||
- **Validation**: Schema validation runs per-object, which adds overhead but
|
||||
enables partial results on validation failure
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
|
||||
- JSONL parsing with valid input
|
||||
- JSONL parsing with empty lines
|
||||
- JSONL parsing with markdown code fences
|
||||
- JSONL parsing with truncated final line
|
||||
- JSONL parsing with invalid JSON lines interspersed
|
||||
- Schema validation with `oneOf` discriminated unions
|
||||
- Backward compatibility: existing `"text"` and `"json"` prompts unchanged
|
||||
|
||||
### Integration Tests
|
||||
|
||||
- End-to-end extraction with JSONL prompts
|
||||
- Extraction with simulated truncation (artificially limited response)
|
||||
- Mixed-type extraction with type discriminator
|
||||
- Ontology extraction with all three types
|
||||
|
||||
### Extraction Quality Tests
|
||||
|
||||
- Compare extraction results: JSONL vs JSON array format
|
||||
- Verify truncation resilience: JSONL yields partial results where JSON fails
|
||||
|
||||
## Migration Plan
|
||||
|
||||
### Phase 1: Implementation
|
||||
|
||||
1. Implement `parse_jsonl()` method in `PromptManager`
|
||||
2. Extend `invoke()` to handle `response-type: "jsonl"`
|
||||
3. Add unit tests
|
||||
|
||||
### Phase 2: Prompt Migration
|
||||
|
||||
1. Update `extract-definitions` prompt and configuration
|
||||
2. Update `extract-relationships` prompt and configuration
|
||||
3. Update `extract-topics` prompt and configuration
|
||||
4. Update `extract-rows` prompt and configuration
|
||||
5. Update `agent-kg-extract` prompt and configuration
|
||||
6. Update `extract-with-ontologies` prompt and configuration
|
||||
|
||||
### Phase 3: Downstream Updates
|
||||
|
||||
1. Update any code consuming extraction results to handle list return type
|
||||
2. Update code that categorizes mixed-type extractions by `type` field
|
||||
3. Update tests that assert on extraction output format
|
||||
|
||||
## Open Questions
|
||||
|
||||
None at this time.
|
||||
|
||||
## References
|
||||
|
||||
- Current implementation: `trustgraph-flow/trustgraph/template/prompt_manager.py`
|
||||
- JSON Lines specification: https://jsonlines.org/
|
||||
- JSON Schema `oneOf`: https://json-schema.org/understanding-json-schema/reference/combining.html#oneof
|
||||
- Related specification: Streaming LLM Responses (`docs/tech-specs/streaming-llm-responses.md`)
|
||||
613
docs/tech-specs/structured-data-2.md
Normal file
613
docs/tech-specs/structured-data-2.md
Normal file
|
|
@ -0,0 +1,613 @@
|
|||
# Structured Data Technical Specification (Part 2)
|
||||
|
||||
## Overview
|
||||
|
||||
This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`.
|
||||
|
||||
## Problem Statements
|
||||
|
||||
### 1. Naming Inconsistency: "Object" vs "Row"
|
||||
|
||||
The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion:
|
||||
|
||||
- "Object" is an overloaded term in software (Python objects, JSON objects, etc.)
|
||||
- The data being handled is fundamentally tabular - rows in tables with defined schemas
|
||||
- "Row" more accurately describes the data model and aligns with database terminology
|
||||
|
||||
This inconsistency appears in module names, class names, message types, and documentation.
|
||||
|
||||
### 2. Row Store Query Limitations
|
||||
|
||||
The current row store implementation has significant query limitations:
|
||||
|
||||
**Natural Language Mismatch**: Queries struggle with real-world data variations. For example:
|
||||
- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"`
|
||||
- Abbreviations, case differences, and formatting variations break exact-match queries
|
||||
- Users expect semantic understanding, but the store provides literal matching
|
||||
|
||||
**Schema Evolution Issues**: Changing schemas causes problems:
|
||||
- Existing data may not conform to updated schemas
|
||||
- Table structure changes can break queries and data integrity
|
||||
- No clear migration path for schema updates
|
||||
|
||||
### 3. Row Embeddings Required
|
||||
|
||||
Related to problem 2, the system needs vector embeddings for row data to enable:
|
||||
|
||||
- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST")
|
||||
- Similarity matching for fuzzy queries
|
||||
- Hybrid search combining structured filters with semantic similarity
|
||||
- Better natural language query support
|
||||
|
||||
The embedding service was specified but not implemented.
|
||||
|
||||
### 4. Row Data Ingestion Incomplete
|
||||
|
||||
The structured data ingestion pipeline is not fully operational:
|
||||
|
||||
- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.)
|
||||
- The ingestion service that uses these prompts is not plumbed into the system
|
||||
- No end-to-end path for loading pre-structured data into the row store
|
||||
|
||||
## Goals
|
||||
|
||||
- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations
|
||||
- **Consistent Naming**: Standardize on "row" terminology throughout the codebase
|
||||
- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings
|
||||
- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data
|
||||
|
||||
## Technical Design
|
||||
|
||||
### Unified Row Storage Schema
|
||||
|
||||
The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations.
|
||||
|
||||
The new design uses a single unified table for all row data:
|
||||
|
||||
```sql
|
||||
CREATE TABLE rows (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
index_value frozen<list<text>>,
|
||||
data map<text, text>,
|
||||
source text,
|
||||
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
||||
)
|
||||
```
|
||||
|
||||
#### Column Definitions
|
||||
|
||||
| Column | Type | Description |
|
||||
|--------|------|-------------|
|
||||
| `collection` | `text` | Data collection/import identifier (from metadata) |
|
||||
| `schema_name` | `text` | Name of the schema this row conforms to |
|
||||
| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites |
|
||||
| `index_value` | `frozen<list<text>>` | Index value(s) as a list |
|
||||
| `data` | `map<text, text>` | Row data as key-value pairs |
|
||||
| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. |
|
||||
|
||||
#### Index Handling
|
||||
|
||||
Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility.
|
||||
|
||||
**Single-field index example:**
|
||||
- Schema defines `email` as indexed
|
||||
- `index_name = "email"`
|
||||
- `index_value = ['foo@bar.com']`
|
||||
|
||||
**Composite index example:**
|
||||
- Schema defines composite index on `region` and `status`
|
||||
- `index_name = "region,status"` (field names sorted and comma-joined)
|
||||
- `index_value = ['US', 'active']` (values in same order as field names)
|
||||
|
||||
**Primary key example:**
|
||||
- Schema defines `customer_id` as primary key
|
||||
- `index_name = "customer_id"`
|
||||
- `index_value = ['CUST001']`
|
||||
|
||||
#### Query Patterns
|
||||
|
||||
All queries follow the same pattern regardless of which index is used:
|
||||
|
||||
```sql
|
||||
SELECT * FROM rows
|
||||
WHERE collection = 'import_2024'
|
||||
AND schema_name = 'customers'
|
||||
AND index_name = 'email'
|
||||
AND index_value = ['foo@bar.com']
|
||||
```
|
||||
|
||||
#### Design Trade-offs
|
||||
|
||||
**Advantages:**
|
||||
- Schema changes don't require table structure changes
|
||||
- Row data is opaque to Cassandra - field additions/removals are transparent
|
||||
- Consistent query pattern for all access methods
|
||||
- No Cassandra secondary indexes (which can be slow at scale)
|
||||
- Native Cassandra types throughout (`map`, `frozen<list>`)
|
||||
|
||||
**Trade-offs:**
|
||||
- Write amplification: each row insert = N inserts (one per indexed field)
|
||||
- Storage overhead from duplicated row data
|
||||
- Type information stored in schema config, conversion at application layer
|
||||
|
||||
#### Consistency Model
|
||||
|
||||
The design accepts certain simplifications:
|
||||
|
||||
1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row.
|
||||
|
||||
2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed.
|
||||
|
||||
### Partition Tracking and Deletion
|
||||
|
||||
#### The Problem
|
||||
|
||||
With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data.
|
||||
|
||||
#### Partition Tracking Table
|
||||
|
||||
A secondary lookup table tracks which partitions exist:
|
||||
|
||||
```sql
|
||||
CREATE TABLE row_partitions (
|
||||
collection text,
|
||||
schema_name text,
|
||||
index_name text,
|
||||
PRIMARY KEY ((collection), schema_name, index_name)
|
||||
)
|
||||
```
|
||||
|
||||
This enables efficient discovery of partitions for deletion operations.
|
||||
|
||||
#### Row Writer Behavior
|
||||
|
||||
The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row:
|
||||
|
||||
1. Check if `(collection, schema_name)` is in the cache
|
||||
2. If not cached (first row for this pair):
|
||||
- Look up the schema config to get all index names
|
||||
- Insert entries into `row_partitions` for each `(collection, schema_name, index_name)`
|
||||
- Add the pair to the cache
|
||||
3. Proceed with writing the row data
|
||||
|
||||
The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names.
|
||||
|
||||
This approach ensures:
|
||||
- Lookup table writes happen once per `(collection, schema_name)` pair, not per row
|
||||
- The lookup table reflects the indexes that were active when data was written
|
||||
- Schema changes mid-import are picked up correctly
|
||||
|
||||
#### Deletion Operations
|
||||
|
||||
**Delete collection:**
|
||||
```sql
|
||||
-- 1. Discover all partitions
|
||||
SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X';
|
||||
|
||||
-- 2. Delete each partition from rows table
|
||||
DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...';
|
||||
-- (repeat for each discovered partition)
|
||||
|
||||
-- 3. Clean up the lookup table
|
||||
DELETE FROM row_partitions WHERE collection = 'X';
|
||||
```
|
||||
|
||||
**Delete collection + schema:**
|
||||
```sql
|
||||
-- 1. Discover partitions for this schema
|
||||
SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
|
||||
|
||||
-- 2. Delete each partition from rows table
|
||||
DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...';
|
||||
-- (repeat for each discovered partition)
|
||||
|
||||
-- 3. Clean up the lookup table entries
|
||||
DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
|
||||
```
|
||||
|
||||
### Row Embeddings
|
||||
|
||||
Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street").
|
||||
|
||||
#### Design Overview
|
||||
|
||||
Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra.
|
||||
|
||||
#### Qdrant Collection Structure
|
||||
|
||||
One Qdrant collection per `(user, collection, schema_name, dimension)` tuple:
|
||||
|
||||
- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}`
|
||||
- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix)
|
||||
- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist
|
||||
|
||||
#### What Gets Embedded
|
||||
|
||||
The text representation of index values:
|
||||
|
||||
| Index Type | Example `index_value` | Text to Embed |
|
||||
|------------|----------------------|---------------|
|
||||
| Single-field | `['foo@bar.com']` | `"foo@bar.com"` |
|
||||
| Composite | `['US', 'active']` | `"US active"` (space-joined) |
|
||||
|
||||
#### Point Structure
|
||||
|
||||
Each Qdrant point contains:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "<uuid>",
|
||||
"vector": [0.1, 0.2, ...],
|
||||
"payload": {
|
||||
"index_name": "street_name",
|
||||
"index_value": ["CHESTNUT ST"],
|
||||
"text": "CHESTNUT ST"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Payload Field | Description |
|
||||
|---------------|-------------|
|
||||
| `index_name` | The indexed field(s) this embedding represents |
|
||||
| `index_value` | The original list of values (for Cassandra lookup) |
|
||||
| `text` | The text that was embedded (for debugging/display) |
|
||||
|
||||
Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name.
|
||||
|
||||
#### Query Flow
|
||||
|
||||
1. User queries for "Chestnut Street" within user U, collection X, schema Y
|
||||
2. Embed the query text
|
||||
3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_`
|
||||
4. Search matching Qdrant collection(s) for nearest vectors
|
||||
5. Get matching points with payloads containing `index_name` and `index_value`
|
||||
6. Query Cassandra:
|
||||
```sql
|
||||
SELECT * FROM rows
|
||||
WHERE collection = 'X'
|
||||
AND schema_name = 'Y'
|
||||
AND index_name = '<from payload>'
|
||||
AND index_value = <from payload>
|
||||
```
|
||||
7. Return matched rows
|
||||
|
||||
#### Optional: Filtering by Index Name
|
||||
|
||||
Queries can optionally filter by `index_name` in Qdrant to search only specific fields:
|
||||
|
||||
- **"Find any field matching 'Chestnut'"** → search all vectors in the collection
|
||||
- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'`
|
||||
|
||||
#### Architecture
|
||||
|
||||
Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings):
|
||||
|
||||
- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings`
|
||||
- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant
|
||||
|
||||
The Cassandra row writer is a separate parallel consumer:
|
||||
|
||||
- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra
|
||||
|
||||
All three services consume from the same flow, keeping them decoupled. This allows:
|
||||
- Independent scaling of Cassandra writes vs embedding generation vs vector storage
|
||||
- Embedding services can be disabled if not needed
|
||||
- Failures in one service don't affect the others
|
||||
- Consistent architecture with GraphRAG pipelines
|
||||
|
||||
#### Write Path
|
||||
|
||||
**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`:
|
||||
|
||||
1. Look up the schema to find indexed fields
|
||||
2. For each indexed field:
|
||||
- Build the text representation of the index value
|
||||
- Compute embedding via the embeddings service
|
||||
3. Output a `RowEmbeddings` message containing all computed vectors
|
||||
|
||||
**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`:
|
||||
|
||||
1. For each embedding in the message:
|
||||
- Determine Qdrant collection from `(user, collection, schema_name, dimension)`
|
||||
- Create collection if needed (lazy creation on first write)
|
||||
- Upsert point with vector and payload
|
||||
|
||||
#### Message Types
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class RowIndexEmbedding:
|
||||
index_name: str # The indexed field name(s)
|
||||
index_value: list[str] # The field value(s)
|
||||
text: str # Text that was embedded
|
||||
vectors: list[list[float]] # Computed embedding vectors
|
||||
|
||||
@dataclass
|
||||
class RowEmbeddings:
|
||||
metadata: Metadata
|
||||
schema_name: str
|
||||
embeddings: list[RowIndexEmbedding]
|
||||
```
|
||||
|
||||
#### Deletion Integration
|
||||
|
||||
Qdrant collections are discovered by prefix matching on the collection name pattern:
|
||||
|
||||
**Delete `(user, collection)`:**
|
||||
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_`
|
||||
2. Delete each matching collection
|
||||
3. Delete Cassandra rows partitions (as documented above)
|
||||
4. Clean up `row_partitions` entries
|
||||
|
||||
**Delete `(user, collection, schema_name)`:**
|
||||
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_`
|
||||
2. Delete each matching collection (handles multiple dimensions)
|
||||
3. Delete Cassandra rows partitions
|
||||
4. Clean up `row_partitions`
|
||||
|
||||
#### Module Locations
|
||||
|
||||
| Stage | Module | Entry Point |
|
||||
|-------|--------|-------------|
|
||||
| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` |
|
||||
| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` |
|
||||
|
||||
### Row Embeddings Query API
|
||||
|
||||
The row embeddings query is a **separate API** from the GraphQL row query service:
|
||||
|
||||
| API | Purpose | Backend |
|
||||
|-----|---------|---------|
|
||||
| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra |
|
||||
| Row Embeddings Query | Fuzzy/semantic matching | Qdrant |
|
||||
|
||||
This separation keeps concerns clean:
|
||||
- GraphQL service focuses on exact, structured queries
|
||||
- Embeddings API handles semantic similarity
|
||||
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
|
||||
|
||||
#### Request/Response Schema
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class RowEmbeddingsRequest:
|
||||
vectors: list[list[float]] # Query vectors (pre-computed embeddings)
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
schema_name: str = ""
|
||||
index_name: str = "" # Optional: filter to specific index
|
||||
limit: int = 10 # Max results per vector
|
||||
|
||||
@dataclass
|
||||
class RowIndexMatch:
|
||||
index_name: str = "" # The matched index field(s)
|
||||
index_value: list[str] = [] # The matched value(s)
|
||||
text: str = "" # Original text that was embedded
|
||||
score: float = 0.0 # Similarity score
|
||||
|
||||
@dataclass
|
||||
class RowEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
matches: list[RowIndexMatch] = []
|
||||
```
|
||||
|
||||
#### Query Processor
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
|
||||
|
||||
Entry point: `row-embeddings-query-qdrant`
|
||||
|
||||
The processor:
|
||||
1. Receives `RowEmbeddingsRequest` with query vectors
|
||||
2. Finds the appropriate Qdrant collection by prefix matching
|
||||
3. Searches for nearest vectors with optional `index_name` filter
|
||||
4. Returns `RowEmbeddingsResponse` with matching index information
|
||||
|
||||
#### API Gateway Integration
|
||||
|
||||
The gateway exposes row embeddings queries via the standard request/response pattern:
|
||||
|
||||
| Component | Location |
|
||||
|-----------|----------|
|
||||
| Dispatcher | `trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py` |
|
||||
| Registration | Add `"row-embeddings"` to `request_response_dispatchers` in `manager.py` |
|
||||
|
||||
Flow interface name: `row-embeddings`
|
||||
|
||||
Interface definition in flow blueprint:
|
||||
```json
|
||||
{
|
||||
"interfaces": {
|
||||
"row-embeddings": {
|
||||
"request": "non-persistent://tg/request/row-embeddings:{id}",
|
||||
"response": "non-persistent://tg/response/row-embeddings:{id}"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Python SDK Support
|
||||
|
||||
The SDK provides methods for row embeddings queries:
|
||||
|
||||
```python
|
||||
# Flow-scoped query (preferred)
|
||||
api = Api(url)
|
||||
flow = api.flow().id("default")
|
||||
|
||||
# Query with text (SDK computes embeddings)
|
||||
matches = flow.row_embeddings_query(
|
||||
text="Chestnut Street",
|
||||
collection="my_collection",
|
||||
schema_name="addresses",
|
||||
index_name="street_name", # Optional filter
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Query with pre-computed vectors
|
||||
matches = flow.row_embeddings_query(
|
||||
vectors=[[0.1, 0.2, ...]],
|
||||
collection="my_collection",
|
||||
schema_name="addresses"
|
||||
)
|
||||
|
||||
# Each match contains:
|
||||
for match in matches:
|
||||
print(match.index_name) # e.g., "street_name"
|
||||
print(match.index_value) # e.g., ["CHESTNUT ST"]
|
||||
print(match.text) # e.g., "CHESTNUT ST"
|
||||
print(match.score) # e.g., 0.95
|
||||
```
|
||||
|
||||
#### CLI Utility
|
||||
|
||||
Command: `tg-invoke-row-embeddings`
|
||||
|
||||
```bash
|
||||
# Query by text (computes embedding automatically)
|
||||
tg-invoke-row-embeddings \
|
||||
--text "Chestnut Street" \
|
||||
--collection my_collection \
|
||||
--schema addresses \
|
||||
--index street_name \
|
||||
--limit 10
|
||||
|
||||
# Query by vector file
|
||||
tg-invoke-row-embeddings \
|
||||
--vectors vectors.json \
|
||||
--collection my_collection \
|
||||
--schema addresses
|
||||
|
||||
# Output formats
|
||||
tg-invoke-row-embeddings --text "..." --format json
|
||||
tg-invoke-row-embeddings --text "..." --format table
|
||||
```
|
||||
|
||||
#### Typical Usage Pattern
|
||||
|
||||
The row embeddings query is typically used as part of a fuzzy-to-exact lookup flow:
|
||||
|
||||
```python
|
||||
# Step 1: Fuzzy search via embeddings
|
||||
matches = flow.row_embeddings_query(
|
||||
text="chestnut street",
|
||||
collection="geo",
|
||||
schema_name="streets"
|
||||
)
|
||||
|
||||
# Step 2: Exact lookup via GraphQL for full row data
|
||||
for match in matches:
|
||||
query = f'''
|
||||
query {{
|
||||
streets(where: {{ {match.index_name}: {{ eq: "{match.index_value[0]}" }} }}) {{
|
||||
street_name
|
||||
city
|
||||
zip_code
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
rows = flow.rows_query(query, collection="geo")
|
||||
```
|
||||
|
||||
This two-step pattern enables:
|
||||
- Finding "CHESTNUT ST" when user searches for "Chestnut Street"
|
||||
- Retrieving complete row data with all fields
|
||||
- Combining semantic similarity with structured data access
|
||||
|
||||
### Row Data Ingestion
|
||||
|
||||
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
|
||||
|
||||
## Implementation Impact
|
||||
|
||||
### Current State Analysis
|
||||
|
||||
The existing implementation has two main components:
|
||||
|
||||
| Component | Location | Lines | Description |
|
||||
|-----------|----------|-------|-------------|
|
||||
| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling |
|
||||
| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete |
|
||||
|
||||
**Current Query Pattern:**
|
||||
```sql
|
||||
SELECT * FROM {keyspace}.o_{schema_name}
|
||||
WHERE collection = 'X' AND email = 'foo@bar.com'
|
||||
ALLOW FILTERING
|
||||
```
|
||||
|
||||
**New Query Pattern:**
|
||||
```sql
|
||||
SELECT * FROM {keyspace}.rows
|
||||
WHERE collection = 'X' AND schema_name = 'customers'
|
||||
AND index_name = 'email' AND index_value = ['foo@bar.com']
|
||||
```
|
||||
|
||||
### Key Changes
|
||||
|
||||
1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either:
|
||||
- Become post-filtering on returned data (if still needed)
|
||||
- Are removed in favor of using the embeddings API for fuzzy matching
|
||||
|
||||
2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code.
|
||||
|
||||
### Proposed Refactor
|
||||
|
||||
The refactor has two parts:
|
||||
|
||||
#### 1. Break Out GraphQL Code
|
||||
|
||||
Extract reusable GraphQL components into a shared module:
|
||||
|
||||
```
|
||||
trustgraph-flow/trustgraph/query/graphql/
|
||||
├── __init__.py
|
||||
├── types.py # Filter types (IntFilter, StringFilter, FloatFilter)
|
||||
├── schema.py # Dynamic schema generation from RowSchema
|
||||
└── filters.py # Filter parsing utilities
|
||||
```
|
||||
|
||||
This enables:
|
||||
- Reuse across different row store backends
|
||||
- Cleaner separation of concerns
|
||||
- Easier testing of GraphQL logic independently
|
||||
|
||||
#### 2. Implement New Table Schema
|
||||
|
||||
Refactor the Cassandra-specific code to use the unified table:
|
||||
|
||||
**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`):
|
||||
- Single `rows` table instead of per-schema tables
|
||||
- Write N copies per row (one per index)
|
||||
- Register to `row_partitions` table
|
||||
- Simpler table creation (one-time setup)
|
||||
|
||||
**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`):
|
||||
- Query the unified `rows` table
|
||||
- Use extracted GraphQL module for schema generation
|
||||
- Simplified filter handling (exact match only at DB level)
|
||||
|
||||
### Module Renames
|
||||
|
||||
As part of the "object" → "row" naming cleanup:
|
||||
|
||||
| Current | New |
|
||||
|---------|-----|
|
||||
| `storage/objects/cassandra/` | `storage/rows/cassandra/` |
|
||||
| `query/objects/cassandra/` | `query/rows/cassandra/` |
|
||||
| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` |
|
||||
|
||||
### New Modules
|
||||
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities |
|
||||
| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API |
|
||||
| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) |
|
||||
| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) |
|
||||
|
||||
## References
|
||||
|
||||
- [Structured Data Technical Specification](structured-data.md)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
type: object
|
||||
description: |
|
||||
Row embeddings query request - find similar rows by vector similarity on indexed fields.
|
||||
Enables semantic/fuzzy matching on structured data.
|
||||
required:
|
||||
- vectors
|
||||
- schema_name
|
||||
properties:
|
||||
vectors:
|
||||
type: array
|
||||
description: Query embedding vector
|
||||
items:
|
||||
type: number
|
||||
example: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156]
|
||||
schema_name:
|
||||
type: string
|
||||
description: Schema name to search within
|
||||
example: customers
|
||||
index_name:
|
||||
type: string
|
||||
description: Optional index name to filter search to specific index
|
||||
example: full_name
|
||||
limit:
|
||||
type: integer
|
||||
description: Maximum number of matches to return
|
||||
default: 10
|
||||
minimum: 1
|
||||
maximum: 1000
|
||||
example: 20
|
||||
user:
|
||||
type: string
|
||||
description: User identifier
|
||||
default: trustgraph
|
||||
example: alice
|
||||
collection:
|
||||
type: string
|
||||
description: Collection to search
|
||||
default: default
|
||||
example: sales
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
type: object
|
||||
description: Row embeddings query response with matching row index information
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
description: Error information if query failed
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Error type identifier
|
||||
example: row-embeddings-query-error
|
||||
message:
|
||||
type: string
|
||||
description: Human-readable error message
|
||||
example: Schema not found
|
||||
matches:
|
||||
type: array
|
||||
description: List of matching row index entries with similarity scores
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
index_name:
|
||||
type: string
|
||||
description: Name of the indexed field(s)
|
||||
example: full_name
|
||||
index_value:
|
||||
type: array
|
||||
description: Values of the indexed fields for this row
|
||||
items:
|
||||
type: string
|
||||
example: ["John", "Smith"]
|
||||
text:
|
||||
type: string
|
||||
description: The text that was embedded for this index entry
|
||||
example: "John Smith"
|
||||
score:
|
||||
type: number
|
||||
description: Similarity score (higher is more similar)
|
||||
example: 0.89
|
||||
example:
|
||||
matches:
|
||||
- index_name: full_name
|
||||
index_value: ["John", "Smith"]
|
||||
text: "John Smith"
|
||||
score: 0.95
|
||||
- index_name: full_name
|
||||
index_value: ["Jon", "Smythe"]
|
||||
text: "Jon Smythe"
|
||||
score: 0.82
|
||||
- index_name: full_name
|
||||
index_value: ["Jonathan", "Schmidt"]
|
||||
text: "Jonathan Schmidt"
|
||||
score: 0.76
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
type: object
|
||||
description: |
|
||||
Objects query request - GraphQL query over knowledge graph.
|
||||
Rows query request - GraphQL query over structured data.
|
||||
required:
|
||||
- query
|
||||
properties:
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
type: object
|
||||
description: Objects query response (GraphQL format)
|
||||
description: Rows query response (GraphQL format)
|
||||
properties:
|
||||
data:
|
||||
description: GraphQL response data (JSON object or null)
|
||||
|
|
@ -121,8 +121,8 @@ paths:
|
|||
$ref: './paths/flow/mcp-tool.yaml'
|
||||
/api/v1/flow/{flow}/service/triples:
|
||||
$ref: './paths/flow/triples.yaml'
|
||||
/api/v1/flow/{flow}/service/objects:
|
||||
$ref: './paths/flow/objects.yaml'
|
||||
/api/v1/flow/{flow}/service/rows:
|
||||
$ref: './paths/flow/rows.yaml'
|
||||
/api/v1/flow/{flow}/service/nlp-query:
|
||||
$ref: './paths/flow/nlp-query.yaml'
|
||||
/api/v1/flow/{flow}/service/structured-query:
|
||||
|
|
@ -133,6 +133,8 @@ paths:
|
|||
$ref: './paths/flow/graph-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/document-embeddings:
|
||||
$ref: './paths/flow/document-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/row-embeddings:
|
||||
$ref: './paths/flow/row-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/text-load:
|
||||
$ref: './paths/flow/text-load.yaml'
|
||||
/api/v1/flow/{flow}/service/document-load:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ post:
|
|||
```
|
||||
1. User asks: "Who does Alice know?"
|
||||
2. NLP Query generates GraphQL
|
||||
3. Execute via /api/v1/flow/{flow}/service/objects
|
||||
3. Execute via /api/v1/flow/{flow}/service/rows
|
||||
4. Return results to user
|
||||
```
|
||||
|
||||
|
|
|
|||
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
post:
|
||||
tags:
|
||||
- Flow Services
|
||||
summary: Row Embeddings Query - semantic search on structured data
|
||||
description: |
|
||||
Query row embeddings to find similar rows by vector similarity on indexed fields.
|
||||
Enables fuzzy/semantic matching on structured data.
|
||||
|
||||
## Row Embeddings Query Overview
|
||||
|
||||
Find rows whose indexed field values are semantically similar to a query:
|
||||
- **Input**: Query embedding vector, schema name, optional index filter
|
||||
- **Search**: Compare against stored row index embeddings
|
||||
- **Output**: Matching rows with index values and similarity scores
|
||||
|
||||
Core component of semantic search on structured data.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- **Fuzzy name matching**: Find customers by approximate name
|
||||
- **Semantic field search**: Find products by description similarity
|
||||
- **Data deduplication**: Identify potential duplicate records
|
||||
- **Entity resolution**: Match records across datasets
|
||||
|
||||
## Process
|
||||
|
||||
1. Obtain query embedding (via embeddings service)
|
||||
2. Query stored row index embeddings for the specified schema
|
||||
3. Calculate cosine similarity
|
||||
4. Return top N most similar index entries
|
||||
5. Use index values to retrieve full rows via GraphQL
|
||||
|
||||
## Response Format
|
||||
|
||||
Each match includes:
|
||||
- `index_name`: The indexed field(s) that matched
|
||||
- `index_value`: The actual values for those fields
|
||||
- `text`: The text that was embedded
|
||||
- `score`: Similarity score (higher = more similar)
|
||||
|
||||
operationId: rowEmbeddingsQueryService
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: flow
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: Flow instance ID
|
||||
example: my-flow
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||
examples:
|
||||
basicQuery:
|
||||
summary: Find similar customer names
|
||||
value:
|
||||
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
|
||||
schema_name: customers
|
||||
limit: 10
|
||||
user: alice
|
||||
collection: sales
|
||||
filteredQuery:
|
||||
summary: Search specific index
|
||||
value:
|
||||
vectors: [0.1, -0.2, 0.3, -0.4, 0.5]
|
||||
schema_name: products
|
||||
index_name: description
|
||||
limit: 20
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml'
|
||||
examples:
|
||||
similarRows:
|
||||
summary: Similar rows found
|
||||
value:
|
||||
matches:
|
||||
- index_name: full_name
|
||||
index_value: ["John", "Smith"]
|
||||
text: "John Smith"
|
||||
score: 0.95
|
||||
- index_name: full_name
|
||||
index_value: ["Jon", "Smythe"]
|
||||
text: "Jon Smythe"
|
||||
score: 0.82
|
||||
- index_name: full_name
|
||||
index_value: ["Jonathan", "Schmidt"]
|
||||
text: "Jonathan Schmidt"
|
||||
score: 0.76
|
||||
'401':
|
||||
$ref: '../../components/responses/Unauthorized.yaml'
|
||||
'500':
|
||||
$ref: '../../components/responses/Error.yaml'
|
||||
|
|
@ -1,19 +1,19 @@
|
|||
post:
|
||||
tags:
|
||||
- Flow Services
|
||||
summary: Objects query - GraphQL over knowledge graph
|
||||
summary: Rows query - GraphQL over structured data
|
||||
description: |
|
||||
Query knowledge graph using GraphQL for object-oriented data access.
|
||||
Query structured data using GraphQL for row-oriented data access.
|
||||
|
||||
## Objects Query Overview
|
||||
## Rows Query Overview
|
||||
|
||||
GraphQL interface to knowledge graph:
|
||||
GraphQL interface to structured data:
|
||||
- **Schema-driven**: Predefined types and relationships
|
||||
- **Flexible queries**: Request exactly what you need
|
||||
- **Nested data**: Traverse relationships in single query
|
||||
- **Type-safe**: Strong typing with introspection
|
||||
|
||||
Abstracts RDF triples into familiar object model.
|
||||
Abstracts structured rows into familiar object model.
|
||||
|
||||
## GraphQL Benefits
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ post:
|
|||
Schema defines available types via config service.
|
||||
Use introspection query to discover schema.
|
||||
|
||||
operationId: objectsQueryService
|
||||
operationId: rowsQueryService
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
|
|
@ -77,7 +77,7 @@ post:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/query/ObjectsQueryRequest.yaml'
|
||||
$ref: '../../components/schemas/query/RowsQueryRequest.yaml'
|
||||
examples:
|
||||
simpleQuery:
|
||||
summary: Simple query
|
||||
|
|
@ -129,7 +129,7 @@ post:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/query/ObjectsQueryResponse.yaml'
|
||||
$ref: '../../components/schemas/query/RowsQueryResponse.yaml'
|
||||
examples:
|
||||
successfulQuery:
|
||||
summary: Successful query
|
||||
|
|
@ -9,7 +9,7 @@ post:
|
|||
|
||||
Combines two operations in one call:
|
||||
1. **NLP Query**: Generate GraphQL from question
|
||||
2. **Objects Query**: Execute generated query
|
||||
2. **Rows Query**: Execute generated query
|
||||
3. **Return Results**: Direct answer data
|
||||
|
||||
Simplest way to query knowledge graph with natural language.
|
||||
|
|
@ -21,7 +21,7 @@ post:
|
|||
- **Output**: Query results (data)
|
||||
- **Use when**: Want simple, direct answers
|
||||
|
||||
### NLP Query + Objects Query (separate calls)
|
||||
### NLP Query + Rows Query (separate calls)
|
||||
- **Step 1**: Convert question → GraphQL
|
||||
- **Step 2**: Execute GraphQL → results
|
||||
- **Use when**: Need to inspect/modify query before execution
|
||||
|
|
|
|||
|
|
@ -25,12 +25,13 @@ payload:
|
|||
- $ref: './requests/EmbeddingsRequest.yaml'
|
||||
- $ref: './requests/McpToolRequest.yaml'
|
||||
- $ref: './requests/TriplesRequest.yaml'
|
||||
- $ref: './requests/ObjectsRequest.yaml'
|
||||
- $ref: './requests/RowsRequest.yaml'
|
||||
- $ref: './requests/NlpQueryRequest.yaml'
|
||||
- $ref: './requests/StructuredQueryRequest.yaml'
|
||||
- $ref: './requests/StructuredDiagRequest.yaml'
|
||||
- $ref: './requests/GraphEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/DocumentEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/RowEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/TextLoadRequest.yaml'
|
||||
- $ref: './requests/DocumentLoadRequest.yaml'
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
type: object
|
||||
description: WebSocket request for row-embeddings service (flow-hosted service)
|
||||
required:
|
||||
- id
|
||||
- service
|
||||
- flow
|
||||
- request
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique request identifier
|
||||
service:
|
||||
type: string
|
||||
const: row-embeddings
|
||||
description: Service identifier for row-embeddings service
|
||||
flow:
|
||||
type: string
|
||||
description: Flow ID
|
||||
request:
|
||||
$ref: '../../../../api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||
examples:
|
||||
- id: req-1
|
||||
service: row-embeddings
|
||||
flow: my-flow
|
||||
request:
|
||||
vectors: [0.023, -0.142, 0.089, 0.234]
|
||||
schema_name: customers
|
||||
limit: 10
|
||||
user: trustgraph
|
||||
collection: default
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
type: object
|
||||
description: WebSocket request for objects service (flow-hosted service)
|
||||
description: WebSocket request for rows service (flow-hosted service)
|
||||
required:
|
||||
- id
|
||||
- service
|
||||
|
|
@ -11,16 +11,16 @@ properties:
|
|||
description: Unique request identifier
|
||||
service:
|
||||
type: string
|
||||
const: objects
|
||||
description: Service identifier for objects service
|
||||
const: rows
|
||||
description: Service identifier for rows service
|
||||
flow:
|
||||
type: string
|
||||
description: Flow ID
|
||||
request:
|
||||
$ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml'
|
||||
$ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml'
|
||||
examples:
|
||||
- id: req-1
|
||||
service: objects
|
||||
service: rows
|
||||
flow: my-flow
|
||||
request:
|
||||
query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }"
|
||||
|
|
@ -15,10 +15,10 @@ from trustgraph.schema import (
|
|||
TextCompletionRequest, TextCompletionResponse,
|
||||
DocumentRagQuery, DocumentRagResponse,
|
||||
AgentRequest, AgentResponse, AgentStep,
|
||||
Chunk, Triple, Triples, Value, Error,
|
||||
Chunk, Triple, Triples, Term, Error,
|
||||
EntityContext, EntityContexts,
|
||||
GraphEmbeddings, EntityEmbeddings,
|
||||
Metadata
|
||||
Metadata, IRI, LITERAL
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ def schema_registry():
|
|||
"Chunk": Chunk,
|
||||
"Triple": Triple,
|
||||
"Triples": Triples,
|
||||
"Value": Value,
|
||||
"Term": Term,
|
||||
"Error": Error,
|
||||
"EntityContext": EntityContext,
|
||||
"EntityContexts": EntityContexts,
|
||||
|
|
@ -98,26 +98,22 @@ def sample_message_data():
|
|||
"collection": "test_collection",
|
||||
"metadata": []
|
||||
},
|
||||
"Value": {
|
||||
"value": "http://example.com/entity",
|
||||
"is_uri": True,
|
||||
"type": ""
|
||||
"Term": {
|
||||
"type": IRI,
|
||||
"iri": "http://example.com/entity"
|
||||
},
|
||||
"Triple": {
|
||||
"s": Value(
|
||||
value="http://example.com/subject",
|
||||
is_uri=True,
|
||||
type=""
|
||||
"s": Term(
|
||||
type=IRI,
|
||||
iri="http://example.com/subject"
|
||||
),
|
||||
"p": Value(
|
||||
value="http://example.com/predicate",
|
||||
is_uri=True,
|
||||
type=""
|
||||
"p": Term(
|
||||
type=IRI,
|
||||
iri="http://example.com/predicate"
|
||||
),
|
||||
"o": Value(
|
||||
value="Object value",
|
||||
is_uri=False,
|
||||
type=""
|
||||
"o": Term(
|
||||
type=LITERAL,
|
||||
value="Object value"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -139,10 +135,10 @@ def invalid_message_data():
|
|||
{"query": "test", "user": "test", "collection": "test", "doc_limit": -1}, # Invalid doc_limit
|
||||
{"query": "test"}, # Missing required fields
|
||||
],
|
||||
"Value": [
|
||||
{"value": None, "is_uri": True, "type": ""}, # Invalid value (None)
|
||||
{"value": "test", "is_uri": "not_boolean", "type": ""}, # Invalid is_uri
|
||||
{"value": 123, "is_uri": True, "type": ""}, # Invalid value (not string)
|
||||
"Term": [
|
||||
{"type": IRI, "iri": None}, # Invalid iri (None)
|
||||
{"type": "invalid_type", "value": "test"}, # Invalid type
|
||||
{"type": LITERAL, "value": 123}, # Invalid value (not string)
|
||||
]
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,14 @@ from trustgraph.schema import (
|
|||
TextCompletionRequest, TextCompletionResponse,
|
||||
DocumentRagQuery, DocumentRagResponse,
|
||||
AgentRequest, AgentResponse, AgentStep,
|
||||
Chunk, Triple, Triples, Value, Error,
|
||||
Chunk, Triple, Triples, Term, Error,
|
||||
EntityContext, EntityContexts,
|
||||
GraphEmbeddings, EntityEmbeddings,
|
||||
Metadata, Field, RowSchema,
|
||||
StructuredDataSubmission, ExtractedObject,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
StructuredObjectEmbedding
|
||||
StructuredObjectEmbedding, IRI, LITERAL
|
||||
)
|
||||
from .conftest import validate_schema_contract, serialize_deserialize_test
|
||||
|
||||
|
|
@ -271,52 +271,51 @@ class TestAgentMessageContracts:
|
|||
class TestGraphMessageContracts:
|
||||
"""Contract tests for Graph/Knowledge message schemas"""
|
||||
|
||||
def test_value_schema_contract(self, sample_message_data):
|
||||
"""Test Value schema contract"""
|
||||
def test_term_schema_contract(self, sample_message_data):
|
||||
"""Test Term schema contract"""
|
||||
# Arrange
|
||||
value_data = sample_message_data["Value"]
|
||||
term_data = sample_message_data["Term"]
|
||||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(Value, value_data)
|
||||
|
||||
# Test URI value
|
||||
uri_value = Value(**value_data)
|
||||
assert uri_value.value == "http://example.com/entity"
|
||||
assert uri_value.is_uri is True
|
||||
assert validate_schema_contract(Term, term_data)
|
||||
|
||||
# Test literal value
|
||||
literal_value = Value(
|
||||
value="Literal text value",
|
||||
is_uri=False,
|
||||
type=""
|
||||
# Test URI term
|
||||
uri_term = Term(**term_data)
|
||||
assert uri_term.iri == "http://example.com/entity"
|
||||
assert uri_term.type == IRI
|
||||
|
||||
# Test literal term
|
||||
literal_term = Term(
|
||||
type=LITERAL,
|
||||
value="Literal text value"
|
||||
)
|
||||
assert literal_value.value == "Literal text value"
|
||||
assert literal_value.is_uri is False
|
||||
assert literal_term.value == "Literal text value"
|
||||
assert literal_term.type == LITERAL
|
||||
|
||||
def test_triple_schema_contract(self, sample_message_data):
|
||||
"""Test Triple schema contract"""
|
||||
# Arrange
|
||||
triple_data = sample_message_data["Triple"]
|
||||
|
||||
# Act & Assert - Triple uses Value objects, not dict validation
|
||||
# Act & Assert - Triple uses Term objects, not dict validation
|
||||
triple = Triple(
|
||||
s=triple_data["s"],
|
||||
p=triple_data["p"],
|
||||
p=triple_data["p"],
|
||||
o=triple_data["o"]
|
||||
)
|
||||
assert triple.s.value == "http://example.com/subject"
|
||||
assert triple.p.value == "http://example.com/predicate"
|
||||
assert triple.s.iri == "http://example.com/subject"
|
||||
assert triple.p.iri == "http://example.com/predicate"
|
||||
assert triple.o.value == "Object value"
|
||||
assert triple.s.is_uri is True
|
||||
assert triple.p.is_uri is True
|
||||
assert triple.o.is_uri is False
|
||||
assert triple.s.type == IRI
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.type == LITERAL
|
||||
|
||||
def test_triples_schema_contract(self, sample_message_data):
|
||||
"""Test Triples (batch) schema contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(**sample_message_data["Metadata"])
|
||||
triple = Triple(**sample_message_data["Triple"])
|
||||
|
||||
|
||||
triples_data = {
|
||||
"metadata": metadata,
|
||||
"triples": [triple]
|
||||
|
|
@ -324,11 +323,11 @@ class TestGraphMessageContracts:
|
|||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(Triples, triples_data)
|
||||
|
||||
|
||||
triples = Triples(**triples_data)
|
||||
assert triples.metadata.id == "test-doc-123"
|
||||
assert len(triples.triples) == 1
|
||||
assert triples.triples[0].s.value == "http://example.com/subject"
|
||||
assert triples.triples[0].s.iri == "http://example.com/subject"
|
||||
|
||||
def test_chunk_schema_contract(self, sample_message_data):
|
||||
"""Test Chunk schema contract"""
|
||||
|
|
@ -349,29 +348,29 @@ class TestGraphMessageContracts:
|
|||
def test_entity_context_schema_contract(self):
|
||||
"""Test EntityContext schema contract"""
|
||||
# Arrange
|
||||
entity_value = Value(value="http://example.com/entity", is_uri=True, type="")
|
||||
entity_term = Term(type=IRI, iri="http://example.com/entity")
|
||||
entity_context_data = {
|
||||
"entity": entity_value,
|
||||
"entity": entity_term,
|
||||
"context": "Context information about the entity"
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(EntityContext, entity_context_data)
|
||||
|
||||
|
||||
entity_context = EntityContext(**entity_context_data)
|
||||
assert entity_context.entity.value == "http://example.com/entity"
|
||||
assert entity_context.entity.iri == "http://example.com/entity"
|
||||
assert entity_context.context == "Context information about the entity"
|
||||
|
||||
def test_entity_contexts_batch_schema_contract(self, sample_message_data):
|
||||
"""Test EntityContexts (batch) schema contract"""
|
||||
# Arrange
|
||||
metadata = Metadata(**sample_message_data["Metadata"])
|
||||
entity_value = Value(value="http://example.com/entity", is_uri=True, type="")
|
||||
entity_term = Term(type=IRI, iri="http://example.com/entity")
|
||||
entity_context = EntityContext(
|
||||
entity=entity_value,
|
||||
entity=entity_term,
|
||||
context="Entity context"
|
||||
)
|
||||
|
||||
|
||||
entity_contexts_data = {
|
||||
"metadata": metadata,
|
||||
"entities": [entity_context]
|
||||
|
|
@ -379,7 +378,7 @@ class TestGraphMessageContracts:
|
|||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(EntityContexts, entity_contexts_data)
|
||||
|
||||
|
||||
entity_contexts = EntityContexts(**entity_contexts_data)
|
||||
assert entity_contexts.metadata.id == "test-doc-123"
|
||||
assert len(entity_contexts.entities) == 1
|
||||
|
|
@ -417,10 +416,10 @@ class TestMetadataMessageContracts:
|
|||
|
||||
# Act & Assert
|
||||
assert validate_schema_contract(Metadata, metadata_data)
|
||||
|
||||
|
||||
metadata = Metadata(**metadata_data)
|
||||
assert len(metadata.metadata) == 1
|
||||
assert metadata.metadata[0].s.value == "http://example.com/subject"
|
||||
assert metadata.metadata[0].s.iri == "http://example.com/subject"
|
||||
|
||||
def test_error_schema_contract(self):
|
||||
"""Test Error schema contract"""
|
||||
|
|
@ -532,7 +531,7 @@ class TestSerializationContracts:
|
|||
# Test each schema in the registry
|
||||
for schema_name, schema_class in schema_registry.items():
|
||||
if schema_name in sample_message_data:
|
||||
# Skip Triple schema as it requires special handling with Value objects
|
||||
# Skip Triple schema as it requires special handling with Term objects
|
||||
if schema_name == "Triple":
|
||||
continue
|
||||
|
||||
|
|
@ -541,36 +540,36 @@ class TestSerializationContracts:
|
|||
assert serialize_deserialize_test(schema_class, data), f"Serialization failed for {schema_name}"
|
||||
|
||||
def test_triple_serialization_contract(self, sample_message_data):
|
||||
"""Test Triple schema serialization contract with Value objects"""
|
||||
"""Test Triple schema serialization contract with Term objects"""
|
||||
# Arrange
|
||||
triple_data = sample_message_data["Triple"]
|
||||
|
||||
|
||||
# Act
|
||||
triple = Triple(
|
||||
s=triple_data["s"],
|
||||
p=triple_data["p"],
|
||||
p=triple_data["p"],
|
||||
o=triple_data["o"]
|
||||
)
|
||||
|
||||
# Assert - Test that Value objects are properly constructed and accessible
|
||||
assert triple.s.value == "http://example.com/subject"
|
||||
assert triple.p.value == "http://example.com/predicate"
|
||||
|
||||
# Assert - Test that Term objects are properly constructed and accessible
|
||||
assert triple.s.iri == "http://example.com/subject"
|
||||
assert triple.p.iri == "http://example.com/predicate"
|
||||
assert triple.o.value == "Object value"
|
||||
assert isinstance(triple.s, Value)
|
||||
assert isinstance(triple.p, Value)
|
||||
assert isinstance(triple.o, Value)
|
||||
assert isinstance(triple.s, Term)
|
||||
assert isinstance(triple.p, Term)
|
||||
assert isinstance(triple.o, Term)
|
||||
|
||||
def test_nested_schema_serialization_contract(self, sample_message_data):
|
||||
"""Test serialization of nested schemas"""
|
||||
# Test Triples (contains Metadata and Triple objects)
|
||||
metadata = Metadata(**sample_message_data["Metadata"])
|
||||
triple = Triple(**sample_message_data["Triple"])
|
||||
|
||||
|
||||
triples = Triples(metadata=metadata, triples=[triple])
|
||||
|
||||
|
||||
# Verify nested objects maintain their contracts
|
||||
assert triples.metadata.id == "test-doc-123"
|
||||
assert triples.triples[0].s.value == "http://example.com/subject"
|
||||
assert triples.triples[0].s.iri == "http://example.com/subject"
|
||||
|
||||
def test_array_field_serialization_contract(self):
|
||||
"""Test serialization of array fields"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""
|
||||
Contract tests for Cassandra Object Storage
|
||||
Contract tests for Cassandra Row Storage
|
||||
|
||||
These tests verify the message contracts and schema compatibility
|
||||
for the objects storage processor.
|
||||
for the rows storage processor.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -10,12 +10,12 @@ import json
|
|||
from pulsar.schema import AvroSchema
|
||||
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestObjectsCassandraContracts:
|
||||
"""Contract tests for Cassandra object storage messages"""
|
||||
class TestRowsCassandraContracts:
|
||||
"""Contract tests for Cassandra row storage messages"""
|
||||
|
||||
def test_extracted_object_input_contract(self):
|
||||
"""Test that ExtractedObject schema matches expected input format"""
|
||||
|
|
@ -145,50 +145,6 @@ class TestObjectsCassandraContracts:
|
|||
assert required_field_keys.issubset(field.keys())
|
||||
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
|
||||
|
||||
def test_cassandra_type_mapping_contract(self):
|
||||
"""Test that all supported field types have Cassandra mappings"""
|
||||
processor = Processor.__new__(Processor)
|
||||
|
||||
# All field types that should be supported
|
||||
supported_types = [
|
||||
("string", "text"),
|
||||
("integer", "int"), # or bigint based on size
|
||||
("float", "float"), # or double based on size
|
||||
("boolean", "boolean"),
|
||||
("timestamp", "timestamp"),
|
||||
("date", "date"),
|
||||
("time", "time"),
|
||||
("uuid", "uuid")
|
||||
]
|
||||
|
||||
for field_type, expected_cassandra_type in supported_types:
|
||||
cassandra_type = processor.get_cassandra_type(field_type)
|
||||
# For integer and float, the exact type depends on size
|
||||
if field_type in ["integer", "float"]:
|
||||
assert cassandra_type in ["int", "bigint", "float", "double"]
|
||||
else:
|
||||
assert cassandra_type == expected_cassandra_type
|
||||
|
||||
def test_value_conversion_contract(self):
|
||||
"""Test value conversion for all supported types"""
|
||||
processor = Processor.__new__(Processor)
|
||||
|
||||
# Test conversions maintain data integrity
|
||||
test_cases = [
|
||||
# (input_value, field_type, expected_output, expected_type)
|
||||
("123", "integer", 123, int),
|
||||
("123.45", "float", 123.45, float),
|
||||
("true", "boolean", True, bool),
|
||||
("false", "boolean", False, bool),
|
||||
("test string", "string", "test string", str),
|
||||
(None, "string", None, type(None)),
|
||||
]
|
||||
|
||||
for input_val, field_type, expected_val, expected_type in test_cases:
|
||||
result = processor.convert_value(input_val, field_type)
|
||||
assert result == expected_val
|
||||
assert isinstance(result, expected_type) or result is None
|
||||
|
||||
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||
def test_extracted_object_serialization_contract(self):
|
||||
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
||||
|
|
@ -222,43 +178,31 @@ class TestObjectsCassandraContracts:
|
|||
assert decoded.confidence == original.confidence
|
||||
assert decoded.source_span == original.source_span
|
||||
|
||||
def test_cassandra_table_naming_contract(self):
|
||||
def test_cassandra_name_sanitization_contract(self):
|
||||
"""Test Cassandra naming conventions and constraints"""
|
||||
processor = Processor.__new__(Processor)
|
||||
|
||||
# Test table naming (always gets o_ prefix)
|
||||
table_test_names = [
|
||||
("simple_name", "o_simple_name"),
|
||||
("Name-With-Dashes", "o_name_with_dashes"),
|
||||
("name.with.dots", "o_name_with_dots"),
|
||||
("123_numbers", "o_123_numbers"),
|
||||
("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores
|
||||
("UPPERCASE", "o_uppercase"),
|
||||
("CamelCase", "o_camelcase"),
|
||||
("", "o_"), # Edge case - empty string becomes o_
|
||||
]
|
||||
|
||||
for input_name, expected_name in table_test_names:
|
||||
result = processor.sanitize_table(input_name)
|
||||
assert result == expected_name
|
||||
# Verify result is valid Cassandra identifier (starts with letter)
|
||||
assert result.startswith('o_')
|
||||
assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_'
|
||||
|
||||
# Test regular name sanitization (only adds o_ prefix if starts with number)
|
||||
|
||||
# Test name sanitization for Cassandra identifiers
|
||||
# - Non-alphanumeric chars (except underscore) become underscores
|
||||
# - Names starting with non-letter get 'r_' prefix
|
||||
# - All names converted to lowercase
|
||||
name_test_cases = [
|
||||
("simple_name", "simple_name"),
|
||||
("Name-With-Dashes", "name_with_dashes"),
|
||||
("name.with.dots", "name_with_dots"),
|
||||
("123_numbers", "o_123_numbers"), # Only this gets o_ prefix
|
||||
("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number)
|
||||
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
|
||||
("UPPERCASE", "uppercase"),
|
||||
("CamelCase", "camelcase"),
|
||||
("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore)
|
||||
]
|
||||
|
||||
|
||||
for input_name, expected_name in name_test_cases:
|
||||
result = processor.sanitize_name(input_name)
|
||||
assert result == expected_name
|
||||
assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}"
|
||||
# Verify result is valid Cassandra identifier (starts with letter)
|
||||
if result: # Skip empty string case
|
||||
assert result[0].isalpha(), f"Result {result} should start with a letter"
|
||||
|
||||
def test_primary_key_structure_contract(self):
|
||||
"""Test that primary key structure follows Cassandra best practices"""
|
||||
|
|
@ -308,8 +252,8 @@ class TestObjectsCassandraContracts:
|
|||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestObjectsCassandraContractsBatch:
|
||||
"""Contract tests for Cassandra object storage batch processing"""
|
||||
class TestRowsCassandraContractsBatch:
|
||||
"""Contract tests for Cassandra row storage batch processing"""
|
||||
|
||||
def test_extracted_object_batch_input_contract(self):
|
||||
"""Test that batched ExtractedObject schema matches expected input format"""
|
||||
|
|
@ -1,26 +1,26 @@
|
|||
"""
|
||||
Contract tests for Objects GraphQL Query Service
|
||||
Contract tests for Rows GraphQL Query Service
|
||||
|
||||
These tests verify the message contracts and schema compatibility
|
||||
for the objects GraphQL query processor.
|
||||
for the rows GraphQL query processor.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from pulsar.schema import AvroSchema
|
||||
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestObjectsGraphQLQueryContracts:
|
||||
class TestRowsGraphQLQueryContracts:
|
||||
"""Contract tests for GraphQL query service messages"""
|
||||
|
||||
def test_objects_query_request_contract(self):
|
||||
"""Test ObjectsQueryRequest schema structure and required fields"""
|
||||
def test_rows_query_request_contract(self):
|
||||
"""Test RowsQueryRequest schema structure and required fields"""
|
||||
# Create test request with all required fields
|
||||
test_request = ObjectsQueryRequest(
|
||||
test_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name email } }',
|
||||
|
|
@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert test_request.variables["status"] == "active"
|
||||
assert test_request.operation_name == "GetCustomers"
|
||||
|
||||
def test_objects_query_request_minimal(self):
|
||||
"""Test ObjectsQueryRequest with minimal required fields"""
|
||||
def test_rows_query_request_minimal(self):
|
||||
"""Test RowsQueryRequest with minimal required fields"""
|
||||
# Create request with only essential fields
|
||||
minimal_request = ObjectsQueryRequest(
|
||||
minimal_request = RowsQueryRequest(
|
||||
user="user",
|
||||
collection="collection",
|
||||
query='{ test }',
|
||||
|
|
@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert test_error.path == ["customers", "0", "nonexistent"]
|
||||
assert test_error.extensions["code"] == "FIELD_ERROR"
|
||||
|
||||
def test_objects_query_response_success_contract(self):
|
||||
"""Test ObjectsQueryResponse schema for successful queries"""
|
||||
def test_rows_query_response_success_contract(self):
|
||||
"""Test RowsQueryResponse schema for successful queries"""
|
||||
# Create successful response
|
||||
success_response = ObjectsQueryResponse(
|
||||
success_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||
errors=[],
|
||||
|
|
@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert len(parsed_data["customers"]) == 1
|
||||
assert parsed_data["customers"][0]["id"] == "1"
|
||||
|
||||
def test_objects_query_response_error_contract(self):
|
||||
"""Test ObjectsQueryResponse schema for error cases"""
|
||||
def test_rows_query_response_error_contract(self):
|
||||
"""Test RowsQueryResponse schema for error cases"""
|
||||
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
|
||||
# by creating a response without the problematic errors array first
|
||||
error_response = ObjectsQueryResponse(
|
||||
error_response = RowsQueryResponse(
|
||||
error=None, # System error is None - these are GraphQL errors
|
||||
data=None, # No data due to errors
|
||||
errors=[], # Empty errors array to avoid Pulsar bug
|
||||
|
|
@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert validation_error.path == ["customers", "email"]
|
||||
assert validation_error.extensions["details"] == "Invalid email format"
|
||||
|
||||
def test_objects_query_response_system_error_contract(self):
|
||||
"""Test ObjectsQueryResponse schema for system errors"""
|
||||
def test_rows_query_response_system_error_contract(self):
|
||||
"""Test RowsQueryResponse schema for system errors"""
|
||||
from trustgraph.schema import Error
|
||||
|
||||
# Create system error response
|
||||
system_error_response = ObjectsQueryResponse(
|
||||
system_error_response = RowsQueryResponse(
|
||||
error=Error(
|
||||
type="objects-query-error",
|
||||
type="rows-query-error",
|
||||
message="Failed to connect to Cassandra cluster"
|
||||
),
|
||||
data=None,
|
||||
|
|
@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
|
||||
# Verify system error structure
|
||||
assert system_error_response.error is not None
|
||||
assert system_error_response.error.type == "objects-query-error"
|
||||
assert system_error_response.error.type == "rows-query-error"
|
||||
assert "Cassandra" in system_error_response.error.message
|
||||
assert system_error_response.data is None
|
||||
assert len(system_error_response.errors) == 0
|
||||
|
|
@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
def test_request_response_serialization_contract(self):
|
||||
"""Test that request/response can be serialized/deserialized correctly"""
|
||||
# Create original request
|
||||
original_request = ObjectsQueryRequest(
|
||||
original_request = RowsQueryRequest(
|
||||
user="serialization_test",
|
||||
collection="test_data",
|
||||
query='{ orders(limit: 5) { id total customer { name } } }',
|
||||
|
|
@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
)
|
||||
|
||||
# Test request serialization using Pulsar schema
|
||||
request_schema = AvroSchema(ObjectsQueryRequest)
|
||||
request_schema = AvroSchema(RowsQueryRequest)
|
||||
|
||||
# Encode and decode request
|
||||
encoded_request = request_schema.encode(original_request)
|
||||
|
|
@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert decoded_request.operation_name == original_request.operation_name
|
||||
|
||||
# Create original response - work around Pulsar Array(Record) bug
|
||||
original_response = ObjectsQueryResponse(
|
||||
original_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"orders": []}',
|
||||
errors=[], # Empty to avoid Pulsar validation bug
|
||||
|
|
@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
)
|
||||
|
||||
# Test response serialization
|
||||
response_schema = AvroSchema(ObjectsQueryResponse)
|
||||
response_schema = AvroSchema(RowsQueryResponse)
|
||||
|
||||
# Encode and decode response
|
||||
encoded_response = response_schema.encode(original_response)
|
||||
|
|
@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
def test_graphql_query_format_contract(self):
|
||||
"""Test supported GraphQL query formats"""
|
||||
# Test basic query
|
||||
basic_query = ObjectsQueryRequest(
|
||||
basic_query = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
)
|
||||
|
|
@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert basic_query.query.strip().endswith('}')
|
||||
|
||||
# Test query with variables
|
||||
parameterized_query = ObjectsQueryRequest(
|
||||
parameterized_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
|
||||
variables={"status": "active", "limit": "10"},
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert parameterized_query.operation_name == "GetCustomers"
|
||||
|
||||
# Test complex nested query
|
||||
nested_query = ObjectsQueryRequest(
|
||||
nested_query = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
query='''
|
||||
{
|
||||
|
|
@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
# Note: Current schema uses Map(String()) which only supports string values
|
||||
# This test verifies the current contract, though ideally we'd support all JSON types
|
||||
|
||||
variables_test = ObjectsQueryRequest(
|
||||
variables_test = RowsQueryRequest(
|
||||
user="test", collection="test", query='{ test }',
|
||||
variables={
|
||||
"string_var": "test_value",
|
||||
|
|
@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
def test_cassandra_context_fields_contract(self):
|
||||
"""Test that request contains necessary fields for Cassandra operations"""
|
||||
# Verify request has fields needed for Cassandra keyspace/table targeting
|
||||
request = ObjectsQueryRequest(
|
||||
request = RowsQueryRequest(
|
||||
user="keyspace_name", # Maps to Cassandra keyspace
|
||||
collection="partition_collection", # Used in partition key
|
||||
query='{ objects { id } }',
|
||||
|
|
@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
def test_graphql_extensions_contract(self):
|
||||
"""Test GraphQL extensions field format and usage"""
|
||||
# Extensions should support query metadata
|
||||
response_with_extensions = ObjectsQueryResponse(
|
||||
response_with_extensions = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"test": "data"}',
|
||||
errors=[],
|
||||
|
|
@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
'''
|
||||
|
||||
# Request to execute specific operation
|
||||
multi_op_request = ObjectsQueryRequest(
|
||||
multi_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
query=multi_op_query,
|
||||
variables={},
|
||||
|
|
@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts:
|
|||
assert "GetOrders" in multi_op_request.query
|
||||
|
||||
# Test single operation (operation_name optional)
|
||||
single_op_request = ObjectsQueryRequest(
|
||||
single_op_request = RowsQueryRequest(
|
||||
user="test", collection="test",
|
||||
query='{ customers { id } }',
|
||||
variables={}, operation_name=""
|
||||
|
|
@ -15,7 +15,7 @@ from trustgraph.schema import (
|
|||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
StructuredObjectEmbedding, Field, RowSchema,
|
||||
Metadata, Error, Value
|
||||
Metadata, Error
|
||||
)
|
||||
from .conftest import serialize_deserialize_test
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, AgentRequest, AgentResponse
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
|
@ -30,38 +30,16 @@ class TestAgentKgExtractionIntegration:
|
|||
# Mock agent client
|
||||
agent_client = AsyncMock()
|
||||
|
||||
# Mock successful agent response
|
||||
# Mock successful agent response in JSONL format
|
||||
def mock_agent_response(recipient, question):
|
||||
# Simulate agent processing and return structured response
|
||||
# Simulate agent processing and return structured JSONL response
|
||||
mock_response = MagicMock()
|
||||
mock_response.error = None
|
||||
mock_response.answer = '''```json
|
||||
{
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": true
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": true
|
||||
}
|
||||
]
|
||||
}
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."}
|
||||
{"type": "definition", "entity": "Neural Networks", "definition": "Computing systems inspired by biological neural networks that process information."}
|
||||
{"type": "relationship", "subject": "Machine Learning", "predicate": "is_subset_of", "object": "Artificial Intelligence", "object-entity": true}
|
||||
{"type": "relationship", "subject": "Neural Networks", "predicate": "used_in", "object": "Machine Learning", "object-entity": true}
|
||||
```'''
|
||||
return mock_response.answer
|
||||
|
||||
|
|
@ -100,9 +78,9 @@ class TestAgentKgExtractionIntegration:
|
|||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -120,7 +98,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
# Copy the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
|
@ -156,7 +134,7 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_response = agent_client.invoke(recipient=lambda x: True, question=prompt)
|
||||
|
||||
# Parse and process
|
||||
extraction_data = extractor.parse_json(agent_response)
|
||||
extraction_data = extractor.parse_jsonl(agent_response)
|
||||
triples, entity_contexts = extractor.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Add metadata triples
|
||||
|
|
@ -200,15 +178,15 @@ class TestAgentKgExtractionIntegration:
|
|||
assert len(sent_triples.triples) > 0
|
||||
|
||||
# Check that we have definition triples
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.value == DEFINITION]
|
||||
definition_triples = [t for t in sent_triples.triples if t.p.iri == DEFINITION]
|
||||
assert len(definition_triples) >= 2 # Should have definitions for ML and Neural Networks
|
||||
|
||||
|
||||
# Check that we have label triples
|
||||
label_triples = [t for t in sent_triples.triples if t.p.value == RDF_LABEL]
|
||||
label_triples = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL]
|
||||
assert len(label_triples) >= 2 # Should have labels for entities
|
||||
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in sent_triples.triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) >= 2 # Entities should be linked to document
|
||||
|
||||
# Verify entity contexts were emitted
|
||||
|
|
@ -220,7 +198,7 @@ class TestAgentKgExtractionIntegration:
|
|||
assert len(sent_contexts.entities) >= 2 # Should have contexts for both entities
|
||||
|
||||
# Verify entity URIs are properly formed
|
||||
entity_uris = [ec.entity.value for ec in sent_contexts.entities]
|
||||
entity_uris = [ec.entity.iri for ec in sent_contexts.entities]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
|
|
@ -248,22 +226,28 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_response_handling(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
"""Test handling of invalid JSON responses from agent"""
|
||||
"""Test handling of invalid JSON responses from agent - JSONL is lenient and skips invalid lines"""
|
||||
# Arrange - mock invalid JSON response
|
||||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
|
||||
def mock_invalid_json_response(recipient, question):
|
||||
return "This is not valid JSON at all"
|
||||
|
||||
|
||||
agent_client.invoke = mock_invalid_json_response
|
||||
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises((ValueError, json.JSONDecodeError)):
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
# Act - JSONL parsing is lenient, invalid lines are skipped
|
||||
await configured_agent_extractor.on_message(mock_message, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert - should emit triples (with just metadata) but no entity contexts
|
||||
triples_publisher = mock_flow_context("triples")
|
||||
triples_publisher.send.assert_called_once()
|
||||
|
||||
entity_contexts_publisher = mock_flow_context("entity-contexts")
|
||||
entity_contexts_publisher.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results(self, configured_agent_extractor, sample_chunk, mock_flow_context):
|
||||
|
|
@ -272,7 +256,8 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_empty_response(recipient, question):
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
# Return empty JSONL (just empty/whitespace)
|
||||
return ''
|
||||
|
||||
agent_client.invoke = mock_empty_response
|
||||
|
||||
|
|
@ -303,7 +288,8 @@ class TestAgentKgExtractionIntegration:
|
|||
agent_client = mock_flow_context("agent-request")
|
||||
|
||||
def mock_malformed_response(recipient, question):
|
||||
return '''{"definitions": [{"entity": "Missing Definition"}], "relationships": [{"subject": "Missing Object"}]}'''
|
||||
# JSONL with definition missing required field
|
||||
return '{"type": "definition", "entity": "Missing Definition"}'
|
||||
|
||||
agent_client.invoke = mock_malformed_response
|
||||
|
||||
|
|
@ -330,7 +316,7 @@ class TestAgentKgExtractionIntegration:
|
|||
def capture_prompt(recipient, question):
|
||||
# Verify the prompt contains the test text
|
||||
assert test_text in question
|
||||
return '{"definitions": [], "relationships": []}'
|
||||
return '' # Empty JSONL response
|
||||
|
||||
agent_client.invoke = capture_prompt
|
||||
|
||||
|
|
@ -361,7 +347,7 @@ class TestAgentKgExtractionIntegration:
|
|||
responses = []
|
||||
|
||||
def mock_response(recipient, question):
|
||||
response = f'{{"definitions": [{{"entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}], "relationships": []}}'
|
||||
response = f'{{"type": "definition", "entity": "Entity {len(responses)}", "definition": "Definition {len(responses)}"}}'
|
||||
responses.append(response)
|
||||
return response
|
||||
|
||||
|
|
@ -398,7 +384,7 @@ class TestAgentKgExtractionIntegration:
|
|||
# Verify unicode text was properly decoded and included
|
||||
assert "学习机器" in question
|
||||
assert "人工知能" in question
|
||||
return '''{"definitions": [{"entity": "機械学習", "definition": "人工知能の一分野"}], "relationships": []}'''
|
||||
return '{"type": "definition", "entity": "機械学習", "definition": "人工知能の一分野"}'
|
||||
|
||||
agent_client.invoke = mock_unicode_response
|
||||
|
||||
|
|
@ -415,7 +401,7 @@ class TestAgentKgExtractionIntegration:
|
|||
|
||||
sent_triples = triples_publisher.send.call_args[0][0]
|
||||
# Check that unicode entity was properly processed
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.value == RDF_LABEL and t.o.value == "機械学習"]
|
||||
entity_labels = [t for t in sent_triples.triples if t.p.iri == RDF_LABEL and t.o.value == "機械学習"]
|
||||
assert len(entity_labels) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -433,7 +419,7 @@ class TestAgentKgExtractionIntegration:
|
|||
def mock_large_text_response(recipient, question):
|
||||
# Verify large text was included
|
||||
assert len(question) > 10000
|
||||
return '''{"definitions": [{"entity": "Machine Learning", "definition": "Important AI technique"}], "relationships": []}'''
|
||||
return '{"type": "definition", "entity": "Machine Learning", "definition": "Important AI technique"}'
|
||||
|
||||
agent_client.invoke = mock_large_text_response
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from argparse import ArgumentParser
|
|||
|
||||
# Import processors that use Cassandra configuration
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
|
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
|
|||
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
|
||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
|
||||
"""Test complete flow from environment variables to Cassandra Cluster connection."""
|
||||
env_vars = {
|
||||
|
|
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
# Trigger Cassandra connection
|
||||
processor.connect_cassandra()
|
||||
|
|
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
|
|||
class TestMultipleHostsHandling:
|
||||
"""Test multiple Cassandra hosts handling end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
|
||||
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
|
||||
env_vars = {
|
||||
|
|
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify all hosts were passed to Cluster
|
||||
|
|
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
|
|||
class TestAuthenticationFlow:
|
||||
"""Test authentication configuration flow end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is enabled when both username and password are provided."""
|
||||
env_vars = {
|
||||
|
|
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should be created
|
||||
|
|
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
|
|||
assert 'auth_provider' in call_args.kwargs
|
||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when credentials are missing."""
|
||||
env_vars = {
|
||||
|
|
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should not be created
|
||||
|
|
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
|
|||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when only username is provided."""
|
||||
processor = ObjectsWriter(
|
||||
processor = RowsWriter(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='partial-auth-host',
|
||||
cassandra_username='partial-user'
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .cassandra_test_helper import cassandra_container
|
|||
from trustgraph.direct.cassandra_kg import KnowledgeGraph
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
|
||||
from trustgraph.schema import Triple, Term, Metadata, Triples, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -118,19 +118,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value="Alice Smith", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value="Alice Smith")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="25", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value="25")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/person1", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value="Engineering", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/person1"),
|
||||
p=Term(type=IRI, iri="http://example.org/department"),
|
||||
o=Term(type=LITERAL, value="Engineering")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -181,19 +181,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="testuser", collection="testcol"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/bob", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/bob")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value="30", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value="30")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://example.org/bob", is_uri=True),
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
o=Value(value="http://example.org/charlie", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.org/bob"),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=Term(type=IRI, iri="http://example.org/charlie")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -208,7 +208,7 @@ class TestCassandraIntegration:
|
|||
|
||||
# Test S query (find all relationships for Alice)
|
||||
s_query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.org/alice", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.org/alice"),
|
||||
p=None, # None for wildcard
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
|
|
@ -218,18 +218,18 @@ class TestCassandraIntegration:
|
|||
s_results = await query_processor.query_triples(s_query)
|
||||
print(f"Query processor results: {len(s_results)}")
|
||||
for result in s_results:
|
||||
print(f" S={result.s.value}, P={result.p.value}, O={result.o.value}")
|
||||
print(f" S={result.s.iri}, P={result.p.iri}, O={result.o.iri if result.o.type == IRI else result.o.value}")
|
||||
assert len(s_results) == 2
|
||||
|
||||
s_predicates = [t.p.value for t in s_results]
|
||||
|
||||
s_predicates = [t.p.iri for t in s_results]
|
||||
assert "http://example.org/knows" in s_predicates
|
||||
assert "http://example.org/age" in s_predicates
|
||||
print("✓ Subject queries via processor working")
|
||||
|
||||
|
||||
# Test P query (find all "knows" relationships)
|
||||
p_query = TriplesQueryRequest(
|
||||
s=None, # None for wildcard
|
||||
p=Value(value="http://example.org/knows", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.org/knows"),
|
||||
o=None, # None for wildcard
|
||||
limit=10,
|
||||
user="testuser",
|
||||
|
|
@ -238,8 +238,8 @@ class TestCassandraIntegration:
|
|||
p_results = await query_processor.query_triples(p_query)
|
||||
print(p_results)
|
||||
assert len(p_results) == 2 # Alice knows Bob, Bob knows Charlie
|
||||
|
||||
p_subjects = [t.s.value for t in p_results]
|
||||
|
||||
p_subjects = [t.s.iri for t in p_results]
|
||||
assert "http://example.org/alice" in p_subjects
|
||||
assert "http://example.org/bob" in p_subjects
|
||||
print("✓ Predicate queries via processor working")
|
||||
|
|
@ -262,19 +262,19 @@ class TestCassandraIntegration:
|
|||
metadata=Metadata(user="concurrent_test", collection="people"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value=name, is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value=name)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/age", is_uri=True),
|
||||
o=Value(value=str(age), is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/age"),
|
||||
o=Term(type=LITERAL, value=str(age))
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=f"http://example.org/{person_id}", is_uri=True),
|
||||
p=Value(value="http://example.org/department", is_uri=True),
|
||||
o=Value(value=department, is_uri=False)
|
||||
s=Term(type=IRI, iri=f"http://example.org/{person_id}"),
|
||||
p=Term(type=IRI, iri="http://example.org/department"),
|
||||
o=Term(type=LITERAL, value=department)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -333,36 +333,36 @@ class TestCassandraIntegration:
|
|||
triples=[
|
||||
# People and their types
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://company.org/Employee")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/bob", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://company.org/Employee", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/bob"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://company.org/Employee")
|
||||
),
|
||||
# Relationships
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/reportsTo", is_uri=True),
|
||||
o=Value(value="http://company.org/bob", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/reportsTo"),
|
||||
o=Term(type=IRI, iri="http://company.org/bob")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/worksIn", is_uri=True),
|
||||
o=Value(value="http://company.org/engineering", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/worksIn"),
|
||||
o=Term(type=IRI, iri="http://company.org/engineering")
|
||||
),
|
||||
# Personal info
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/fullName", is_uri=True),
|
||||
o=Value(value="Alice Johnson", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/fullName"),
|
||||
o=Term(type=LITERAL, value="Alice Johnson")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://company.org/alice", is_uri=True),
|
||||
p=Value(value="http://company.org/email", is_uri=True),
|
||||
o=Value(value="alice@company.org", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://company.org/alice"),
|
||||
p=Term(type=IRI, iri="http://company.org/email"),
|
||||
o=Term(type=LITERAL, value="alice@company.org")
|
||||
),
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,10 +51,10 @@ class MockWebSocket:
|
|||
"metadata": {
|
||||
"id": "test-id",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ async def test_import_graceful_shutdown_integration(mock_backend):
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"object-{i}", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"object-{i}"}}]
|
||||
}
|
||||
messages.append(msg_data)
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ async def test_export_no_message_loss_integration(mock_backend):
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"export-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"export-object-{i}", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"export-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": f"export-object-{i}"}}]
|
||||
}
|
||||
# Create Triples object instead of raw dict
|
||||
from trustgraph.schema import Triples, Metadata
|
||||
|
|
@ -302,7 +302,7 @@ async def test_concurrent_import_export_shutdown():
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"concurrent-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"concurrent-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
await import_handler.receive(msg)
|
||||
|
||||
|
|
@ -359,7 +359,7 @@ async def test_websocket_close_during_message_processing():
|
|||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"slow-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": f"slow-subject-{i}"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
task = asyncio.create_task(import_handler.receive(msg))
|
||||
message_tasks.append(task)
|
||||
|
|
@ -423,7 +423,7 @@ async def test_backpressure_during_shutdown():
|
|||
# Simulate receiving and processing a message
|
||||
msg_data = {
|
||||
"metadata": {"id": f"msg-{i}"},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
"triples": [{"s": {"t": "l", "v": "subject"}, "p": {"t": "l", "v": "predicate"}, "o": {"t": "l", "v": "object"}}]
|
||||
}
|
||||
await ws.send_json(msg_data)
|
||||
# Check if we should stop
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor
|
||||
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
|
||||
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
||||
|
|
@ -147,6 +147,8 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor)
|
||||
processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor)
|
||||
processor.triples_batch_size = 50
|
||||
processor.entity_batch_size = 5
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -156,6 +158,7 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor)
|
||||
processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor)
|
||||
processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor)
|
||||
processor.triples_batch_size = 50
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -253,24 +256,24 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
if s and o:
|
||||
s_uri = definitions_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
s_term = Term(type=IRI, iri=str(s_uri))
|
||||
o_term = Term(type=LITERAL, value=str(o))
|
||||
|
||||
# Generate triples as the processor would
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=s, is_uri=False)
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=s)
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=o_value
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=DEFINITION),
|
||||
o=o_term
|
||||
))
|
||||
|
||||
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
entity=s_term,
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
|
|
@ -279,16 +282,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
assert len(entities) == 3 # 1 entity context per entity
|
||||
|
||||
# Verify triple structure
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
|
||||
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
|
||||
|
||||
assert len(label_triples) == 3
|
||||
assert len(definition_triples) == 3
|
||||
|
||||
|
||||
# Verify entity contexts
|
||||
for entity in entities:
|
||||
assert entity.entity.is_uri is True
|
||||
assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert entity.entity.type == IRI
|
||||
assert entity.entity.iri.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert len(entity.context) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -309,52 +312,52 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
o = rel["object"]
|
||||
|
||||
|
||||
if s and p and o:
|
||||
s_uri = relationships_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
s_term = Term(type=IRI, iri=str(s_uri))
|
||||
|
||||
p_uri = relationships_processor.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
p_term = Term(type=IRI, iri=str(p_uri))
|
||||
|
||||
if rel["object-entity"]:
|
||||
o_uri = relationships_processor.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
o_term = Term(type=IRI, iri=str(o_uri))
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
o_term = Term(type=LITERAL, value=str(o))
|
||||
|
||||
# Main relationship triple
|
||||
triples.append(Triple(s=s_value, p=p_value, o=o_value))
|
||||
|
||||
triples.append(Triple(s=s_term, p=p_term, o=o_term))
|
||||
|
||||
# Label triples
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(s), is_uri=False)
|
||||
s=s_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(s))
|
||||
))
|
||||
|
||||
|
||||
triples.append(Triple(
|
||||
s=p_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(p), is_uri=False)
|
||||
s=p_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(p))
|
||||
))
|
||||
|
||||
|
||||
if rel["object-entity"]:
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(o), is_uri=False)
|
||||
s=o_term,
|
||||
p=Term(type=IRI, iri=RDF_LABEL),
|
||||
o=Term(type=LITERAL, value=str(o))
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) > 0
|
||||
|
||||
# Verify relationship triples exist
|
||||
relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")]
|
||||
relationship_triples = [t for t in triples if t.p.iri.endswith("is_subset_of") or t.p.iri.endswith("is_used_in")]
|
||||
assert len(relationship_triples) >= 2
|
||||
|
||||
|
||||
# Verify label triples
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
label_triples = [t for t in triples if t.p.iri == RDF_LABEL]
|
||||
assert len(label_triples) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -374,9 +377,9 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True),
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=Value(value="A subset of AI", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/e/machine-learning"),
|
||||
p=Term(type=IRI, iri=DEFINITION),
|
||||
o=Term(type=LITERAL, value="A subset of AI")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -405,9 +408,14 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
entities=[]
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Term(type=IRI, iri="http://example.org/entity"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
|
|
@ -496,12 +504,12 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still call producers but with empty results
|
||||
# Should NOT call producers with empty results (avoids Cassandra NULL issues)
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
triples_producer.send.assert_not_called()
|
||||
entity_contexts_producer.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
|
|
@ -602,9 +610,9 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
collection="test_collection",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc:test", is_uri=True),
|
||||
p=Value(value="dc:title", is_uri=True),
|
||||
o=Value(value="Test Document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc:test"),
|
||||
p=Term(type=IRI, iri="dc:title"),
|
||||
o=Term(type=LITERAL, value="Test Document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
from trustgraph.extract.kg.rows.processor import Processor
|
||||
from trustgraph.schema import (
|
||||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
|
|
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Mock flow with failing prompt service
|
||||
|
|
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
|
|||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||
|
||||
# Import and bind the convert_values_to_strings function
|
||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
||||
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||
processor.convert_values_to_strings = convert_values_to_strings
|
||||
|
||||
# Load configuration
|
||||
|
|
|
|||
|
|
@ -1,608 +0,0 @@
|
|||
"""
|
||||
Integration tests for Cassandra Object Storage
|
||||
|
||||
These tests verify the end-to-end functionality of storing ExtractedObjects
|
||||
in Cassandra, including table creation, data insertion, and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestObjectsCassandraIntegration:
|
||||
"""Integration tests for Cassandra object storage"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
# Extract keyspace name from query
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
# Check if this keyspace was created
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_cluster(self, mock_cassandra_session):
|
||||
"""Mock Cassandra cluster"""
|
||||
cluster = MagicMock()
|
||||
cluster.connect.return_value = mock_cassandra_session
|
||||
cluster.shutdown = MagicMock()
|
||||
return cluster
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
||||
"""Create processor with mocked Cassandra dependencies"""
|
||||
processor = MagicMock()
|
||||
processor.graph_host = "localhost"
|
||||
processor.graph_username = None
|
||||
processor.graph_password = None
|
||||
processor.config_key = "schema"
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.cluster = None
|
||||
processor.session = None
|
||||
|
||||
# Bind actual methods
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
||||
"""Test complete flow from schema config to object storage"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
# Mock Cluster creation
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Step 1: Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer information",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True},
|
||||
{"name": "age", "type": "integer"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||
await processor.create_collection("test_user", "import_2024", {})
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
||||
# Verify keyspace creation
|
||||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
|
||||
assert "collection text" in str(table_calls[0])
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
|
||||
|
||||
# Verify index creation
|
||||
index_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE INDEX" in str(call)]
|
||||
assert len(index_calls) == 1
|
||||
assert "email" in str(index_calls[0])
|
||||
|
||||
# Verify data insertion
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
insert_call = insert_calls[0]
|
||||
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
|
||||
|
||||
# Check inserted values
|
||||
values = insert_call[0][1]
|
||||
assert "import_2024" in values # collection
|
||||
assert "CUST001" in values # customer_id
|
||||
assert "John Doe" in values # name
|
||||
assert "john@example.com" in values # email
|
||||
assert 30 in values # age (converted to int)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_schema_handling(self, processor_with_mocks):
|
||||
"""Test handling multiple schemas and objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure multiple schemas
|
||||
config = {
|
||||
"schema": {
|
||||
"products": json.dumps({
|
||||
"name": "products",
|
||||
"fields": [
|
||||
{"name": "product_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "price", "type": "float"}
|
||||
]
|
||||
}),
|
||||
"orders": json.dumps({
|
||||
"name": "orders",
|
||||
"fields": [
|
||||
{"name": "order_id", "type": "string", "primary_key": True},
|
||||
{"name": "customer_id", "type": "string"},
|
||||
{"name": "total", "type": "float"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
# Create collections first
|
||||
await processor.create_collection("shop", "catalog", {})
|
||||
await processor.create_collection("shop", "sales", {})
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify separate tables were created
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2
|
||||
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
|
||||
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_fields(self, processor_with_mocks):
|
||||
"""Test handling of objects with missing required fields"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema with required field
|
||||
processor.schemas["test_schema"] = RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True, required=True),
|
||||
Field(name="required_field", type="string", size=100, required=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "test", {})
|
||||
|
||||
# Create object missing required field
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123"}], # missing required_field
|
||||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Should still process (Cassandra doesn't enforce NOT NULL)
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert was attempted
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_without_primary_key(self, processor_with_mocks):
|
||||
"""Test handling schemas without defined primary keys"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema without primary key
|
||||
processor.schemas["events"] = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("logger", "app_events", {})
|
||||
|
||||
# Process object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
schema_name="events",
|
||||
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
|
||||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify synthetic_id was added
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "synthetic_id uuid" in str(table_calls[0])
|
||||
|
||||
# Verify insert includes UUID
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 1
|
||||
values = insert_calls[0][0][1]
|
||||
# Check that a UUID was generated (will be in values list)
|
||||
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
|
||||
assert uuid_found
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_handling(self, processor_with_mocks):
|
||||
"""Test Cassandra authentication"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
processor.cassandra_username = "cassandra_user"
|
||||
processor.cassandra_password = "cassandra_pass"
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
|
||||
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||
mock_cluster_class.return_value = mock_cluster
|
||||
|
||||
# Trigger connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify authentication was configured
|
||||
mock_auth.assert_called_once_with(
|
||||
username="cassandra_user",
|
||||
password="cassandra_pass"
|
||||
)
|
||||
mock_cluster_class.assert_called_once()
|
||||
call_kwargs = mock_cluster_class.call_args[1]
|
||||
assert 'auth_provider' in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_during_insert(self, processor_with_mocks):
|
||||
"""Test error handling when insertion fails"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["test"] = RowSchema(
|
||||
name="test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Make insert fail
|
||||
mock_result = MagicMock()
|
||||
mock_result.one.return_value = MagicMock() # Keyspace exists
|
||||
mock_session.execute.side_effect = [
|
||||
mock_result, # keyspace existence check succeeds
|
||||
None, # table creation succeeds
|
||||
Exception("Connection timeout") # insert fails
|
||||
]
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test",
|
||||
values=[{"id": "123"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Connection timeout"):
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_partitioning(self, processor_with_mocks):
|
||||
"""Test that objects are properly partitioned by collection"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["data"] = RowSchema(
|
||||
name="data",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process objects from different collections
|
||||
collections = ["import_jan", "import_feb", "import_mar"]
|
||||
|
||||
# Create all collections first
|
||||
for coll in collections:
|
||||
await processor.create_collection("analytics", coll, {})
|
||||
|
||||
for coll in collections:
|
||||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
schema_name="data",
|
||||
values=[{"id": f"ID-{coll}"}],
|
||||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify all inserts include collection in values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has the correct collection
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert collections[i] in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test_user", "batch_import", {})
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_batch_customers" in str(table_calls[0])
|
||||
|
||||
# Verify multiple inserts for batch values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
# Should have 3 separate inserts for the 3 objects in the batch
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has correct data
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert "batch_import" in values # collection
|
||||
assert f"CUST00{i+1}" in values # customer_id
|
||||
if i == 0:
|
||||
assert "John Doe" in values
|
||||
assert "john@example.com" in values
|
||||
elif i == 1:
|
||||
assert "Jane Smith" in values
|
||||
assert "jane@example.com" in values
|
||||
elif i == 2:
|
||||
assert "Bob Johnson" in values
|
||||
assert "bob@example.com" in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "empty", {})
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
|
||||
# Should not create any insert statements for empty batch
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
|
||||
"""Test processing mix of single and batch objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["mixed_test"] = RowSchema(
|
||||
name="mixed_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "mixed", {})
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with single item
|
||||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[
|
||||
{"id": "batch-1", "data": "batch data 1"},
|
||||
{"id": "batch-2", "data": "batch data 2"}
|
||||
],
|
||||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 total inserts (1 + 2)
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
492
tests/integration/test_rows_cassandra_integration.py
Normal file
492
tests/integration/test_rows_cassandra_integration.py
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
"""
|
||||
Integration tests for Cassandra Row Storage (Unified Table Implementation)
|
||||
|
||||
These tests verify the end-to-end functionality of storing ExtractedObjects
|
||||
in the unified Cassandra rows table, including table creation, data insertion,
|
||||
and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestRowsCassandraIntegration:
|
||||
"""Integration tests for Cassandra row storage with unified table"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_session(self):
|
||||
"""Mock Cassandra session for integration tests"""
|
||||
session = MagicMock()
|
||||
|
||||
# Track if keyspaces have been created
|
||||
created_keyspaces = set()
|
||||
|
||||
# Mock the execute method to return a valid result for keyspace checks
|
||||
def execute_mock(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
query_str = str(query)
|
||||
|
||||
# Track keyspace creation
|
||||
if "CREATE KEYSPACE" in query_str:
|
||||
import re
|
||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||
if match:
|
||||
created_keyspaces.add(match.group(1))
|
||||
|
||||
# For keyspace existence checks
|
||||
if "system_schema.keyspaces" in query_str:
|
||||
if args and args[0] in created_keyspaces:
|
||||
result.one.return_value = MagicMock() # Exists
|
||||
else:
|
||||
result.one.return_value = None # Doesn't exist
|
||||
else:
|
||||
result.one.return_value = None
|
||||
|
||||
return result
|
||||
|
||||
session.execute = MagicMock(side_effect=execute_mock)
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_cluster(self, mock_cassandra_session):
|
||||
"""Mock Cassandra cluster"""
|
||||
cluster = MagicMock()
|
||||
cluster.connect.return_value = mock_cassandra_session
|
||||
cluster.shutdown = MagicMock()
|
||||
return cluster
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
||||
"""Create processor with mocked Cassandra dependencies"""
|
||||
processor = MagicMock()
|
||||
processor.cassandra_host = ["localhost"]
|
||||
processor.cassandra_username = None
|
||||
processor.cassandra_password = None
|
||||
processor.config_key = "schema"
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.tables_initialized = set()
|
||||
processor.registered_partitions = set()
|
||||
processor.cluster = None
|
||||
processor.session = None
|
||||
|
||||
# Bind actual methods from the new unified table implementation
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
|
||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
||||
"""Test complete flow from schema config to object storage"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Step 1: Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer information",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True},
|
||||
{"name": "age", "type": "integer"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert "customer_records" in processor.schemas
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="doc-001",
|
||||
user="test_user",
|
||||
collection="import_2024",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify Cassandra interactions
|
||||
assert mock_cluster.connect.called
|
||||
|
||||
# Verify keyspace creation
|
||||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE KEYSPACE" in str(call)]
|
||||
assert len(keyspace_calls) == 1
|
||||
assert "test_user" in str(keyspace_calls[0])
|
||||
|
||||
# Verify unified table creation (rows table, not per-schema table)
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2 # rows table + row_partitions table
|
||||
assert any("rows" in str(call) for call in table_calls)
|
||||
assert any("row_partitions" in str(call) for call in table_calls)
|
||||
|
||||
# Verify the rows table has correct structure
|
||||
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
|
||||
assert "collection text" in str(rows_table_call)
|
||||
assert "schema_name text" in str(rows_table_call)
|
||||
assert "index_name text" in str(rows_table_call)
|
||||
assert "data map<text, text>" in str(rows_table_call)
|
||||
|
||||
# Verify data insertion into unified table
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
|
||||
assert len(rows_insert_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_schema_handling(self, processor_with_mocks):
|
||||
"""Test handling multiple schemas stored in unified table"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure multiple schemas
|
||||
config = {
|
||||
"schema": {
|
||||
"products": json.dumps({
|
||||
"name": "products",
|
||||
"fields": [
|
||||
{"name": "product_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string"},
|
||||
{"name": "price", "type": "float"}
|
||||
]
|
||||
}),
|
||||
"orders": json.dumps({
|
||||
"name": "orders",
|
||||
"fields": [
|
||||
{"name": "order_id", "type": "string", "primary_key": True},
|
||||
{"name": "customer_id", "type": "string"},
|
||||
{"name": "total", "type": "float"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
assert len(processor.schemas) == 2
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
||||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
||||
# Process both objects
|
||||
for obj in [product_obj, order_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# All data goes into the same unified rows table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
|
||||
assert len(table_calls) == 2
|
||||
|
||||
# Verify data inserts go to unified rows table
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) > 0
|
||||
for call in rows_insert_calls:
|
||||
assert ".rows" in str(call)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_index_storage(self, processor_with_mocks):
|
||||
"""Test that rows are stored with multiple indexes"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Schema with multiple indexed fields
|
||||
processor.schemas["indexed_data"] = RowSchema(
|
||||
name="indexed_data",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True),
|
||||
Field(name="status", type="string", size=50, indexed=True),
|
||||
Field(name="description", type="string", size=200) # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="indexed_data",
|
||||
values=[{
|
||||
"id": "123",
|
||||
"category": "electronics",
|
||||
"status": "active",
|
||||
"description": "A product"
|
||||
}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 3
|
||||
|
||||
# Verify different index names were used
|
||||
index_names = set()
|
||||
for call in rows_insert_calls:
|
||||
values = call[0][1]
|
||||
index_names.add(values[2]) # index_name is 3rd parameter
|
||||
|
||||
assert index_names == {"id", "category", "status"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_handling(self, processor_with_mocks):
|
||||
"""Test Cassandra authentication"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
processor.cassandra_username = "cassandra_user"
|
||||
processor.cassandra_password = "cassandra_pass"
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
|
||||
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||
mock_cluster_class.return_value = mock_cluster
|
||||
|
||||
# Trigger connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify authentication was configured
|
||||
mock_auth.assert_called_once_with(
|
||||
username="cassandra_user",
|
||||
password="cassandra_pass"
|
||||
)
|
||||
mock_cluster_class.assert_called_once()
|
||||
call_kwargs = mock_cluster_class.call_args[1]
|
||||
assert 'auth_provider' in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify unified table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 2 # rows + row_partitions
|
||||
|
||||
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
|
||||
# 3 rows * 2 indexes = 6 data inserts
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should not create any data insert statements for empty batch
|
||||
# (partition registration may still happen)
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_stored_as_map(self, processor_with_mocks):
|
||||
"""Test that data is stored as map<text, text>"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["map_test"] = RowSchema(
|
||||
name="map_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="count", type="integer", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="map_test",
|
||||
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert uses map for data
|
||||
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||
and "row_partitions" not in str(call)]
|
||||
assert len(rows_insert_calls) >= 1
|
||||
|
||||
# Check that data is passed as a dict (will be map in Cassandra)
|
||||
insert_call = rows_insert_calls[0]
|
||||
values = insert_call[0][1]
|
||||
# Values are: (collection, schema_name, index_name, index_value, data, source)
|
||||
# values[4] should be the data map
|
||||
data_map = values[4]
|
||||
assert isinstance(data_map, dict)
|
||||
assert data_map["id"] == "123"
|
||||
assert data_map["name"] == "Test Item"
|
||||
assert data_map["count"] == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partition_registration(self, processor_with_mocks):
|
||||
"""Test that partitions are registered for efficient querying"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["partition_test"] = RowSchema(
|
||||
name="partition_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=50, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
|
||||
schema_name="partition_test",
|
||||
values=[{"id": "123", "category": "test"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify partition registration
|
||||
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
|
||||
# Should register partitions for each index (id, category)
|
||||
assert len(partition_inserts) == 2
|
||||
|
||||
# Verify cache was updated
|
||||
assert ("my_collection", "partition_test") in processor.registered_partitions
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Integration tests for Objects GraphQL Query Service
|
||||
Integration tests for Rows GraphQL Query Service
|
||||
|
||||
These tests verify end-to-end functionality including:
|
||||
- Real Cassandra database operations
|
||||
|
|
@ -24,8 +24,8 @@ except Exception:
|
|||
DOCKER_AVAILABLE = False
|
||||
CassandraContainer = None
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
|
||||
|
||||
|
||||
|
|
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
request = ObjectsQueryRequest(
|
||||
request = RowsQueryRequest(
|
||||
user="msg_test_user",
|
||||
collection="msg_test_collection",
|
||||
query='{ customer_objects { customer_id name } }',
|
||||
|
|
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
|
||||
# Verify response structure
|
||||
sent_response = mock_response_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, ObjectsQueryResponse)
|
||||
assert isinstance(sent_response, RowsQueryResponse)
|
||||
|
||||
# Should have no system error (even if no data)
|
||||
assert sent_response.error is None
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
Integration tests for Structured Query Service
|
||||
|
||||
These tests verify the end-to-end functionality of the structured query service,
|
||||
testing orchestration between nlp-query and objects-query services.
|
||||
testing orchestration between nlp-query and rows-query services.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
RowsQueryRequest, RowsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
|
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects Query Service Response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
|
||||
errors=None,
|
||||
|
|
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
# Verify Objects service call
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert "customers" in objects_call_args.query
|
||||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
|
|
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects service failure
|
||||
objects_error_response = ObjectsQueryResponse(
|
||||
objects_error_response = RowsQueryResponse(
|
||||
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
|
||||
data=None,
|
||||
errors=None,
|
||||
|
|
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
|
||||
assert response.error is not None
|
||||
assert response.error.type == "structured-query-error"
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "Rows query service error" in response.error.message
|
||||
assert "nonexistent_table" in response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None, # No data when validation fails
|
||||
errors=validation_errors,
|
||||
|
|
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
]
|
||||
}
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(complex_data),
|
||||
errors=None,
|
||||
|
|
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock empty Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": []}', # Empty result set
|
||||
errors=None,
|
||||
|
|
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
||||
errors=None,
|
||||
|
|
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
if service_name == "nlp-query-request":
|
||||
service_call_count += 1
|
||||
return nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
service_call_count += 1
|
||||
return objects_client
|
||||
elif service_name == "response":
|
||||
|
|
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Mock Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
|
||||
errors=None,
|
||||
|
|
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
|
|||
|
|
@ -19,4 +19,5 @@ markers =
|
|||
integration: marks tests as integration tests
|
||||
unit: marks tests as unit tests
|
||||
contract: marks tests as contract tests (service interface validation)
|
||||
vertexai: marks tests as vertex ai specific tests
|
||||
vertexai: marks tests as vertex ai specific tests
|
||||
asyncio: marks tests that use asyncio
|
||||
|
|
@ -88,8 +88,13 @@ async def test_subscriber_deferred_acknowledgment_success():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscriber_deferred_acknowledgment_failure():
|
||||
"""Verify Subscriber negative acks on delivery failure."""
|
||||
async def test_subscriber_dropped_message_still_acks():
|
||||
"""Verify Subscriber acks even when message is dropped (backpressure).
|
||||
|
||||
This prevents redelivery storms on shared topics - if we negative_ack
|
||||
a dropped message, it gets redelivered to all subscribers, none of
|
||||
whom can handle it either, causing a tight redelivery loop.
|
||||
"""
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
|
@ -103,24 +108,66 @@ async def test_subscriber_deferred_acknowledgment_failure():
|
|||
max_size=1, # Very small queue
|
||||
backpressure_strategy="drop_new"
|
||||
)
|
||||
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
|
||||
# Create queue and fill it
|
||||
queue = await subscriber.subscribe("test-queue")
|
||||
await queue.put({"existing": "data"})
|
||||
|
||||
# Create mock message - should be dropped
|
||||
msg = create_mock_message("msg-1", {"data": "test"})
|
||||
|
||||
# Process message (should fail due to full queue + drop_new strategy)
|
||||
|
||||
# Create mock message - should be dropped due to full queue
|
||||
msg = create_mock_message("test-queue", {"data": "test"})
|
||||
|
||||
# Process message (should be dropped due to full queue + drop_new strategy)
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
# Should negative acknowledge failed delivery
|
||||
mock_consumer.negative_acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.acknowledge.assert_not_called()
|
||||
|
||||
|
||||
# Should acknowledge even though delivery failed - prevents redelivery storm
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
# Clean up
|
||||
await subscriber.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscriber_orphaned_message_acks():
|
||||
"""Verify Subscriber acks orphaned messages (no matching waiter).
|
||||
|
||||
On shared response topics, if a message arrives for a waiter that
|
||||
no longer exists (e.g., client disconnected, request timed out),
|
||||
we must acknowledge it to prevent redelivery storms.
|
||||
"""
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
backpressure_strategy="block"
|
||||
)
|
||||
|
||||
# Start subscriber to initialize consumer
|
||||
await subscriber.start()
|
||||
|
||||
# Don't create any queues - message will be orphaned
|
||||
# This simulates a response arriving after the waiter has unsubscribed
|
||||
|
||||
# Create mock message with an ID that has no matching waiter
|
||||
msg = create_mock_message("non-existent-waiter-id", {"data": "orphaned"})
|
||||
|
||||
# Process message (should be orphaned - no matching waiter)
|
||||
await subscriber._process_message(msg)
|
||||
|
||||
# Should acknowledge orphaned message - prevents redelivery storm
|
||||
mock_consumer.acknowledge.assert_called_once_with(msg)
|
||||
mock_consumer.negative_acknowledge.assert_not_called()
|
||||
|
||||
# Clean up
|
||||
await subscriber.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection="sales_data",
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -92,6 +95,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection=None, # No collection specified
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -132,6 +138,9 @@ class TestSetToolStructuredQuery:
|
|||
mcp_tool=None,
|
||||
collection='sales_data',
|
||||
template=None,
|
||||
schema_name=None,
|
||||
index_name=None,
|
||||
limit=None,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
|
|
@ -201,6 +210,144 @@ class TestSetToolStructuredQuery:
|
|||
assert 'Exception:' in printed_output or 'invalid choice:' in printed_output.lower()
|
||||
|
||||
|
||||
class TestSetToolRowEmbeddingsQuery:
|
||||
"""Test the set_tool function with row-embeddings-query type."""
|
||||
|
||||
@patch('trustgraph.cli.set_tool.Api')
|
||||
def test_set_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
|
||||
"""Test setting a row-embeddings-query tool with all parameters."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = []
|
||||
|
||||
set_tool(
|
||||
url="http://test.com",
|
||||
id="customer_search",
|
||||
name="find_customer",
|
||||
description="Find customers by name using semantic search",
|
||||
type="row-embeddings-query",
|
||||
mcp_tool=None,
|
||||
collection="sales",
|
||||
template=None,
|
||||
schema_name="customers",
|
||||
index_name="full_name",
|
||||
limit=20,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Tool set." in captured.out
|
||||
|
||||
# Verify the tool was stored correctly
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
config_value = call_args[0]
|
||||
assert config_value.type == "tool"
|
||||
assert config_value.key == "customer_search"
|
||||
|
||||
stored_tool = json.loads(config_value.value)
|
||||
assert stored_tool["name"] == "find_customer"
|
||||
assert stored_tool["type"] == "row-embeddings-query"
|
||||
assert stored_tool["collection"] == "sales"
|
||||
assert stored_tool["schema-name"] == "customers"
|
||||
assert stored_tool["index-name"] == "full_name"
|
||||
assert stored_tool["limit"] == 20
|
||||
|
||||
@patch('trustgraph.cli.set_tool.Api')
|
||||
def test_set_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
|
||||
"""Test setting row-embeddings-query tool with minimal parameters."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = []
|
||||
|
||||
set_tool(
|
||||
url="http://test.com",
|
||||
id="product_search",
|
||||
name="find_product",
|
||||
description="Find products using semantic search",
|
||||
type="row-embeddings-query",
|
||||
mcp_tool=None,
|
||||
collection=None,
|
||||
template=None,
|
||||
schema_name="products",
|
||||
index_name=None, # No index filter
|
||||
limit=None, # Use default
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Tool set." in captured.out
|
||||
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
stored_tool = json.loads(call_args[0].value)
|
||||
assert stored_tool["type"] == "row-embeddings-query"
|
||||
assert stored_tool["schema-name"] == "products"
|
||||
assert "index-name" not in stored_tool # Should not be included if None
|
||||
assert "limit" not in stored_tool # Should not be included if None
|
||||
assert "collection" not in stored_tool # Should not be included if None
|
||||
|
||||
def test_set_main_row_embeddings_query_with_all_options(self):
|
||||
"""Test set main() with row-embeddings-query tool type and all options."""
|
||||
test_args = [
|
||||
'tg-set-tool',
|
||||
'--id', 'customer_search',
|
||||
'--name', 'find_customer',
|
||||
'--type', 'row-embeddings-query',
|
||||
'--description', 'Find customers by name',
|
||||
'--schema-name', 'customers',
|
||||
'--collection', 'sales',
|
||||
'--index-name', 'full_name',
|
||||
'--limit', '25',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
|
||||
|
||||
set_main()
|
||||
|
||||
mock_set.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
id='customer_search',
|
||||
name='find_customer',
|
||||
description='Find customers by name',
|
||||
type='row-embeddings-query',
|
||||
mcp_tool=None,
|
||||
collection='sales',
|
||||
template=None,
|
||||
schema_name='customers',
|
||||
index_name='full_name',
|
||||
limit=25,
|
||||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None,
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_valid_types_includes_row_embeddings_query(self):
|
||||
"""Test that 'row-embeddings-query' is included in valid tool types."""
|
||||
test_args = [
|
||||
'tg-set-tool',
|
||||
'--id', 'test_tool',
|
||||
'--name', 'test_tool',
|
||||
'--type', 'row-embeddings-query',
|
||||
'--description', 'Test tool',
|
||||
'--schema-name', 'test_schema'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.set_tool.set_tool') as mock_set:
|
||||
|
||||
# Should not raise an exception about invalid type
|
||||
set_main()
|
||||
mock_set.assert_called_once()
|
||||
|
||||
|
||||
class TestShowToolsStructuredQuery:
|
||||
"""Test the show_tools function with structured-query tools."""
|
||||
|
||||
|
|
@ -259,9 +406,9 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_mixed_tool_types(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying multiple tool types including structured-query."""
|
||||
"""Test displaying multiple tool types including structured-query and row-embeddings-query."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": "ask_knowledge",
|
||||
|
|
@ -270,37 +417,47 @@ class TestShowToolsStructuredQuery:
|
|||
"collection": "docs"
|
||||
},
|
||||
{
|
||||
"name": "query_data",
|
||||
"name": "query_data",
|
||||
"description": "Query structured data",
|
||||
"type": "structured-query",
|
||||
"collection": "sales"
|
||||
},
|
||||
{
|
||||
"name": "find_customer",
|
||||
"description": "Find customers by semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"schema-name": "customers",
|
||||
"collection": "crm"
|
||||
},
|
||||
{
|
||||
"name": "complete_text",
|
||||
"description": "Generate text",
|
||||
"type": "text-completion"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
config_values = [
|
||||
ConfigValue(type="tool", key=f"tool_{i}", value=json.dumps(tool))
|
||||
for i, tool in enumerate(tools)
|
||||
]
|
||||
mock_config.get_values.return_value = config_values
|
||||
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
|
||||
# All tool types should be displayed
|
||||
assert "knowledge-query" in output
|
||||
assert "structured-query" in output
|
||||
assert "structured-query" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "text-completion" in output
|
||||
|
||||
|
||||
# Collections should be shown for appropriate tools
|
||||
assert "docs" in output # knowledge-query collection
|
||||
assert "sales" in output # structured-query collection
|
||||
assert "crm" in output # row-embeddings-query collection
|
||||
assert "customers" in output # row-embeddings-query schema-name
|
||||
|
||||
def test_show_main_parses_args_correctly(self):
|
||||
"""Test that show main() parses arguments correctly."""
|
||||
|
|
@ -317,6 +474,76 @@ class TestShowToolsStructuredQuery:
|
|||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
|
||||
|
||||
class TestShowToolsRowEmbeddingsQuery:
|
||||
"""Test the show_tools function with row-embeddings-query tools."""
|
||||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_row_embeddings_query_tool_full(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying a row-embeddings-query tool with all fields."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
tool_config = {
|
||||
"name": "find_customer",
|
||||
"description": "Find customers by name using semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"collection": "sales",
|
||||
"schema-name": "customers",
|
||||
"index-name": "full_name",
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
config_value = ConfigValue(
|
||||
type="tool",
|
||||
key="customer_search",
|
||||
value=json.dumps(tool_config)
|
||||
)
|
||||
mock_config.get_values.return_value = [config_value]
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Check that tool information is displayed
|
||||
assert "customer_search" in output
|
||||
assert "find_customer" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "sales" in output # Collection
|
||||
assert "customers" in output # Schema name
|
||||
assert "full_name" in output # Index name
|
||||
assert "20" in output # Limit
|
||||
|
||||
@patch('trustgraph.cli.show_tools.Api')
|
||||
def test_show_row_embeddings_query_tool_minimal(self, mock_api_class, mock_api, capsys):
|
||||
"""Test displaying row-embeddings-query tool with minimal fields."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
tool_config = {
|
||||
"name": "find_product",
|
||||
"description": "Find products using semantic search",
|
||||
"type": "row-embeddings-query",
|
||||
"schema-name": "products"
|
||||
# No collection, index-name, or limit
|
||||
}
|
||||
|
||||
config_value = ConfigValue(
|
||||
type="tool",
|
||||
key="product_search",
|
||||
value=json.dumps(tool_config)
|
||||
)
|
||||
mock_config.get_values.return_value = [config_value]
|
||||
|
||||
show_config("http://test.com")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Should display the tool with schema-name
|
||||
assert "product_search" in output
|
||||
assert "row-embeddings-query" in output
|
||||
assert "products" in output # Schema name
|
||||
|
||||
|
||||
class TestStructuredQueryToolValidation:
|
||||
"""Test validation specific to structured-query tools."""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|||
from unittest.mock import call
|
||||
|
||||
from trustgraph.cores.knowledge import KnowledgeManager
|
||||
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Value, EntityEmbeddings
|
||||
from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -71,15 +71,15 @@ def sample_triples():
|
|||
return Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc-id",
|
||||
user="test-user",
|
||||
user="test-user",
|
||||
collection="default", # This should be overridden
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.org/john", is_uri=True),
|
||||
p=Value(value="http://example.org/name", is_uri=True),
|
||||
o=Value(value="John Smith", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.org/john"),
|
||||
p=Term(type=IRI, iri="http://example.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -97,7 +97,7 @@ def sample_graph_embeddings():
|
|||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=Value(value="http://example.org/john", is_uri=True),
|
||||
entity=Term(type=IRI, iri="http://example.org/john"),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
]
|
||||
|
|
|
|||
599
tests/unit/test_direct/test_entity_centric_kg.py
Normal file
599
tests/unit/test_direct/test_entity_centric_kg.py
Normal file
|
|
@ -0,0 +1,599 @@
|
|||
"""
|
||||
Unit tests for EntityCentricKnowledgeGraph class
|
||||
|
||||
Tests the entity-centric knowledge graph implementation without requiring
|
||||
an actual Cassandra connection. Uses mocking to verify correct behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
import os
|
||||
|
||||
|
||||
class TestEntityCentricKnowledgeGraph:
|
||||
"""Test cases for EntityCentricKnowledgeGraph"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cluster(self):
|
||||
"""Create a mock Cassandra cluster"""
|
||||
with patch('trustgraph.direct.cassandra_kg.Cluster') as mock_cluster_cls:
|
||||
mock_cluster = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster.connect.return_value = mock_session
|
||||
mock_cluster_cls.return_value = mock_cluster
|
||||
yield mock_cluster_cls, mock_cluster, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def entity_kg(self, mock_cluster):
|
||||
"""Create an EntityCentricKnowledgeGraph instance with mocked Cassandra"""
|
||||
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
|
||||
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
|
||||
|
||||
# Create instance
|
||||
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
|
||||
return kg, mock_session
|
||||
|
||||
def test_init_creates_entity_centric_schema(self, mock_cluster):
|
||||
"""Test that initialization creates the 2-table entity-centric schema"""
|
||||
from trustgraph.direct.cassandra_kg import EntityCentricKnowledgeGraph
|
||||
mock_cluster_cls, mock_cluster, mock_session = mock_cluster
|
||||
|
||||
kg = EntityCentricKnowledgeGraph(hosts=['localhost'], keyspace='test_keyspace')
|
||||
|
||||
# Verify schema tables were created
|
||||
execute_calls = mock_session.execute.call_args_list
|
||||
executed_statements = [str(c) for c in execute_calls]
|
||||
|
||||
# Check for keyspace creation
|
||||
keyspace_created = any('create keyspace' in str(c).lower() for c in execute_calls)
|
||||
assert keyspace_created
|
||||
|
||||
# Check for quads_by_entity table
|
||||
entity_table_created = any('quads_by_entity' in str(c) for c in execute_calls)
|
||||
assert entity_table_created
|
||||
|
||||
# Check for quads_by_collection table
|
||||
collection_table_created = any('quads_by_collection' in str(c) for c in execute_calls)
|
||||
assert collection_table_created
|
||||
|
||||
# Check for collection_metadata table
|
||||
metadata_table_created = any('collection_metadata' in str(c) for c in execute_calls)
|
||||
assert metadata_table_created
|
||||
|
||||
def test_prepare_statements_initialized(self, entity_kg):
|
||||
"""Test that prepared statements are initialized"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Verify prepare was called for various statements
|
||||
assert mock_session.prepare.called
|
||||
prepare_calls = mock_session.prepare.call_args_list
|
||||
|
||||
# Check that key prepared statements exist
|
||||
prepared_queries = [str(c) for c in prepare_calls]
|
||||
|
||||
# Insert statements
|
||||
insert_entity_stmt = any('INSERT INTO' in str(c) and 'quads_by_entity' in str(c)
|
||||
for c in prepare_calls)
|
||||
assert insert_entity_stmt
|
||||
|
||||
insert_collection_stmt = any('INSERT INTO' in str(c) and 'quads_by_collection' in str(c)
|
||||
for c in prepare_calls)
|
||||
assert insert_collection_stmt
|
||||
|
||||
def test_insert_uri_object_creates_4_entity_rows(self, entity_kg):
|
||||
"""Test that inserting a quad with URI object creates 4 entity rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Reset mocks to track only insert-related calls
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g='http://example.org/graph1',
|
||||
otype='u'
|
||||
)
|
||||
|
||||
# Verify batch was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_literal_object_creates_3_entity_rows(self, entity_kg):
|
||||
"""Test that inserting a quad with literal object creates 3 entity rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://www.w3.org/2000/01/rdf-schema#label',
|
||||
o='Alice Smith',
|
||||
g=None,
|
||||
otype='l',
|
||||
dtype='xsd:string',
|
||||
lang='en'
|
||||
)
|
||||
|
||||
# Verify batch was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_default_graph(self, entity_kg):
|
||||
"""Test that None graph is stored as empty string"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g=None,
|
||||
otype='u'
|
||||
)
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_insert_auto_detects_otype(self, entity_kg):
|
||||
"""Test that otype is auto-detected when not provided"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
# URI should be auto-detected
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob'
|
||||
)
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
mock_session.reset_mock()
|
||||
|
||||
# Literal should be auto-detected
|
||||
kg.insert(
|
||||
collection='test_collection',
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/name',
|
||||
o='Alice'
|
||||
)
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_get_s_returns_quads_for_subject(self, entity_kg):
|
||||
"""Test get_s queries by subject"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Mock the query result
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', o='http://example.org/Bob',
|
||||
d='', otype='u', dtype='', lang='', s='http://example.org/Alice')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice')
|
||||
|
||||
# Verify query was executed
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
# Results should be QuadResult objects
|
||||
assert len(results) == 1
|
||||
assert results[0].s == 'http://example.org/Alice'
|
||||
assert results[0].p == 'http://example.org/knows'
|
||||
assert results[0].o == 'http://example.org/Bob'
|
||||
|
||||
def test_get_p_returns_quads_for_predicate(self, entity_kg):
|
||||
"""Test get_p queries by predicate"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', o='http://example.org/Bob',
|
||||
d='', otype='u', dtype='', lang='', p='http://example.org/knows')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_p('test_collection', 'http://example.org/knows')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_o_returns_quads_for_object(self, entity_kg):
|
||||
"""Test get_o queries by object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
|
||||
d='', otype='u', dtype='', lang='', o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_o('test_collection', 'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_sp_returns_quads_for_subject_predicate(self, entity_kg):
|
||||
"""Test get_sp queries by subject and predicate"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(o='http://example.org/Bob', d='', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_sp('test_collection', 'http://example.org/Alice',
|
||||
'http://example.org/knows')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_po_returns_quads_for_predicate_object(self, entity_kg):
|
||||
"""Test get_po queries by predicate and object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', d='', otype='u', dtype='', lang='',
|
||||
o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_po('test_collection', 'http://example.org/knows',
|
||||
'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_os_returns_quads_for_object_subject(self, entity_kg):
|
||||
"""Test get_os queries by object and subject"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='', otype='u', dtype='', lang='',
|
||||
s='http://example.org/Alice', o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_os('test_collection', 'http://example.org/Bob',
|
||||
'http://example.org/Alice')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_spo_returns_quads_for_subject_predicate_object(self, entity_kg):
|
||||
"""Test get_spo queries by subject, predicate, and object"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(d='', otype='u', dtype='', lang='',
|
||||
o='http://example.org/Bob')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_spo('test_collection', 'http://example.org/Alice',
|
||||
'http://example.org/knows', 'http://example.org/Bob')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_g_returns_quads_for_graph(self, entity_kg):
|
||||
"""Test get_g queries by graph"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_g('test_collection', 'http://example.org/graph1')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_get_all_returns_all_quads_in_collection(self, entity_kg):
|
||||
"""Test get_all returns all quads"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u', dtype='', lang='')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_all('test_collection')
|
||||
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_graph_wildcard_returns_all_graphs(self, entity_kg):
|
||||
"""Test that g='*' returns quads from all graphs"""
|
||||
from trustgraph.direct.cassandra_kg import GRAPH_WILDCARD
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Bob'),
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Charlie')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice', g=GRAPH_WILDCARD)
|
||||
|
||||
# Should return quads from both graphs
|
||||
assert len(results) == 2
|
||||
|
||||
def test_specific_graph_filters_results(self, entity_kg):
|
||||
"""Test that specifying a graph filters results"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph1',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Bob'),
|
||||
MagicMock(p='http://example.org/knows', d='http://example.org/graph2',
|
||||
otype='u', dtype='', lang='', s='http://example.org/Alice',
|
||||
o='http://example.org/Charlie')
|
||||
]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = kg.get_s('test_collection', 'http://example.org/Alice',
|
||||
g='http://example.org/graph1')
|
||||
|
||||
# Should only return quads from graph1
|
||||
assert len(results) == 1
|
||||
assert results[0].g == 'http://example.org/graph1'
|
||||
|
||||
def test_collection_exists_returns_true_when_exists(self, entity_kg):
|
||||
"""Test collection_exists returns True for existing collection"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_result = [MagicMock(collection='test_collection')]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
exists = kg.collection_exists('test_collection')
|
||||
|
||||
assert exists is True
|
||||
|
||||
def test_collection_exists_returns_false_when_not_exists(self, entity_kg):
|
||||
"""Test collection_exists returns False for non-existing collection"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.execute.return_value = []
|
||||
|
||||
exists = kg.collection_exists('nonexistent_collection')
|
||||
|
||||
assert exists is False
|
||||
|
||||
def test_create_collection_inserts_metadata(self, entity_kg):
|
||||
"""Test create_collection inserts metadata row"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
mock_session.reset_mock()
|
||||
kg.create_collection('test_collection')
|
||||
|
||||
# Verify INSERT was executed for collection_metadata
|
||||
mock_session.execute.assert_called()
|
||||
|
||||
def test_delete_collection_removes_all_data(self, entity_kg):
|
||||
"""Test delete_collection removes entity partitions and collection rows"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
# Mock reading quads from collection
|
||||
mock_quads = [
|
||||
MagicMock(d='', s='http://example.org/Alice', p='http://example.org/knows',
|
||||
o='http://example.org/Bob', otype='u')
|
||||
]
|
||||
mock_session.execute.return_value = mock_quads
|
||||
|
||||
mock_session.reset_mock()
|
||||
kg.delete_collection('test_collection')
|
||||
|
||||
# Verify delete operations were executed
|
||||
assert mock_session.execute.called
|
||||
|
||||
def test_close_shuts_down_connections(self, entity_kg):
|
||||
"""Test close shuts down session and cluster"""
|
||||
kg, mock_session = entity_kg
|
||||
|
||||
kg.close()
|
||||
|
||||
mock_session.shutdown.assert_called_once()
|
||||
kg.cluster.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestQuadResult:
|
||||
"""Test cases for QuadResult class"""
|
||||
|
||||
def test_quad_result_stores_all_fields(self):
|
||||
"""Test QuadResult stores all quad fields"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/Alice',
|
||||
p='http://example.org/knows',
|
||||
o='http://example.org/Bob',
|
||||
g='http://example.org/graph1',
|
||||
otype='u',
|
||||
dtype='',
|
||||
lang=''
|
||||
)
|
||||
|
||||
assert result.s == 'http://example.org/Alice'
|
||||
assert result.p == 'http://example.org/knows'
|
||||
assert result.o == 'http://example.org/Bob'
|
||||
assert result.g == 'http://example.org/graph1'
|
||||
assert result.otype == 'u'
|
||||
assert result.dtype == ''
|
||||
assert result.lang == ''
|
||||
|
||||
def test_quad_result_defaults(self):
|
||||
"""Test QuadResult default values"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/s',
|
||||
p='http://example.org/p',
|
||||
o='literal value',
|
||||
g=''
|
||||
)
|
||||
|
||||
assert result.otype == 'u' # Default otype
|
||||
assert result.dtype == ''
|
||||
assert result.lang == ''
|
||||
|
||||
def test_quad_result_with_literal_metadata(self):
|
||||
"""Test QuadResult with literal metadata"""
|
||||
from trustgraph.direct.cassandra_kg import QuadResult
|
||||
|
||||
result = QuadResult(
|
||||
s='http://example.org/Alice',
|
||||
p='http://www.w3.org/2000/01/rdf-schema#label',
|
||||
o='Alice Smith',
|
||||
g='',
|
||||
otype='l',
|
||||
dtype='xsd:string',
|
||||
lang='en'
|
||||
)
|
||||
|
||||
assert result.otype == 'l'
|
||||
assert result.dtype == 'xsd:string'
|
||||
assert result.lang == 'en'
|
||||
|
||||
|
||||
class TestWriteHelperFunctions:
|
||||
"""Test cases for helper functions in write.py"""
|
||||
|
||||
def test_get_term_otype_for_iri(self):
|
||||
"""Test get_term_otype returns 'u' for IRI terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_otype(term) == 'u'
|
||||
|
||||
def test_get_term_otype_for_literal(self):
|
||||
"""Test get_term_otype returns 'l' for LITERAL terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='Alice Smith')
|
||||
assert get_term_otype(term) == 'l'
|
||||
|
||||
def test_get_term_otype_for_blank(self):
|
||||
"""Test get_term_otype returns 'u' for BLANK terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, BLANK
|
||||
|
||||
term = Term(type=BLANK, id='_:b1')
|
||||
assert get_term_otype(term) == 'u'
|
||||
|
||||
def test_get_term_otype_for_triple(self):
|
||||
"""Test get_term_otype returns 't' for TRIPLE terms"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
from trustgraph.schema import Term, TRIPLE
|
||||
|
||||
term = Term(type=TRIPLE)
|
||||
assert get_term_otype(term) == 't'
|
||||
|
||||
def test_get_term_otype_for_none(self):
|
||||
"""Test get_term_otype returns 'u' for None"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_otype
|
||||
|
||||
assert get_term_otype(None) == 'u'
|
||||
|
||||
def test_get_term_dtype_for_literal(self):
|
||||
"""Test get_term_dtype extracts datatype from LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='42', datatype='xsd:integer')
|
||||
assert get_term_dtype(term) == 'xsd:integer'
|
||||
|
||||
def test_get_term_dtype_for_non_literal(self):
|
||||
"""Test get_term_dtype returns empty string for non-LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_dtype(term) == ''
|
||||
|
||||
def test_get_term_dtype_for_none(self):
|
||||
"""Test get_term_dtype returns empty string for None"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_dtype
|
||||
|
||||
assert get_term_dtype(None) == ''
|
||||
|
||||
def test_get_term_lang_for_literal(self):
|
||||
"""Test get_term_lang extracts language from LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_lang
|
||||
from trustgraph.schema import Term, LITERAL
|
||||
|
||||
term = Term(type=LITERAL, value='Alice Smith', language='en')
|
||||
assert get_term_lang(term) == 'en'
|
||||
|
||||
def test_get_term_lang_for_non_literal(self):
|
||||
"""Test get_term_lang returns empty string for non-LITERAL"""
|
||||
from trustgraph.storage.triples.cassandra.write import get_term_lang
|
||||
from trustgraph.schema import Term, IRI
|
||||
|
||||
term = Term(type=IRI, iri='http://example.org/Alice')
|
||||
assert get_term_lang(term) == ''
|
||||
|
||||
|
||||
class TestServiceHelperFunctions:
|
||||
"""Test cases for helper functions in service.py"""
|
||||
|
||||
def test_create_term_with_uri_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='u'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice', otype='u')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_with_literal_otype(self):
|
||||
"""Test create_term creates LITERAL Term for otype='l'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith', otype='l', dtype='xsd:string', lang='en')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
assert term.datatype == 'xsd:string'
|
||||
assert term.language == 'en'
|
||||
|
||||
def test_create_term_with_triple_otype(self):
|
||||
"""Test create_term creates IRI Term for otype='t'"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/statement1', otype='t')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/statement1'
|
||||
|
||||
def test_create_term_heuristic_fallback_uri(self):
|
||||
"""Test create_term uses URL heuristic when otype not provided"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import IRI
|
||||
|
||||
term = create_term('http://example.org/Alice')
|
||||
|
||||
assert term.type == IRI
|
||||
assert term.iri == 'http://example.org/Alice'
|
||||
|
||||
def test_create_term_heuristic_fallback_literal(self):
|
||||
"""Test create_term uses literal heuristic when otype not provided"""
|
||||
from trustgraph.query.triples.cassandra.service import create_term
|
||||
from trustgraph.schema import LITERAL
|
||||
|
||||
term = create_term('Alice Smith')
|
||||
|
||||
assert term.type == LITERAL
|
||||
assert term.value == 'Alice Smith'
|
||||
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
"""
|
||||
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
|
||||
Tests the Stage 1 processor that computes embeddings for row index fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
||||
"""Test row embeddings processor functionality"""
|
||||
|
||||
async def test_processor_initialization(self):
|
||||
"""Test basic processor initialization"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert hasattr(processor, 'schemas')
|
||||
assert processor.schemas == {}
|
||||
assert processor.batch_size == 10 # default
|
||||
|
||||
async def test_processor_initialization_with_custom_batch_size(self):
|
||||
"""Test processor initialization with custom batch size"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-row-embeddings',
|
||||
'batch_size': 25
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
assert processor.batch_size == 25
|
||||
|
||||
async def test_get_index_names_single_index(self):
|
||||
"""Test getting index names with single indexed field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
Field(name='email', type='text', indexed=False),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed field
|
||||
assert 'id' in index_names
|
||||
assert 'name' in index_names
|
||||
assert 'email' not in index_names
|
||||
|
||||
async def test_get_index_names_no_indexes(self):
|
||||
"""Test getting index names when no fields are indexed"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema = RowSchema(
|
||||
name='logs',
|
||||
description='Log records',
|
||||
fields=[
|
||||
Field(name='timestamp', type='text'),
|
||||
Field(name='message', type='text'),
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert index_names == []
|
||||
|
||||
async def test_build_index_value_single_field(self):
|
||||
"""Test building index value for single field"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'id': 'CUST001',
|
||||
'name': 'John Doe',
|
||||
'email': 'john@example.com'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'name')
|
||||
|
||||
assert result == ['John Doe']
|
||||
|
||||
async def test_build_index_value_composite_index(self):
|
||||
"""Test building index value for composite index"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'first_name': 'John',
|
||||
'last_name': 'Doe',
|
||||
'city': 'New York'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'first_name, last_name')
|
||||
|
||||
assert result == ['John', 'Doe']
|
||||
|
||||
async def test_build_index_value_missing_field(self):
|
||||
"""Test building index value when field is missing"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
value_map = {
|
||||
'name': 'John Doe'
|
||||
}
|
||||
|
||||
result = processor.build_index_value(value_map, 'missing_field')
|
||||
|
||||
assert result == ['']
|
||||
|
||||
async def test_build_text_for_embedding_single_value(self):
|
||||
"""Test building text representation for single value"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John Doe'])
|
||||
|
||||
assert result == 'John Doe'
|
||||
|
||||
async def test_build_text_for_embedding_multiple_values(self):
|
||||
"""Test building text representation for multiple values"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
|
||||
|
||||
assert result == 'John Doe NYC'
|
||||
|
||||
async def test_on_schema_config_loads_schemas(self):
|
||||
"""Test that schema configuration is loaded correctly"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
schema_def = {
|
||||
'name': 'customers',
|
||||
'description': 'Customer records',
|
||||
'fields': [
|
||||
{'name': 'id', 'type': 'text', 'primary_key': True},
|
||||
{'name': 'name', 'type': 'text', 'indexed': True},
|
||||
{'name': 'email', 'type': 'text'}
|
||||
]
|
||||
}
|
||||
|
||||
config_data = {
|
||||
'schema': {
|
||||
'customers': json.dumps(schema_def)
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert 'customers' in processor.schemas
|
||||
assert processor.schemas['customers'].name == 'customers'
|
||||
assert len(processor.schemas['customers'].fields) == 3
|
||||
|
||||
async def test_on_schema_config_handles_missing_type(self):
|
||||
"""Test that missing schema type is handled gracefully"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
config_data = {
|
||||
'other_type': {}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config_data, 1)
|
||||
|
||||
assert processor.schemas == {}
|
||||
|
||||
async def test_on_message_drops_unknown_collection(self):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_drops_unknown_schema(self):
|
||||
"""Test that messages for unknown schemas are dropped"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
# No schemas registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='unknown_schema',
|
||||
values=[{'id': '123', 'name': 'Test'}]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
mock_flow = MagicMock()
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Flow should not be called for output
|
||||
mock_flow.assert_not_called()
|
||||
|
||||
async def test_on_message_processes_embeddings(self):
|
||||
"""Test processing a message and computing embeddings"""
|
||||
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||
from trustgraph.schema import ExtractedObject, RowSchema, Field
|
||||
import json
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor',
|
||||
'config_type': 'schema'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Set up schema
|
||||
processor.schemas['customers'] = RowSchema(
|
||||
name='customers',
|
||||
description='Customer records',
|
||||
fields=[
|
||||
Field(name='id', type='text', primary=True),
|
||||
Field(name='name', type='text', indexed=True),
|
||||
]
|
||||
)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
obj = ExtractedObject(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
values=[
|
||||
{'id': 'CUST001', 'name': 'John Doe'},
|
||||
{'id': 'CUST002', 'name': 'Jane Smith'}
|
||||
]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = obj
|
||||
|
||||
# Mock the flow
|
||||
mock_embeddings_request = AsyncMock()
|
||||
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_output = AsyncMock()
|
||||
|
||||
def flow_factory(name):
|
||||
if name == 'embeddings-request':
|
||||
return mock_embeddings_request
|
||||
elif name == 'output':
|
||||
return mock_output
|
||||
return MagicMock()
|
||||
|
||||
mock_flow = MagicMock(side_effect=flow_factory)
|
||||
|
||||
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||
|
||||
# Should have called embed for each unique text
|
||||
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||
assert mock_embeddings_request.embed.call_count == 4
|
||||
|
||||
# Should have sent output
|
||||
mock_output.send.assert_called()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -7,7 +7,7 @@ collecting labels and definitions for entity embedding and retrieval.
|
|||
|
||||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
from trustgraph.schema.knowledge.graph import EntityContext
|
||||
|
||||
|
||||
|
|
@ -25,9 +25,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that entity context is built from rdfs:label."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/cornish-pasty", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Cornish Pasty", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/cornish-pasty"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Cornish Pasty")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -35,16 +35,16 @@ class TestEntityContextBuilding:
|
|||
|
||||
assert len(contexts) == 1, "Should create one entity context"
|
||||
assert isinstance(contexts[0], EntityContext)
|
||||
assert contexts[0].entity.value == "https://example.com/entity/cornish-pasty"
|
||||
assert contexts[0].entity.iri == "https://example.com/entity/cornish-pasty"
|
||||
assert "Label: Cornish Pasty" in contexts[0].context
|
||||
|
||||
def test_builds_context_from_definition(self, processor):
|
||||
"""Test that entity context includes definitions."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/pasty", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="A baked pastry filled with savory ingredients", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/pasty"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="A baked pastry filled with savory ingredients")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -57,14 +57,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that label and definition are combined in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Pasty Recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Pasty Recipe")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Traditional Cornish pastry recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Traditional Cornish pastry recipe")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -80,14 +80,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that only the first label is used in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="First Label", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="First Label")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Second Label", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Second Label")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -101,14 +101,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that all definitions are included in context."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="First definition", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="First definition")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Second definition", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Second definition")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -123,9 +123,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that schema.org description is treated as definition."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="https://schema.org/description", is_uri=True),
|
||||
o=Value(value="A delicious food item", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="https://schema.org/description"),
|
||||
o=Term(type=LITERAL, value="A delicious food item")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -138,26 +138,26 @@ class TestEntityContextBuilding:
|
|||
"""Test that contexts are created for multiple entities."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity One", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity One")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity2", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity Two", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity Two")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity3", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity Three", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity3"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity Three")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 3, "Should create context for each entity"
|
||||
entity_uris = [ctx.entity.value for ctx in contexts]
|
||||
entity_uris = [ctx.entity.iri for ctx in contexts]
|
||||
assert "https://example.com/entity/entity1" in entity_uris
|
||||
assert "https://example.com/entity/entity2" in entity_uris
|
||||
assert "https://example.com/entity/entity3" in entity_uris
|
||||
|
|
@ -166,9 +166,9 @@ class TestEntityContextBuilding:
|
|||
"""Test that URI objects are ignored (only literal labels/definitions)."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="https://example.com/some/uri", is_uri=True) # URI, not literal
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=IRI, iri="https://example.com/some/uri") # URI, not literal
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -181,14 +181,14 @@ class TestEntityContextBuilding:
|
|||
"""Test that other predicates are ignored."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Food", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Food")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/food1", is_uri=True),
|
||||
p=Value(value="http://example.com/produces", is_uri=True),
|
||||
o=Value(value="https://example.com/entity/food2", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/food1"),
|
||||
p=Term(type=IRI, iri="http://example.com/produces"),
|
||||
o=Term(type=IRI, iri="https://example.com/entity/food2")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -205,29 +205,29 @@ class TestEntityContextBuilding:
|
|||
|
||||
assert len(contexts) == 0, "Empty triple list should return empty contexts"
|
||||
|
||||
def test_entity_context_has_value_object(self, processor):
|
||||
"""Test that EntityContext.entity is a Value object."""
|
||||
def test_entity_context_has_term_object(self, processor):
|
||||
"""Test that EntityContext.entity is a Term object."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test Entity", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test Entity")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 1
|
||||
assert isinstance(contexts[0].entity, Value), "Entity should be Value object"
|
||||
assert contexts[0].entity.is_uri, "Entity should be marked as URI"
|
||||
assert isinstance(contexts[0].entity, Term), "Entity should be Term object"
|
||||
assert contexts[0].entity.type == IRI, "Entity should be IRI type"
|
||||
|
||||
def test_entity_context_text_is_string(self, processor):
|
||||
"""Test that EntityContext.context is a string."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test Entity", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test Entity")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -241,22 +241,22 @@ class TestEntityContextBuilding:
|
|||
triples = [
|
||||
# Entity with label - should create context
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Entity One", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Entity One")
|
||||
),
|
||||
# Entity with only rdf:type - should NOT create context
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/entity2", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Food", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/entity2"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Food")
|
||||
)
|
||||
]
|
||||
|
||||
contexts = processor.build_entity_contexts(triples)
|
||||
|
||||
assert len(contexts) == 1, "Should only create context for entity with label/definition"
|
||||
assert contexts[0].entity.value == "https://example.com/entity/entity1"
|
||||
assert contexts[0].entity.iri == "https://example.com/entity/entity1"
|
||||
|
||||
|
||||
class TestEntityContextEdgeCases:
|
||||
|
|
@ -266,9 +266,9 @@ class TestEntityContextEdgeCases:
|
|||
"""Test handling of unicode characters in labels."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/café", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Café Spécial", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/café"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Café Spécial")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -282,9 +282,9 @@ class TestEntityContextEdgeCases:
|
|||
long_def = "This is a very long definition " * 50
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value=long_def, is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value=long_def)
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -297,9 +297,9 @@ class TestEntityContextEdgeCases:
|
|||
"""Test handling of special characters in context text."""
|
||||
triples = [
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/test", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Test & Entity <with> \"quotes\"", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/test"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Test & Entity <with> \"quotes\"")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -313,27 +313,27 @@ class TestEntityContextEdgeCases:
|
|||
triples = [
|
||||
# Label - relevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Cornish Pasty Recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2000/01/rdf-schema#label"),
|
||||
o=Term(type=LITERAL, value="Cornish Pasty Recipe")
|
||||
),
|
||||
# Type - irrelevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://example.com/Recipe", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri="http://example.com/Recipe")
|
||||
),
|
||||
# Property - irrelevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://example.com/produces", is_uri=True),
|
||||
o=Value(value="https://example.com/entity/pasty", is_uri=True)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://example.com/produces"),
|
||||
o=Term(type=IRI, iri="https://example.com/entity/pasty")
|
||||
),
|
||||
# Definition - relevant
|
||||
Triple(
|
||||
s=Value(value="https://example.com/entity/recipe1", is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2004/02/skos/core#definition", is_uri=True),
|
||||
o=Value(value="Traditional British pastry recipe", is_uri=False)
|
||||
s=Term(type=IRI, iri="https://example.com/entity/recipe1"),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/2004/02/skos/core#definition"),
|
||||
o=Term(type=LITERAL, value="Traditional British pastry recipe")
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ the knowledge graph.
|
|||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -92,12 +92,12 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for Recipe class
|
||||
recipe_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(recipe_type_triples) == 1, "Should generate exactly one type triple per class"
|
||||
assert recipe_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#Class", \
|
||||
assert recipe_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#Class", \
|
||||
"Class type should be owl:Class"
|
||||
|
||||
def test_generates_class_labels(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -107,14 +107,14 @@ class TestOntologyTripleGeneration:
|
|||
# Find label triples for Recipe class
|
||||
recipe_label_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
|
||||
assert len(recipe_label_triples) == 1, "Should generate label triple for class"
|
||||
assert recipe_label_triples[0].o.value == "Recipe", \
|
||||
"Label should match class label from ontology"
|
||||
assert not recipe_label_triples[0].o.is_uri, \
|
||||
assert recipe_label_triples[0].o.type == LITERAL, \
|
||||
"Label should be a literal, not URI"
|
||||
|
||||
def test_generates_class_comments(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -124,8 +124,8 @@ class TestOntologyTripleGeneration:
|
|||
# Find comment triples for Recipe class
|
||||
recipe_comment_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#comment"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#comment"
|
||||
]
|
||||
|
||||
assert len(recipe_comment_triples) == 1, "Should generate comment triple for class"
|
||||
|
|
@ -139,13 +139,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for ingredients property
|
||||
ingredients_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(ingredients_type_triples) == 1, \
|
||||
"Should generate exactly one type triple per object property"
|
||||
assert ingredients_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#ObjectProperty", \
|
||||
assert ingredients_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#ObjectProperty", \
|
||||
"Object property type should be owl:ObjectProperty"
|
||||
|
||||
def test_generates_object_property_labels(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -155,8 +155,8 @@ class TestOntologyTripleGeneration:
|
|||
# Find label triples for ingredients property
|
||||
ingredients_label_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
|
||||
assert len(ingredients_label_triples) == 1, \
|
||||
|
|
@ -171,15 +171,15 @@ class TestOntologyTripleGeneration:
|
|||
# Find domain triples for ingredients property
|
||||
ingredients_domain_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#domain"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/ingredients"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#domain"
|
||||
]
|
||||
|
||||
assert len(ingredients_domain_triples) == 1, \
|
||||
"Should generate domain triple for object property"
|
||||
assert ingredients_domain_triples[0].o.value == "http://purl.org/ontology/fo/Recipe", \
|
||||
assert ingredients_domain_triples[0].o.iri == "http://purl.org/ontology/fo/Recipe", \
|
||||
"Domain should be Recipe class URI"
|
||||
assert ingredients_domain_triples[0].o.is_uri, \
|
||||
assert ingredients_domain_triples[0].o.type == IRI, \
|
||||
"Domain should be a URI reference"
|
||||
|
||||
def test_generates_object_property_range(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -189,13 +189,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find range triples for produces property
|
||||
produces_range_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/produces"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/produces"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
]
|
||||
|
||||
assert len(produces_range_triples) == 1, \
|
||||
"Should generate range triple for object property"
|
||||
assert produces_range_triples[0].o.value == "http://purl.org/ontology/fo/Food", \
|
||||
assert produces_range_triples[0].o.iri == "http://purl.org/ontology/fo/Food", \
|
||||
"Range should be Food class URI"
|
||||
|
||||
def test_generates_datatype_property_type_triples(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -205,13 +205,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find type triples for serves property
|
||||
serves_type_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
]
|
||||
|
||||
assert len(serves_type_triples) == 1, \
|
||||
"Should generate exactly one type triple per datatype property"
|
||||
assert serves_type_triples[0].o.value == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
|
||||
assert serves_type_triples[0].o.iri == "http://www.w3.org/2002/07/owl#DatatypeProperty", \
|
||||
"Datatype property type should be owl:DatatypeProperty"
|
||||
|
||||
def test_generates_datatype_property_range(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -221,13 +221,13 @@ class TestOntologyTripleGeneration:
|
|||
# Find range triples for serves property
|
||||
serves_range_triples = [
|
||||
t for t in triples
|
||||
if t.s.value == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.value == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
if t.s.iri == "http://purl.org/ontology/fo/serves"
|
||||
and t.p.iri == "http://www.w3.org/2000/01/rdf-schema#range"
|
||||
]
|
||||
|
||||
assert len(serves_range_triples) == 1, \
|
||||
"Should generate range triple for datatype property"
|
||||
assert serves_range_triples[0].o.value == "http://www.w3.org/2001/XMLSchema#string", \
|
||||
assert serves_range_triples[0].o.iri == "http://www.w3.org/2001/XMLSchema#string", \
|
||||
"Range should be XSD type URI (xsd:string expanded)"
|
||||
|
||||
def test_generates_triples_for_all_classes(self, extractor, sample_ontology_subset):
|
||||
|
|
@ -236,9 +236,9 @@ class TestOntologyTripleGeneration:
|
|||
|
||||
# Count unique class subjects
|
||||
class_subjects = set(
|
||||
t.s.value for t in triples
|
||||
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and t.o.value == "http://www.w3.org/2002/07/owl#Class"
|
||||
t.s.iri for t in triples
|
||||
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and t.o.iri == "http://www.w3.org/2002/07/owl#Class"
|
||||
)
|
||||
|
||||
assert len(class_subjects) == 3, \
|
||||
|
|
@ -250,9 +250,9 @@ class TestOntologyTripleGeneration:
|
|||
|
||||
# Count unique property subjects (object + datatype properties)
|
||||
property_subjects = set(
|
||||
t.s.value for t in triples
|
||||
if t.p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and ("ObjectProperty" in t.o.value or "DatatypeProperty" in t.o.value)
|
||||
t.s.iri for t in triples
|
||||
if t.p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
and ("ObjectProperty" in t.o.iri or "DatatypeProperty" in t.o.iri)
|
||||
)
|
||||
|
||||
assert len(property_subjects) == 3, \
|
||||
|
|
@ -276,7 +276,7 @@ class TestOntologyTripleGeneration:
|
|||
# Should still generate proper RDF triples despite dict field names
|
||||
label_triples = [
|
||||
t for t in triples
|
||||
if t.p.value == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
if t.p.iri == "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
]
|
||||
assert len(label_triples) > 0, \
|
||||
"Should generate rdfs:label triples from dict 'labels' field"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ and extracts/validates triples from LLM responses.
|
|||
import pytest
|
||||
from trustgraph.extract.kg.ontology.extract import Processor
|
||||
from trustgraph.extract.kg.ontology.ontology_selector import OntologySubset
|
||||
from trustgraph.schema.core.primitives import Triple, Value
|
||||
from trustgraph.schema.core.primitives import Triple, Term, IRI, LITERAL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -248,9 +248,9 @@ class TestTripleParsing:
|
|||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert len(validated) == 1, "Should parse one valid triple"
|
||||
assert validated[0].s.value == "https://trustgraph.ai/food/cornish-pasty"
|
||||
assert validated[0].p.value == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].s.iri == "https://trustgraph.ai/food/cornish-pasty"
|
||||
assert validated[0].p.iri == "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
|
||||
def test_parse_multiple_triples(self, extractor, sample_ontology_subset):
|
||||
"""Test parsing multiple triples."""
|
||||
|
|
@ -307,11 +307,11 @@ class TestTripleParsing:
|
|||
|
||||
assert len(validated) == 1
|
||||
# Subject should be expanded to entity URI
|
||||
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
|
||||
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
|
||||
# Predicate should be expanded to ontology URI
|
||||
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
|
||||
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
|
||||
# Object should be expanded to class URI
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Food"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Food"
|
||||
|
||||
def test_creates_proper_triple_objects(self, extractor, sample_ontology_subset):
|
||||
"""Test that Triple objects are properly created."""
|
||||
|
|
@ -324,12 +324,12 @@ class TestTripleParsing:
|
|||
assert len(validated) == 1
|
||||
triple = validated[0]
|
||||
assert isinstance(triple, Triple), "Should create Triple objects"
|
||||
assert isinstance(triple.s, Value), "Subject should be Value object"
|
||||
assert isinstance(triple.p, Value), "Predicate should be Value object"
|
||||
assert isinstance(triple.o, Value), "Object should be Value object"
|
||||
assert triple.s.is_uri, "Subject should be marked as URI"
|
||||
assert triple.p.is_uri, "Predicate should be marked as URI"
|
||||
assert not triple.o.is_uri, "Object literal should not be marked as URI"
|
||||
assert isinstance(triple.s, Term), "Subject should be Term object"
|
||||
assert isinstance(triple.p, Term), "Predicate should be Term object"
|
||||
assert isinstance(triple.o, Term), "Object should be Term object"
|
||||
assert triple.s.type == IRI, "Subject should be IRI type"
|
||||
assert triple.p.type == IRI, "Predicate should be IRI type"
|
||||
assert triple.o.type == LITERAL, "Object literal should be LITERAL type"
|
||||
|
||||
|
||||
class TestURIExpansionInExtraction:
|
||||
|
|
@ -343,8 +343,8 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].o.value == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].o.is_uri, "Class reference should be URI"
|
||||
assert validated[0].o.iri == "http://purl.org/ontology/fo/Recipe"
|
||||
assert validated[0].o.type == IRI, "Class reference should be URI"
|
||||
|
||||
def test_expands_property_names(self, extractor, sample_ontology_subset):
|
||||
"""Test that property names are expanded to full URIs."""
|
||||
|
|
@ -354,7 +354,7 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].p.value == "http://purl.org/ontology/fo/produces"
|
||||
assert validated[0].p.iri == "http://purl.org/ontology/fo/produces"
|
||||
|
||||
def test_expands_entity_instances(self, extractor, sample_ontology_subset):
|
||||
"""Test that entity instances get constructed URIs."""
|
||||
|
|
@ -364,8 +364,8 @@ class TestURIExpansionInExtraction:
|
|||
|
||||
validated = extractor.parse_and_validate_triples(triples_response, sample_ontology_subset)
|
||||
|
||||
assert validated[0].s.value.startswith("https://trustgraph.ai/food/")
|
||||
assert "my-special-recipe" in validated[0].s.value
|
||||
assert validated[0].s.iri.startswith("https://trustgraph.ai/food/")
|
||||
assert "my-special-recipe" in validated[0].s.iri
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestDispatchSerialize:
|
||||
|
|
@ -14,55 +14,55 @@ class TestDispatchSerialize:
|
|||
|
||||
def test_to_value_with_uri(self):
|
||||
"""Test to_value function with URI"""
|
||||
input_data = {"v": "http://example.com/resource", "e": True}
|
||||
|
||||
input_data = {"t": "i", "i": "http://example.com/resource"}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_to_value_with_literal(self):
|
||||
"""Test to_value function with literal value"""
|
||||
input_data = {"v": "literal string", "e": False}
|
||||
|
||||
input_data = {"t": "l", "v": "literal string"}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_to_subgraph_with_multiple_triples(self):
|
||||
"""Test to_subgraph function with multiple triples"""
|
||||
input_data = [
|
||||
{
|
||||
"s": {"v": "subject1", "e": True},
|
||||
"p": {"v": "predicate1", "e": True},
|
||||
"o": {"v": "object1", "e": False}
|
||||
"s": {"t": "i", "i": "subject1"},
|
||||
"p": {"t": "i", "i": "predicate1"},
|
||||
"o": {"t": "l", "v": "object1"}
|
||||
},
|
||||
{
|
||||
"s": {"v": "subject2", "e": False},
|
||||
"p": {"v": "predicate2", "e": True},
|
||||
"o": {"v": "object2", "e": True}
|
||||
"s": {"t": "l", "v": "subject2"},
|
||||
"p": {"t": "i", "i": "predicate2"},
|
||||
"o": {"t": "i", "i": "object2"}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(triple, Triple) for triple in result)
|
||||
|
||||
|
||||
# Check first triple
|
||||
assert result[0].s.value == "subject1"
|
||||
assert result[0].s.is_uri is True
|
||||
assert result[0].p.value == "predicate1"
|
||||
assert result[0].p.is_uri is True
|
||||
assert result[0].s.iri == "subject1"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "predicate1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "object1"
|
||||
assert result[0].o.is_uri is False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Check second triple
|
||||
assert result[1].s.value == "subject2"
|
||||
assert result[1].s.is_uri is False
|
||||
assert result[1].s.type == LITERAL
|
||||
|
||||
def test_to_subgraph_with_empty_list(self):
|
||||
"""Test to_subgraph function with empty input"""
|
||||
|
|
@ -74,16 +74,16 @@ class TestDispatchSerialize:
|
|||
|
||||
def test_serialize_value_with_uri(self):
|
||||
"""Test serialize_value function with URI value"""
|
||||
value = Value(value="http://example.com/test", is_uri=True)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "http://example.com/test", "e": True}
|
||||
term = Term(type=IRI, iri="http://example.com/test")
|
||||
|
||||
result = serialize_value(term)
|
||||
|
||||
assert result == {"t": "i", "i": "http://example.com/test"}
|
||||
|
||||
def test_serialize_value_with_literal(self):
|
||||
"""Test serialize_value function with literal value"""
|
||||
value = Value(value="test literal", is_uri=False)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "test literal", "e": False}
|
||||
term = Term(type=LITERAL, value="test literal")
|
||||
|
||||
result = serialize_value(term)
|
||||
|
||||
assert result == {"t": "l", "v": "test literal"}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Unit tests for objects import dispatcher.
|
||||
Unit tests for rows import dispatcher.
|
||||
|
||||
Tests the business logic of objects import dispatcher
|
||||
Tests the business logic of rows import dispatcher
|
||||
while mocking the Publisher and websocket components.
|
||||
"""
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
|
||||
from trustgraph.gateway.dispatch.rows_import import RowsImport
|
||||
from trustgraph.schema import Metadata, ExtractedObject
|
||||
|
||||
|
||||
|
|
@ -92,16 +92,16 @@ def minimal_objects_message():
|
|||
}
|
||||
|
||||
|
||||
class TestObjectsImportInitialization:
|
||||
"""Test ObjectsImport initialization."""
|
||||
class TestRowsImportInitialization:
|
||||
"""Test RowsImport initialization."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||
"""Test that RowsImport creates Publisher with correct parameters."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
|
|||
)
|
||||
|
||||
# Verify instance variables are set correctly
|
||||
assert objects_import.ws == mock_websocket
|
||||
assert objects_import.running == mock_running
|
||||
assert objects_import.publisher == mock_publisher_instance
|
||||
assert rows_import.ws == mock_websocket
|
||||
assert rows_import.running == mock_running
|
||||
assert rows_import.publisher == mock_publisher_instance
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport stores all required references."""
|
||||
objects_import = ObjectsImport(
|
||||
"""Test that RowsImport stores all required references."""
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="objects-queue"
|
||||
)
|
||||
|
||||
assert objects_import.ws is mock_websocket
|
||||
assert objects_import.running is mock_running
|
||||
assert rows_import.ws is mock_websocket
|
||||
assert rows_import.running is mock_running
|
||||
|
||||
|
||||
class TestObjectsImportLifecycle:
|
||||
"""Test ObjectsImport lifecycle methods."""
|
||||
class TestRowsImportLifecycle:
|
||||
"""Test RowsImport lifecycle methods."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that start() calls publisher.start()."""
|
||||
|
|
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.start()
|
||||
await rows_import.start()
|
||||
|
||||
mock_publisher_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||
|
|
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
# Verify sequence of operations
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that destroy() handles None websocket gracefully."""
|
||||
|
|
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
|
|||
mock_publisher_instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.destroy()
|
||||
await rows_import.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_publisher_instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestObjectsImportMessageProcessing:
|
||||
"""Test ObjectsImport message processing."""
|
||||
class TestRowsImportMessageProcessing:
|
||||
"""Test RowsImport message processing."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() processes complete message correctly."""
|
||||
|
|
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.metadata.collection == "testcollection"
|
||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||
"""Test that receive() handles message with minimal required fields."""
|
||||
|
|
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = minimal_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == "" # Default value
|
||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() uses appropriate default values for optional fields."""
|
||||
|
|
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = message_data
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Get the sent object and verify defaults
|
||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||
|
|
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
|
|||
assert sent_object.source_span == ""
|
||||
|
||||
|
||||
class TestObjectsImportRunMethod:
|
||||
"""Test ObjectsImport run method."""
|
||||
class TestRowsImportRunMethod:
|
||||
"""Test RowsImport run method."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that run() loops while running.get() returns True."""
|
||||
|
|
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
|
|||
# Set up running state to return True twice, then False
|
||||
mock_running.get.side_effect = [True, True, False]
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify sleep was called twice (for the two True iterations)
|
||||
assert mock_sleep.call_count == 2
|
||||
|
|
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
|
|||
mock_websocket.close.assert_called_once()
|
||||
|
||||
# Verify websocket was set to None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that run() handles None websocket gracefully."""
|
||||
|
|
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
|
|||
|
||||
mock_running.get.return_value = False # Exit immediately
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
|
|||
)
|
||||
|
||||
# Should not raise exception
|
||||
await objects_import.run()
|
||||
await rows_import.run()
|
||||
|
||||
# Verify websocket remains None
|
||||
assert objects_import.ws is None
|
||||
assert rows_import.ws is None
|
||||
|
||||
|
||||
class TestObjectsImportBatchProcessing:
|
||||
"""Test ObjectsImport batch processing functionality."""
|
||||
class TestRowsImportBatchProcessing:
|
||||
"""Test RowsImport batch processing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_objects_message(self):
|
||||
|
|
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
|
|||
"source_span": "Multiple people found in document"
|
||||
}
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
|
|
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = batch_objects_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Verify publisher.send was called
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
|
|||
assert sent_object.confidence == 0.85
|
||||
assert sent_object.source_span == "Multiple people found in document"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
|
|
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_publisher_instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
|
|||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_batch_message
|
||||
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
# Should still send the message
|
||||
mock_publisher_instance.send.assert_called_once()
|
||||
|
|
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
|
|||
assert len(sent_object.values) == 0
|
||||
|
||||
|
||||
class TestObjectsImportErrorHandling:
|
||||
"""Test error handling in ObjectsImport."""
|
||||
class TestRowsImportErrorHandling:
|
||||
"""Test error handling in RowsImport."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() propagates publisher send errors."""
|
||||
|
|
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
|
|||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.return_value = sample_objects_message
|
||||
|
||||
with pytest.raises(Exception, match="Publisher error"):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles malformed JSON appropriately."""
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
rows_import = RowsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
|
|
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
|
|||
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
await objects_import.receive(mock_msg)
|
||||
await rows_import.receive(mock_msg)
|
||||
|
|
@ -6,11 +6,21 @@ import pytest
|
|||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
# Mock schema classes for testing
|
||||
class Value:
|
||||
def __init__(self, value, is_uri, type):
|
||||
self.value = value
|
||||
self.is_uri = is_uri
|
||||
# Term type constants
|
||||
IRI = "i"
|
||||
LITERAL = "l"
|
||||
BLANK = "b"
|
||||
TRIPLE = "t"
|
||||
|
||||
class Term:
|
||||
def __init__(self, type, iri=None, value=None, id=None, datatype=None, language=None, triple=None):
|
||||
self.type = type
|
||||
self.iri = iri
|
||||
self.value = value
|
||||
self.id = id
|
||||
self.datatype = datatype
|
||||
self.language = language
|
||||
self.triple = triple
|
||||
|
||||
class Triple:
|
||||
def __init__(self, s, p, o):
|
||||
|
|
@ -66,32 +76,30 @@ def sample_relationships():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_value_uri():
|
||||
"""Sample URI Value object"""
|
||||
return Value(
|
||||
value="http://example.com/person/john-smith",
|
||||
is_uri=True,
|
||||
type=""
|
||||
def sample_term_uri():
|
||||
"""Sample URI Term object"""
|
||||
return Term(
|
||||
type=IRI,
|
||||
iri="http://example.com/person/john-smith"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_value_literal():
|
||||
"""Sample literal Value object"""
|
||||
return Value(
|
||||
value="John Smith",
|
||||
is_uri=False,
|
||||
type="string"
|
||||
def sample_term_literal():
|
||||
"""Sample literal Term object"""
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value="John Smith"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_triple(sample_value_uri, sample_value_literal):
|
||||
def sample_triple(sample_term_uri, sample_term_literal):
|
||||
"""Sample Triple object"""
|
||||
return Triple(
|
||||
s=sample_value_uri,
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=sample_value_literal
|
||||
s=sample_term_uri,
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=sample_term_literal
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from trustgraph.template.prompt_manager import PromptManager
|
||||
|
|
@ -33,7 +33,7 @@ class TestAgentKgExtractor:
|
|||
|
||||
# Set up the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
|
@ -53,48 +53,49 @@ class TestAgentKgExtractor:
|
|||
id="doc123",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc123", is_uri=True),
|
||||
p=Value(value="http://example.org/type", is_uri=True),
|
||||
o=Value(value="document", is_uri=False)
|
||||
s=Term(type=IRI, iri="doc123"),
|
||||
p=Term(type=IRI, iri="http://example.org/type"),
|
||||
o=Term(type=LITERAL, value="document")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_extraction_data(self):
|
||||
"""Sample extraction data in expected format"""
|
||||
return {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
}
|
||||
"""Sample extraction data in JSONL format (list with type discriminators)"""
|
||||
return [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
def test_to_uri_conversion(self, agent_extractor):
|
||||
"""Test URI conversion for entities"""
|
||||
|
|
@ -113,148 +114,147 @@ class TestAgentKgExtractor:
|
|||
expected = f"{TRUSTGRAPH_ENTITIES}"
|
||||
assert uri == expected
|
||||
|
||||
def test_parse_json_with_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing from code blocks"""
|
||||
# Test JSON in code blocks
|
||||
def test_parse_jsonl_with_code_blocks(self, agent_extractor):
|
||||
"""Test JSONL parsing from code blocks"""
|
||||
# Test JSONL in code blocks - note: JSON uses lowercase true/false
|
||||
response = '''```json
|
||||
{
|
||||
"definitions": [{"entity": "AI", "definition": "Artificial Intelligence"}],
|
||||
"relationships": []
|
||||
}
|
||||
```'''
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert result["definitions"][0]["entity"] == "AI"
|
||||
assert result["definitions"][0]["definition"] == "Artificial Intelligence"
|
||||
assert result["relationships"] == []
|
||||
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}
|
||||
{"type": "relationship", "subject": "AI", "predicate": "is", "object": "technology", "object-entity": false}
|
||||
```'''
|
||||
|
||||
def test_parse_json_without_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing without code blocks"""
|
||||
response = '''{"definitions": [{"entity": "ML", "definition": "Machine Learning"}], "relationships": []}'''
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert result["definitions"][0]["entity"] == "ML"
|
||||
assert result["definitions"][0]["definition"] == "Machine Learning"
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
def test_parse_json_invalid_format(self, agent_extractor):
|
||||
"""Test JSON parsing with invalid format"""
|
||||
invalid_response = "This is not JSON at all"
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
agent_extractor.parse_json(invalid_response)
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "AI"
|
||||
assert result[0]["definition"] == "Artificial Intelligence"
|
||||
assert result[1]["type"] == "relationship"
|
||||
|
||||
def test_parse_json_malformed_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing with malformed code blocks"""
|
||||
# Missing closing backticks
|
||||
response = '''```json
|
||||
{"definitions": [], "relationships": []}
|
||||
'''
|
||||
|
||||
# Should still parse the JSON content
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
agent_extractor.parse_json(response)
|
||||
def test_parse_jsonl_without_code_blocks(self, agent_extractor):
|
||||
"""Test JSONL parsing without code blocks"""
|
||||
response = '''{"type": "definition", "entity": "ML", "definition": "Machine Learning"}
|
||||
{"type": "definition", "entity": "AI", "definition": "Artificial Intelligence"}'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "ML"
|
||||
assert result[1]["entity"] == "AI"
|
||||
|
||||
def test_parse_jsonl_invalid_lines_skipped(self, agent_extractor):
|
||||
"""Test JSONL parsing skips invalid lines gracefully"""
|
||||
response = '''{"type": "definition", "entity": "Valid", "definition": "Valid def"}
|
||||
This is not JSON at all
|
||||
{"type": "definition", "entity": "Also Valid", "definition": "Another def"}'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 2 valid objects, skipping the invalid line
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "Valid"
|
||||
assert result[1]["entity"] == "Also Valid"
|
||||
|
||||
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
|
||||
"""Test JSONL parsing handles truncated responses"""
|
||||
# Simulates output cut off mid-line
|
||||
response = '''{"type": "definition", "entity": "Complete", "definition": "Full def"}
|
||||
{"type": "definition", "entity": "Trunca'''
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 1 valid object, the truncated line is skipped
|
||||
assert len(result) == 1
|
||||
assert result[0]["entity"] == "Complete"
|
||||
|
||||
def test_process_extraction_data_definitions(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of definition data"""
|
||||
data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of AI that enables learning from data."
|
||||
}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of AI that enables learning from data."
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check entity label triple
|
||||
label_triple = next((t for t in triples if t.p.value == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
||||
label_triple = next((t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "Machine Learning"), None)
|
||||
assert label_triple is not None
|
||||
assert label_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert label_triple.s.is_uri == True
|
||||
assert label_triple.o.is_uri == False
|
||||
|
||||
assert label_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert label_triple.s.type == IRI
|
||||
assert label_triple.o.type == LITERAL
|
||||
|
||||
# Check definition triple
|
||||
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
|
||||
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
|
||||
assert def_triple is not None
|
||||
assert def_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert def_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert def_triple.o.value == "A subset of AI that enables learning from data."
|
||||
|
||||
|
||||
# Check subject-of triple
|
||||
subject_of_triple = next((t for t in triples if t.p.value == SUBJECT_OF), None)
|
||||
subject_of_triple = next((t for t in triples if t.p.iri == SUBJECT_OF), None)
|
||||
assert subject_of_triple is not None
|
||||
assert subject_of_triple.s.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert subject_of_triple.o.value == "doc123"
|
||||
|
||||
assert subject_of_triple.s.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert subject_of_triple.o.iri == "doc123"
|
||||
|
||||
# Check entity context
|
||||
assert len(entity_contexts) == 1
|
||||
assert entity_contexts[0].entity.value == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert entity_contexts[0].entity.iri == f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
assert entity_contexts[0].context == "A subset of AI that enables learning from data."
|
||||
|
||||
def test_process_extraction_data_relationships(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of relationship data"""
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Check that subject, predicate, and object labels are created
|
||||
subject_uri = f"{TRUSTGRAPH_ENTITIES}Machine%20Learning"
|
||||
predicate_uri = f"{TRUSTGRAPH_ENTITIES}is_subset_of"
|
||||
|
||||
|
||||
# Find label triples
|
||||
subject_label = next((t for t in triples if t.s.value == subject_uri and t.p.value == RDF_LABEL), None)
|
||||
subject_label = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == RDF_LABEL), None)
|
||||
assert subject_label is not None
|
||||
assert subject_label.o.value == "Machine Learning"
|
||||
|
||||
predicate_label = next((t for t in triples if t.s.value == predicate_uri and t.p.value == RDF_LABEL), None)
|
||||
|
||||
predicate_label = next((t for t in triples if t.s.iri == predicate_uri and t.p.iri == RDF_LABEL), None)
|
||||
assert predicate_label is not None
|
||||
assert predicate_label.o.value == "is_subset_of"
|
||||
|
||||
# Check main relationship triple
|
||||
# NOTE: Current implementation has bugs:
|
||||
# 1. Uses data.get("object-entity") instead of rel.get("object-entity")
|
||||
# 2. Sets object_value to predicate_uri instead of actual object URI
|
||||
# This test documents the current buggy behavior
|
||||
rel_triple = next((t for t in triples if t.s.value == subject_uri and t.p.value == predicate_uri), None)
|
||||
|
||||
# Check main relationship triple
|
||||
object_uri = f"{TRUSTGRAPH_ENTITIES}Artificial%20Intelligence"
|
||||
rel_triple = next((t for t in triples if t.s.iri == subject_uri and t.p.iri == predicate_uri), None)
|
||||
assert rel_triple is not None
|
||||
# Due to bug, object value is set to predicate_uri
|
||||
assert rel_triple.o.value == predicate_uri
|
||||
|
||||
assert rel_triple.o.iri == object_uri
|
||||
assert rel_triple.o.type == IRI
|
||||
|
||||
# Check subject-of relationships
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF and t.o.value == "doc123"]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF and t.o.iri == "doc123"]
|
||||
assert len(subject_of_triples) >= 2 # At least subject and predicate should have subject-of relations
|
||||
|
||||
def test_process_extraction_data_literal_object(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of relationships with literal objects"""
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "accuracy",
|
||||
"object": "95%",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
|
||||
# Check that object labels are not created for literal objects
|
||||
object_labels = [t for t in triples if t.p.value == RDF_LABEL and t.o.value == "95%"]
|
||||
object_labels = [t for t in triples if t.p.iri == RDF_LABEL and t.o.value == "95%"]
|
||||
# Based on the code logic, it should not create object labels for non-entity objects
|
||||
# But there might be a bug in the original implementation
|
||||
|
||||
|
|
@ -263,75 +263,62 @@ class TestAgentKgExtractor:
|
|||
triples, entity_contexts = agent_extractor.process_extraction_data(sample_extraction_data, sample_metadata)
|
||||
|
||||
# Check that we have both definition and relationship triples
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
definition_triples = [t for t in triples if t.p.iri == DEFINITION]
|
||||
assert len(definition_triples) == 2 # Two definitions
|
||||
|
||||
|
||||
# Check entity contexts are created for definitions
|
||||
assert len(entity_contexts) == 2
|
||||
entity_uris = [ec.entity.value for ec in entity_contexts]
|
||||
entity_uris = [ec.entity.iri for ec in entity_contexts]
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Machine%20Learning" in entity_uris
|
||||
assert f"{TRUSTGRAPH_ENTITIES}Neural%20Networks" in entity_uris
|
||||
|
||||
def test_process_extraction_data_no_metadata_id(self, agent_extractor):
|
||||
"""Test processing when metadata has no ID"""
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test Entity", "definition": "Test definition"}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test Entity", "definition": "Test definition"}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should not create subject-of relationships when no metadata ID
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
|
||||
|
||||
# Should still create entity contexts
|
||||
assert len(entity_contexts) == 1
|
||||
|
||||
def test_process_extraction_data_empty_data(self, agent_extractor, sample_metadata):
|
||||
"""Test processing of empty extraction data"""
|
||||
data = {"definitions": [], "relationships": []}
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should only have metadata triples
|
||||
assert len(entity_contexts) == 0
|
||||
# Triples should only contain metadata triples if any
|
||||
data = []
|
||||
|
||||
def test_process_extraction_data_missing_keys(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with missing keys"""
|
||||
# Test missing definitions key
|
||||
data = {"relationships": []}
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
||||
# Should have no entity contexts
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Test missing relationships key
|
||||
data = {"definitions": []}
|
||||
# Triples should be empty
|
||||
assert len(triples) == 0
|
||||
|
||||
def test_process_extraction_data_unknown_types_ignored(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with unknown type values"""
|
||||
data = [
|
||||
{"type": "definition", "entity": "Valid", "definition": "Valid def"},
|
||||
{"type": "unknown_type", "foo": "bar"}, # Unknown type - should be ignored
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel", "object": "B", "object-entity": True}
|
||||
]
|
||||
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Test completely missing keys
|
||||
data = {}
|
||||
triples, entity_contexts = agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
assert len(entity_contexts) == 0
|
||||
|
||||
# Should process valid items and ignore unknown types
|
||||
assert len(entity_contexts) == 1 # Only the definition creates entity context
|
||||
|
||||
def test_process_extraction_data_malformed_entries(self, agent_extractor, sample_metadata):
|
||||
"""Test processing data with malformed entries"""
|
||||
# Test definition missing required fields
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test"}, # Missing definition
|
||||
{"definition": "Test def"} # Missing entity
|
||||
],
|
||||
"relationships": [
|
||||
{"subject": "A", "predicate": "rel"}, # Missing object
|
||||
{"subject": "B", "object": "C"} # Missing predicate
|
||||
]
|
||||
}
|
||||
|
||||
# Test items missing required fields - should raise KeyError
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test"}, # Missing definition
|
||||
]
|
||||
|
||||
# Should handle gracefully or raise appropriate errors
|
||||
with pytest.raises(KeyError):
|
||||
agent_extractor.process_extraction_data(data, sample_metadata)
|
||||
|
|
@ -340,17 +327,17 @@ class TestAgentKgExtractor:
|
|||
async def test_emit_triples(self, agent_extractor, sample_metadata):
|
||||
"""Test emitting triples to publisher"""
|
||||
mock_publisher = AsyncMock()
|
||||
|
||||
|
||||
test_triples = [
|
||||
Triple(
|
||||
s=Value(value="test:subject", is_uri=True),
|
||||
p=Value(value="test:predicate", is_uri=True),
|
||||
o=Value(value="test object", is_uri=False)
|
||||
s=Term(type=IRI, iri="test:subject"),
|
||||
p=Term(type=IRI, iri="test:predicate"),
|
||||
o=Term(type=LITERAL, value="test object")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
await agent_extractor.emit_triples(mock_publisher, sample_metadata, test_triples)
|
||||
|
||||
|
||||
mock_publisher.send.assert_called_once()
|
||||
sent_triples = mock_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_triples, Triples)
|
||||
|
|
@ -361,22 +348,22 @@ class TestAgentKgExtractor:
|
|||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_triples.metadata.metadata == []
|
||||
assert len(sent_triples.triples) == 1
|
||||
assert sent_triples.triples[0].s.value == "test:subject"
|
||||
assert sent_triples.triples[0].s.iri == "test:subject"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_entity_contexts(self, agent_extractor, sample_metadata):
|
||||
"""Test emitting entity contexts to publisher"""
|
||||
mock_publisher = AsyncMock()
|
||||
|
||||
|
||||
test_contexts = [
|
||||
EntityContext(
|
||||
entity=Value(value="test:entity", is_uri=True),
|
||||
entity=Term(type=IRI, iri="test:entity"),
|
||||
context="Test context"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
await agent_extractor.emit_entity_contexts(mock_publisher, sample_metadata, test_contexts)
|
||||
|
||||
|
||||
mock_publisher.send.assert_called_once()
|
||||
sent_contexts = mock_publisher.send.call_args[0][0]
|
||||
assert isinstance(sent_contexts, EntityContexts)
|
||||
|
|
@ -387,7 +374,7 @@ class TestAgentKgExtractor:
|
|||
# Note: metadata.metadata is now empty array in the new implementation
|
||||
assert sent_contexts.metadata.metadata == []
|
||||
assert len(sent_contexts.entities) == 1
|
||||
assert sent_contexts.entities[0].entity.value == "test:entity"
|
||||
assert sent_contexts.entities[0].entity.iri == "test:entity"
|
||||
|
||||
def test_agent_extractor_initialization_params(self):
|
||||
"""Test agent extractor parameter validation"""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import urllib.parse
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.extract.kg.agent.extract import Processor as AgentKgExtractor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
|
@ -32,11 +32,11 @@ class TestAgentKgExtractionEdgeCases:
|
|||
|
||||
# Set up the methods we want to test
|
||||
extractor.to_uri = real_extractor.to_uri
|
||||
extractor.parse_json = real_extractor.parse_json
|
||||
extractor.parse_jsonl = real_extractor.parse_jsonl
|
||||
extractor.process_extraction_data = real_extractor.process_extraction_data
|
||||
extractor.emit_triples = real_extractor.emit_triples
|
||||
extractor.emit_entity_contexts = real_extractor.emit_entity_contexts
|
||||
|
||||
|
||||
return extractor
|
||||
|
||||
def test_to_uri_special_characters(self, agent_extractor):
|
||||
|
|
@ -85,146 +85,116 @@ class TestAgentKgExtractionEdgeCases:
|
|||
# Verify the URI is properly encoded
|
||||
assert unicode_text not in uri # Original unicode should be encoded
|
||||
|
||||
def test_parse_json_whitespace_variations(self, agent_extractor):
|
||||
"""Test JSON parsing with various whitespace patterns"""
|
||||
# Test JSON with different whitespace patterns
|
||||
def test_parse_jsonl_whitespace_variations(self, agent_extractor):
|
||||
"""Test JSONL parsing with various whitespace patterns"""
|
||||
# Test JSONL with different whitespace patterns
|
||||
test_cases = [
|
||||
# Extra whitespace around code blocks
|
||||
" ```json\n{\"test\": true}\n``` ",
|
||||
# Tabs and mixed whitespace
|
||||
"\t\t```json\n\t{\"test\": true}\n\t```\t",
|
||||
# Multiple newlines
|
||||
"\n\n\n```json\n\n{\"test\": true}\n\n```\n\n",
|
||||
# JSON without code blocks but with whitespace
|
||||
" {\"test\": true} ",
|
||||
# Mixed line endings
|
||||
"```json\r\n{\"test\": true}\r\n```",
|
||||
' ```json\n{"type": "definition", "entity": "test", "definition": "def"}\n``` ',
|
||||
# Multiple newlines between lines
|
||||
'{"type": "definition", "entity": "A", "definition": "def A"}\n\n\n{"type": "definition", "entity": "B", "definition": "def B"}',
|
||||
# JSONL without code blocks but with whitespace
|
||||
' {"type": "definition", "entity": "test", "definition": "def"} ',
|
||||
]
|
||||
|
||||
for response in test_cases:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert result == {"test": True}
|
||||
|
||||
def test_parse_json_code_block_variations(self, agent_extractor):
|
||||
"""Test JSON parsing with different code block formats"""
|
||||
for response in test_cases:
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
assert len(result) >= 1
|
||||
assert result[0].get("type") == "definition"
|
||||
|
||||
def test_parse_jsonl_code_block_variations(self, agent_extractor):
|
||||
"""Test JSONL parsing with different code block formats"""
|
||||
test_cases = [
|
||||
# Standard json code block
|
||||
"```json\n{\"valid\": true}\n```",
|
||||
'```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# jsonl code block
|
||||
'```jsonl\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# Code block without language
|
||||
"```\n{\"valid\": true}\n```",
|
||||
# Uppercase JSON
|
||||
"```JSON\n{\"valid\": true}\n```",
|
||||
# Mixed case
|
||||
"```Json\n{\"valid\": true}\n```",
|
||||
# Multiple code blocks (should take first one)
|
||||
"```json\n{\"first\": true}\n```\n```json\n{\"second\": true}\n```",
|
||||
# Code block with extra content
|
||||
"Here's the result:\n```json\n{\"valid\": true}\n```\nDone!",
|
||||
'```\n{"type": "definition", "entity": "A", "definition": "def"}\n```',
|
||||
# Code block with extra content before/after
|
||||
'Here\'s the result:\n```json\n{"type": "definition", "entity": "A", "definition": "def"}\n```\nDone!',
|
||||
]
|
||||
|
||||
|
||||
for i, response in enumerate(test_cases):
|
||||
try:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert result.get("valid") == True or result.get("first") == True
|
||||
except json.JSONDecodeError:
|
||||
# Some cases may fail due to regex extraction issues
|
||||
# This documents current behavior - the regex may not match all cases
|
||||
print(f"Case {i} failed JSON parsing: {response[:50]}...")
|
||||
pass
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
assert len(result) >= 1, f"Case {i} failed"
|
||||
assert result[0].get("entity") == "A"
|
||||
|
||||
def test_parse_json_malformed_code_blocks(self, agent_extractor):
|
||||
"""Test JSON parsing with malformed code block formats"""
|
||||
# These should still work by falling back to treating entire text as JSON
|
||||
test_cases = [
|
||||
# Unclosed code block
|
||||
"```json\n{\"test\": true}",
|
||||
# No opening backticks
|
||||
"{\"test\": true}\n```",
|
||||
# Wrong number of backticks
|
||||
"`json\n{\"test\": true}\n`",
|
||||
# Nested backticks (should handle gracefully)
|
||||
"```json\n{\"code\": \"```\", \"test\": true}\n```",
|
||||
]
|
||||
|
||||
for response in test_cases:
|
||||
try:
|
||||
result = agent_extractor.parse_json(response)
|
||||
assert "test" in result # Should successfully parse
|
||||
except json.JSONDecodeError:
|
||||
# This is also acceptable for malformed cases
|
||||
pass
|
||||
def test_parse_jsonl_truncation_resilience(self, agent_extractor):
|
||||
"""Test JSONL parsing with truncated responses"""
|
||||
# Simulates LLM output being cut off mid-line
|
||||
response = '''{"type": "definition", "entity": "Complete1", "definition": "Full definition"}
|
||||
{"type": "definition", "entity": "Complete2", "definition": "Another full def"}
|
||||
{"type": "definition", "entity": "Trunca'''
|
||||
|
||||
def test_parse_json_large_responses(self, agent_extractor):
|
||||
"""Test JSON parsing with very large responses"""
|
||||
# Create a large JSON structure
|
||||
large_data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": f"Entity {i}",
|
||||
"definition": f"Definition {i} " + "with more content " * 100
|
||||
}
|
||||
for i in range(100)
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": f"Subject {i}",
|
||||
"predicate": f"predicate_{i}",
|
||||
"object": f"Object {i}",
|
||||
"object-entity": i % 2 == 0
|
||||
}
|
||||
for i in range(50)
|
||||
]
|
||||
}
|
||||
|
||||
large_json_str = json.dumps(large_data)
|
||||
response = f"```json\n{large_json_str}\n```"
|
||||
|
||||
result = agent_extractor.parse_json(response)
|
||||
|
||||
assert len(result["definitions"]) == 100
|
||||
assert len(result["relationships"]) == 50
|
||||
assert result["definitions"][0]["entity"] == "Entity 0"
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
# Should get 2 valid objects, the truncated line is skipped
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "Complete1"
|
||||
assert result[1]["entity"] == "Complete2"
|
||||
|
||||
def test_parse_jsonl_large_responses(self, agent_extractor):
|
||||
"""Test JSONL parsing with very large responses"""
|
||||
# Create a large JSONL response
|
||||
lines = []
|
||||
for i in range(100):
|
||||
lines.append(json.dumps({
|
||||
"type": "definition",
|
||||
"entity": f"Entity {i}",
|
||||
"definition": f"Definition {i} " + "with more content " * 100
|
||||
}))
|
||||
for i in range(50):
|
||||
lines.append(json.dumps({
|
||||
"type": "relationship",
|
||||
"subject": f"Subject {i}",
|
||||
"predicate": f"predicate_{i}",
|
||||
"object": f"Object {i}",
|
||||
"object-entity": i % 2 == 0
|
||||
}))
|
||||
|
||||
response = f"```json\n{chr(10).join(lines)}\n```"
|
||||
|
||||
result = agent_extractor.parse_jsonl(response)
|
||||
|
||||
definitions = [r for r in result if r.get("type") == "definition"]
|
||||
relationships = [r for r in result if r.get("type") == "relationship"]
|
||||
|
||||
assert len(definitions) == 100
|
||||
assert len(relationships) == 50
|
||||
assert definitions[0]["entity"] == "Entity 0"
|
||||
|
||||
def test_process_extraction_data_empty_metadata(self, agent_extractor):
|
||||
"""Test processing with empty or minimal metadata"""
|
||||
# Test with None metadata - may not raise AttributeError depending on implementation
|
||||
try:
|
||||
triples, contexts = agent_extractor.process_extraction_data(
|
||||
{"definitions": [], "relationships": []},
|
||||
None
|
||||
)
|
||||
triples, contexts = agent_extractor.process_extraction_data([], None)
|
||||
# If it doesn't raise, check the results
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
except (AttributeError, TypeError):
|
||||
# This is expected behavior when metadata is None
|
||||
pass
|
||||
|
||||
|
||||
# Test with metadata without ID
|
||||
metadata = Metadata(id=None, metadata=[])
|
||||
triples, contexts = agent_extractor.process_extraction_data(
|
||||
{"definitions": [], "relationships": []},
|
||||
metadata
|
||||
)
|
||||
triples, contexts = agent_extractor.process_extraction_data([], metadata)
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
|
||||
|
||||
# Test with metadata with empty string ID
|
||||
metadata = Metadata(id="", metadata=[])
|
||||
data = {
|
||||
"definitions": [{"entity": "Test", "definition": "Test def"}],
|
||||
"relationships": []
|
||||
}
|
||||
data = [{"type": "definition", "entity": "Test", "definition": "Test def"}]
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should not create subject-of triples when ID is empty string
|
||||
subject_of_triples = [t for t in triples if t.p.value == SUBJECT_OF]
|
||||
subject_of_triples = [t for t in triples if t.p.iri == SUBJECT_OF]
|
||||
assert len(subject_of_triples) == 0
|
||||
|
||||
def test_process_extraction_data_special_entity_names(self, agent_extractor):
|
||||
"""Test processing with special characters in entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
|
||||
special_entities = [
|
||||
"Entity with spaces",
|
||||
"Entity & Co.",
|
||||
|
|
@ -237,71 +207,62 @@ class TestAgentKgExtractionEdgeCases:
|
|||
"Quotes: \"test\"",
|
||||
"Parentheses: (test)",
|
||||
]
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": entity, "definition": f"Definition for {entity}"}
|
||||
for entity in special_entities
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": entity, "definition": f"Definition for {entity}"}
|
||||
for entity in special_entities
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Verify all entities were processed
|
||||
assert len(contexts) == len(special_entities)
|
||||
|
||||
|
||||
# Verify URIs were properly encoded
|
||||
for i, entity in enumerate(special_entities):
|
||||
expected_uri = f"{TRUSTGRAPH_ENTITIES}{urllib.parse.quote(entity)}"
|
||||
assert contexts[i].entity.value == expected_uri
|
||||
assert contexts[i].entity.iri == expected_uri
|
||||
|
||||
def test_process_extraction_data_very_long_definitions(self, agent_extractor):
|
||||
"""Test processing with very long entity definitions"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
|
||||
# Create very long definition
|
||||
long_definition = "This is a very long definition. " * 1000
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Test Entity", "definition": long_definition}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Test Entity", "definition": long_definition}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle long definitions without issues
|
||||
assert len(contexts) == 1
|
||||
assert contexts[0].context == long_definition
|
||||
|
||||
|
||||
# Find definition triple
|
||||
def_triple = next((t for t in triples if t.p.value == DEFINITION), None)
|
||||
def_triple = next((t for t in triples if t.p.iri == DEFINITION), None)
|
||||
assert def_triple is not None
|
||||
assert def_triple.o.value == long_definition
|
||||
|
||||
def test_process_extraction_data_duplicate_entities(self, agent_extractor):
|
||||
"""Test processing with duplicate entity names"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "Machine Learning", "definition": "First definition"},
|
||||
{"entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
|
||||
{"entity": "AI", "definition": "AI definition"},
|
||||
{"entity": "AI", "definition": "Another AI definition"}, # Duplicate
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "First definition"},
|
||||
{"type": "definition", "entity": "Machine Learning", "definition": "Second definition"}, # Duplicate
|
||||
{"type": "definition", "entity": "AI", "definition": "AI definition"},
|
||||
{"type": "definition", "entity": "AI", "definition": "Another AI definition"}, # Duplicate
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should process all entries (including duplicates)
|
||||
assert len(contexts) == 4
|
||||
|
||||
|
||||
# Check that both definitions for "Machine Learning" are present
|
||||
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.value]
|
||||
ml_contexts = [ec for ec in contexts if "Machine%20Learning" in ec.entity.iri]
|
||||
assert len(ml_contexts) == 2
|
||||
assert ml_contexts[0].context == "First definition"
|
||||
assert ml_contexts[1].context == "Second definition"
|
||||
|
|
@ -309,49 +270,44 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_empty_strings(self, agent_extractor):
|
||||
"""Test processing with empty strings in data"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{"entity": "", "definition": "Definition for empty entity"},
|
||||
{"entity": "Valid Entity", "definition": ""},
|
||||
{"entity": " ", "definition": " "}, # Whitespace only
|
||||
],
|
||||
"relationships": [
|
||||
{"subject": "", "predicate": "test", "object": "test", "object-entity": True},
|
||||
{"subject": "test", "predicate": "", "object": "test", "object-entity": True},
|
||||
{"subject": "test", "predicate": "test", "object": "", "object-entity": True},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{"type": "definition", "entity": "", "definition": "Definition for empty entity"},
|
||||
{"type": "definition", "entity": "Valid Entity", "definition": ""},
|
||||
{"type": "definition", "entity": " ", "definition": " "}, # Whitespace only
|
||||
{"type": "relationship", "subject": "", "predicate": "test", "object": "test", "object-entity": True},
|
||||
{"type": "relationship", "subject": "test", "predicate": "", "object": "test", "object-entity": True},
|
||||
{"type": "relationship", "subject": "test", "predicate": "test", "object": "", "object-entity": True},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle empty strings by creating URIs (even if empty)
|
||||
assert len(contexts) == 3
|
||||
|
||||
|
||||
# Empty entity should create empty URI after encoding
|
||||
empty_entity_context = next((ec for ec in contexts if ec.entity.value == TRUSTGRAPH_ENTITIES), None)
|
||||
empty_entity_context = next((ec for ec in contexts if ec.entity.iri == TRUSTGRAPH_ENTITIES), None)
|
||||
assert empty_entity_context is not None
|
||||
|
||||
def test_process_extraction_data_nested_json_in_strings(self, agent_extractor):
|
||||
"""Test processing when definitions contain JSON-like strings"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": "JSON Entity",
|
||||
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
|
||||
},
|
||||
{
|
||||
"entity": "Array Entity",
|
||||
"definition": 'Contains array: [1, 2, 3, "string"]'
|
||||
}
|
||||
],
|
||||
"relationships": []
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "JSON Entity",
|
||||
"definition": 'Definition with JSON: {"key": "value", "nested": {"inner": true}}'
|
||||
},
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": "Array Entity",
|
||||
"definition": 'Contains array: [1, 2, 3, "string"]'
|
||||
}
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should handle JSON strings in definitions without parsing them
|
||||
assert len(contexts) == 2
|
||||
assert '{"key": "value"' in contexts[0].context
|
||||
|
|
@ -360,32 +316,29 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_boolean_object_entity_variations(self, agent_extractor):
|
||||
"""Test processing with various boolean values for object-entity"""
|
||||
metadata = Metadata(id="doc123", metadata=[])
|
||||
|
||||
data = {
|
||||
"definitions": [],
|
||||
"relationships": [
|
||||
# Explicit True
|
||||
{"subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
|
||||
# Explicit False
|
||||
{"subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
|
||||
# Missing object-entity (should default to True based on code)
|
||||
{"subject": "A", "predicate": "rel3", "object": "C"},
|
||||
# String "true" (should be treated as truthy)
|
||||
{"subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
|
||||
# String "false" (should be treated as truthy in Python)
|
||||
{"subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
|
||||
# Number 0 (falsy)
|
||||
{"subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
|
||||
# Number 1 (truthy)
|
||||
{"subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
data = [
|
||||
# Explicit True
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel1", "object": "B", "object-entity": True},
|
||||
# Explicit False
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel2", "object": "literal", "object-entity": False},
|
||||
# Missing object-entity (should default to True based on code)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel3", "object": "C"},
|
||||
# String "true" (should be treated as truthy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel4", "object": "D", "object-entity": "true"},
|
||||
# String "false" (should be treated as truthy in Python)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel5", "object": "E", "object-entity": "false"},
|
||||
# Number 0 (falsy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel6", "object": "literal2", "object-entity": 0},
|
||||
# Number 1 (truthy)
|
||||
{"type": "relationship", "subject": "A", "predicate": "rel7", "object": "F", "object-entity": 1},
|
||||
]
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(data, metadata)
|
||||
|
||||
|
||||
# Should process all relationships
|
||||
# Note: The current implementation has some logic issues that these tests document
|
||||
assert len([t for t in triples if t.p.value != RDF_LABEL and t.p.value != SUBJECT_OF]) >= 7
|
||||
assert len([t for t in triples if t.p.iri != RDF_LABEL and t.p.iri != SUBJECT_OF]) >= 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_empty_collections(self, agent_extractor):
|
||||
|
|
@ -437,41 +390,40 @@ class TestAgentKgExtractionEdgeCases:
|
|||
def test_process_extraction_data_performance_large_dataset(self, agent_extractor):
|
||||
"""Test performance with large extraction datasets"""
|
||||
metadata = Metadata(id="large-doc", metadata=[])
|
||||
|
||||
# Create large dataset
|
||||
|
||||
# Create large dataset in JSONL format
|
||||
num_definitions = 1000
|
||||
num_relationships = 2000
|
||||
|
||||
large_data = {
|
||||
"definitions": [
|
||||
{
|
||||
"entity": f"Entity_{i:04d}",
|
||||
"definition": f"Definition for entity {i} with some detailed explanation."
|
||||
}
|
||||
for i in range(num_definitions)
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": f"Entity_{i % num_definitions:04d}",
|
||||
"predicate": f"predicate_{i % 10}",
|
||||
"object": f"Entity_{(i + 1) % num_definitions:04d}",
|
||||
"object-entity": True
|
||||
}
|
||||
for i in range(num_relationships)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
large_data = [
|
||||
{
|
||||
"type": "definition",
|
||||
"entity": f"Entity_{i:04d}",
|
||||
"definition": f"Definition for entity {i} with some detailed explanation."
|
||||
}
|
||||
for i in range(num_definitions)
|
||||
] + [
|
||||
{
|
||||
"type": "relationship",
|
||||
"subject": f"Entity_{i % num_definitions:04d}",
|
||||
"predicate": f"predicate_{i % 10}",
|
||||
"object": f"Entity_{(i + 1) % num_definitions:04d}",
|
||||
"object-entity": True
|
||||
}
|
||||
for i in range(num_relationships)
|
||||
]
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
triples, contexts = agent_extractor.process_extraction_data(large_data, metadata)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
|
||||
# Should complete within reasonable time (adjust threshold as needed)
|
||||
assert processing_time < 10.0 # 10 seconds threshold
|
||||
|
||||
|
||||
# Verify results
|
||||
assert len(contexts) == num_definitions
|
||||
# Triples include labels, definitions, relationships, and subject-of relations
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ processing graph structures, and performing graph operations.
|
|||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from .conftest import Triple, Value, Metadata
|
||||
from .conftest import Triple, Metadata
|
||||
from collections import defaultdict, deque
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def cities_schema():
|
|||
def validator():
|
||||
"""Create a mock processor with just the validation method"""
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
from trustgraph.extract.kg.rows.processor import Processor
|
||||
|
||||
# Create a mock processor
|
||||
mock_processor = MagicMock()
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@
|
|||
Unit tests for triple construction logic
|
||||
|
||||
Tests the core business logic for constructing RDF triples from extracted
|
||||
entities and relationships, including URI generation, Value object creation,
|
||||
entities and relationships, including URI generation, Term object creation,
|
||||
and triple validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from .conftest import Triple, Triples, Value, Metadata
|
||||
from .conftest import Triple, Triples, Term, Metadata, IRI, LITERAL
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
|
|
@ -48,80 +48,82 @@ class TestTripleConstructionLogic:
|
|||
generated_uri = generate_uri(text, entity_type)
|
||||
assert generated_uri == expected_uri, f"URI generation failed for '{text}'"
|
||||
|
||||
def test_value_object_creation(self):
|
||||
"""Test creation of Value objects for subjects, predicates, and objects"""
|
||||
def test_term_object_creation(self):
|
||||
"""Test creation of Term objects for subjects, predicates, and objects"""
|
||||
# Arrange
|
||||
def create_value_object(text, is_uri, value_type=""):
|
||||
return Value(
|
||||
value=text,
|
||||
is_uri=is_uri,
|
||||
type=value_type
|
||||
)
|
||||
|
||||
def create_term_object(text, is_uri, datatype=""):
|
||||
if is_uri:
|
||||
return Term(type=IRI, iri=text)
|
||||
else:
|
||||
return Term(type=LITERAL, value=text, datatype=datatype if datatype else None)
|
||||
|
||||
test_cases = [
|
||||
("http://trustgraph.ai/kg/person/john-smith", True, ""),
|
||||
("John Smith", False, "string"),
|
||||
("42", False, "integer"),
|
||||
("http://schema.org/worksFor", True, "")
|
||||
]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
for value_text, is_uri, value_type in test_cases:
|
||||
value_obj = create_value_object(value_text, is_uri, value_type)
|
||||
|
||||
assert isinstance(value_obj, Value)
|
||||
assert value_obj.value == value_text
|
||||
assert value_obj.is_uri == is_uri
|
||||
assert value_obj.type == value_type
|
||||
for value_text, is_uri, datatype in test_cases:
|
||||
term_obj = create_term_object(value_text, is_uri, datatype)
|
||||
|
||||
assert isinstance(term_obj, Term)
|
||||
if is_uri:
|
||||
assert term_obj.type == IRI
|
||||
assert term_obj.iri == value_text
|
||||
else:
|
||||
assert term_obj.type == LITERAL
|
||||
assert term_obj.value == value_text
|
||||
|
||||
def test_triple_construction_from_relationship(self):
|
||||
"""Test constructing Triple objects from relationships"""
|
||||
# Arrange
|
||||
relationship = {
|
||||
"subject": "John Smith",
|
||||
"predicate": "works_for",
|
||||
"predicate": "works_for",
|
||||
"object": "OpenAI",
|
||||
"subject_type": "PERSON",
|
||||
"object_type": "ORG"
|
||||
}
|
||||
|
||||
|
||||
def construct_triple(relationship, uri_base="http://trustgraph.ai/kg"):
|
||||
# Generate URIs
|
||||
subject_uri = f"{uri_base}/person/{relationship['subject'].lower().replace(' ', '-')}"
|
||||
object_uri = f"{uri_base}/org/{relationship['object'].lower().replace(' ', '-')}"
|
||||
|
||||
|
||||
# Map predicate to schema.org URI
|
||||
predicate_mappings = {
|
||||
"works_for": "http://schema.org/worksFor",
|
||||
"located_in": "http://schema.org/location",
|
||||
"developed": "http://schema.org/creator"
|
||||
}
|
||||
predicate_uri = predicate_mappings.get(relationship["predicate"],
|
||||
predicate_uri = predicate_mappings.get(relationship["predicate"],
|
||||
f"{uri_base}/predicate/{relationship['predicate']}")
|
||||
|
||||
# Create Value objects
|
||||
subject_value = Value(value=subject_uri, is_uri=True, type="")
|
||||
predicate_value = Value(value=predicate_uri, is_uri=True, type="")
|
||||
object_value = Value(value=object_uri, is_uri=True, type="")
|
||||
|
||||
|
||||
# Create Term objects
|
||||
subject_term = Term(type=IRI, iri=subject_uri)
|
||||
predicate_term = Term(type=IRI, iri=predicate_uri)
|
||||
object_term = Term(type=IRI, iri=object_uri)
|
||||
|
||||
# Create Triple
|
||||
return Triple(
|
||||
s=subject_value,
|
||||
p=predicate_value,
|
||||
o=object_value
|
||||
s=subject_term,
|
||||
p=predicate_term,
|
||||
o=object_term
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
triple = construct_triple(relationship)
|
||||
|
||||
|
||||
# Assert
|
||||
assert isinstance(triple, Triple)
|
||||
assert triple.s.value == "http://trustgraph.ai/kg/person/john-smith"
|
||||
assert triple.s.is_uri is True
|
||||
assert triple.p.value == "http://schema.org/worksFor"
|
||||
assert triple.p.is_uri is True
|
||||
assert triple.o.value == "http://trustgraph.ai/kg/org/openai"
|
||||
assert triple.o.is_uri is True
|
||||
assert triple.s.iri == "http://trustgraph.ai/kg/person/john-smith"
|
||||
assert triple.s.type == IRI
|
||||
assert triple.p.iri == "http://schema.org/worksFor"
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == "http://trustgraph.ai/kg/org/openai"
|
||||
assert triple.o.type == IRI
|
||||
|
||||
def test_literal_value_handling(self):
|
||||
"""Test handling of literal values vs URI values"""
|
||||
|
|
@ -132,10 +134,10 @@ class TestTripleConstructionLogic:
|
|||
("John Smith", "email", "john@example.com", False), # Literal email
|
||||
("John Smith", "worksFor", "http://trustgraph.ai/kg/org/openai", True) # URI reference
|
||||
]
|
||||
|
||||
|
||||
def create_triple_with_literal(subject_uri, predicate, object_value, object_is_uri):
|
||||
subject_val = Value(value=subject_uri, is_uri=True, type="")
|
||||
|
||||
subject_term = Term(type=IRI, iri=subject_uri)
|
||||
|
||||
# Determine predicate URI
|
||||
predicate_mappings = {
|
||||
"name": "http://schema.org/name",
|
||||
|
|
@ -144,32 +146,37 @@ class TestTripleConstructionLogic:
|
|||
"worksFor": "http://schema.org/worksFor"
|
||||
}
|
||||
predicate_uri = predicate_mappings.get(predicate, f"http://trustgraph.ai/kg/predicate/{predicate}")
|
||||
predicate_val = Value(value=predicate_uri, is_uri=True, type="")
|
||||
|
||||
# Create object value with appropriate type
|
||||
object_type = ""
|
||||
if not object_is_uri:
|
||||
predicate_term = Term(type=IRI, iri=predicate_uri)
|
||||
|
||||
# Create object term with appropriate type
|
||||
if object_is_uri:
|
||||
object_term = Term(type=IRI, iri=object_value)
|
||||
else:
|
||||
datatype = None
|
||||
if predicate == "age":
|
||||
object_type = "integer"
|
||||
datatype = "integer"
|
||||
elif predicate in ["name", "email"]:
|
||||
object_type = "string"
|
||||
|
||||
object_val = Value(value=object_value, is_uri=object_is_uri, type=object_type)
|
||||
|
||||
return Triple(s=subject_val, p=predicate_val, o=object_val)
|
||||
|
||||
datatype = "string"
|
||||
object_term = Term(type=LITERAL, value=object_value, datatype=datatype)
|
||||
|
||||
return Triple(s=subject_term, p=predicate_term, o=object_term)
|
||||
|
||||
# Act & Assert
|
||||
for subject_uri, predicate, object_value, object_is_uri in test_data:
|
||||
subject_full_uri = "http://trustgraph.ai/kg/person/john-smith"
|
||||
triple = create_triple_with_literal(subject_full_uri, predicate, object_value, object_is_uri)
|
||||
|
||||
assert triple.o.is_uri == object_is_uri
|
||||
assert triple.o.value == object_value
|
||||
|
||||
|
||||
if object_is_uri:
|
||||
assert triple.o.type == IRI
|
||||
assert triple.o.iri == object_value
|
||||
else:
|
||||
assert triple.o.type == LITERAL
|
||||
assert triple.o.value == object_value
|
||||
|
||||
if predicate == "age":
|
||||
assert triple.o.type == "integer"
|
||||
assert triple.o.datatype == "integer"
|
||||
elif predicate in ["name", "email"]:
|
||||
assert triple.o.type == "string"
|
||||
assert triple.o.datatype == "string"
|
||||
|
||||
def test_namespace_management(self):
|
||||
"""Test namespace prefix management and expansion"""
|
||||
|
|
@ -216,63 +223,74 @@ class TestTripleConstructionLogic:
|
|||
def test_triple_validation(self):
|
||||
"""Test triple validation rules"""
|
||||
# Arrange
|
||||
def get_term_value(term):
|
||||
"""Extract value from a Term"""
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
else:
|
||||
return term.value
|
||||
|
||||
def validate_triple(triple):
|
||||
errors = []
|
||||
|
||||
|
||||
# Check required components
|
||||
if not triple.s or not triple.s.value:
|
||||
s_val = get_term_value(triple.s) if triple.s else None
|
||||
p_val = get_term_value(triple.p) if triple.p else None
|
||||
o_val = get_term_value(triple.o) if triple.o else None
|
||||
|
||||
if not triple.s or not s_val:
|
||||
errors.append("Missing or empty subject")
|
||||
|
||||
if not triple.p or not triple.p.value:
|
||||
|
||||
if not triple.p or not p_val:
|
||||
errors.append("Missing or empty predicate")
|
||||
|
||||
if not triple.o or not triple.o.value:
|
||||
|
||||
if not triple.o or not o_val:
|
||||
errors.append("Missing or empty object")
|
||||
|
||||
|
||||
# Check URI validity for URI values
|
||||
uri_pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||
|
||||
if triple.s.is_uri and not re.match(uri_pattern, triple.s.value):
|
||||
|
||||
if triple.s.type == IRI and not re.match(uri_pattern, triple.s.iri or ""):
|
||||
errors.append("Invalid subject URI format")
|
||||
|
||||
if triple.p.is_uri and not re.match(uri_pattern, triple.p.value):
|
||||
|
||||
if triple.p.type == IRI and not re.match(uri_pattern, triple.p.iri or ""):
|
||||
errors.append("Invalid predicate URI format")
|
||||
|
||||
if triple.o.is_uri and not re.match(uri_pattern, triple.o.value):
|
||||
|
||||
if triple.o.type == IRI and not re.match(uri_pattern, triple.o.iri or ""):
|
||||
errors.append("Invalid object URI format")
|
||||
|
||||
|
||||
# Predicates should typically be URIs
|
||||
if not triple.p.is_uri:
|
||||
if triple.p.type != IRI:
|
||||
errors.append("Predicate should be a URI")
|
||||
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
# Test valid triple
|
||||
valid_triple = Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John Smith", is_uri=False, type="string")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith", datatype="string")
|
||||
)
|
||||
|
||||
|
||||
# Test invalid triples
|
||||
invalid_triples = [
|
||||
Triple(s=Value(value="", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John", is_uri=False, type="")), # Empty subject
|
||||
|
||||
Triple(s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="name", is_uri=False, type=""), # Non-URI predicate
|
||||
o=Value(value="John", is_uri=False, type="")),
|
||||
|
||||
Triple(s=Value(value="invalid-uri", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John", is_uri=False, type="")) # Invalid URI format
|
||||
Triple(s=Term(type=IRI, iri=""),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John")), # Empty subject
|
||||
|
||||
Triple(s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=LITERAL, value="name"), # Non-URI predicate
|
||||
o=Term(type=LITERAL, value="John")),
|
||||
|
||||
Triple(s=Term(type=IRI, iri="invalid-uri"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John")) # Invalid URI format
|
||||
]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
is_valid, errors = validate_triple(valid_triple)
|
||||
assert is_valid, f"Valid triple failed validation: {errors}"
|
||||
|
||||
|
||||
for invalid_triple in invalid_triples:
|
||||
is_valid, errors = validate_triple(invalid_triple)
|
||||
assert not is_valid, f"Invalid triple passed validation: {invalid_triple}"
|
||||
|
|
@ -286,97 +304,97 @@ class TestTripleConstructionLogic:
|
|||
{"text": "OpenAI", "type": "ORG"},
|
||||
{"text": "San Francisco", "type": "PLACE"}
|
||||
]
|
||||
|
||||
|
||||
relationships = [
|
||||
{"subject": "John Smith", "predicate": "works_for", "object": "OpenAI"},
|
||||
{"subject": "OpenAI", "predicate": "located_in", "object": "San Francisco"}
|
||||
]
|
||||
|
||||
|
||||
def construct_triple_batch(entities, relationships, document_id="doc-1"):
|
||||
triples = []
|
||||
|
||||
|
||||
# Create type triples for entities
|
||||
for entity in entities:
|
||||
entity_uri = f"http://trustgraph.ai/kg/{entity['type'].lower()}/{entity['text'].lower().replace(' ', '-')}"
|
||||
type_uri = f"http://trustgraph.ai/kg/type/{entity['type']}"
|
||||
|
||||
|
||||
type_triple = Triple(
|
||||
s=Value(value=entity_uri, is_uri=True, type=""),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True, type=""),
|
||||
o=Value(value=type_uri, is_uri=True, type="")
|
||||
s=Term(type=IRI, iri=entity_uri),
|
||||
p=Term(type=IRI, iri="http://www.w3.org/1999/02/22-rdf-syntax-ns#type"),
|
||||
o=Term(type=IRI, iri=type_uri)
|
||||
)
|
||||
triples.append(type_triple)
|
||||
|
||||
|
||||
# Create relationship triples
|
||||
for rel in relationships:
|
||||
subject_uri = f"http://trustgraph.ai/kg/entity/{rel['subject'].lower().replace(' ', '-')}"
|
||||
object_uri = f"http://trustgraph.ai/kg/entity/{rel['object'].lower().replace(' ', '-')}"
|
||||
predicate_uri = f"http://schema.org/{rel['predicate'].replace('_', '')}"
|
||||
|
||||
|
||||
rel_triple = Triple(
|
||||
s=Value(value=subject_uri, is_uri=True, type=""),
|
||||
p=Value(value=predicate_uri, is_uri=True, type=""),
|
||||
o=Value(value=object_uri, is_uri=True, type="")
|
||||
s=Term(type=IRI, iri=subject_uri),
|
||||
p=Term(type=IRI, iri=predicate_uri),
|
||||
o=Term(type=IRI, iri=object_uri)
|
||||
)
|
||||
triples.append(rel_triple)
|
||||
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
# Act
|
||||
triples = construct_triple_batch(entities, relationships)
|
||||
|
||||
|
||||
# Assert
|
||||
assert len(triples) == len(entities) + len(relationships) # Type triples + relationship triples
|
||||
|
||||
|
||||
# Check that all triples are valid Triple objects
|
||||
for triple in triples:
|
||||
assert isinstance(triple, Triple)
|
||||
assert triple.s.value != ""
|
||||
assert triple.p.value != ""
|
||||
assert triple.o.value != ""
|
||||
assert triple.s.iri != ""
|
||||
assert triple.p.iri != ""
|
||||
assert triple.o.iri != ""
|
||||
|
||||
def test_triples_batch_object_creation(self):
|
||||
"""Test creating Triples batch objects with metadata"""
|
||||
# Arrange
|
||||
sample_triples = [
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/name", is_uri=True, type=""),
|
||||
o=Value(value="John Smith", is_uri=False, type="string")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/name"),
|
||||
o=Term(type=LITERAL, value="John Smith", datatype="string")
|
||||
),
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/kg/person/john", is_uri=True, type=""),
|
||||
p=Value(value="http://schema.org/worksFor", is_uri=True, type=""),
|
||||
o=Value(value="http://trustgraph.ai/kg/org/openai", is_uri=True, type="")
|
||||
s=Term(type=IRI, iri="http://trustgraph.ai/kg/person/john"),
|
||||
p=Term(type=IRI, iri="http://schema.org/worksFor"),
|
||||
o=Term(type=IRI, iri="http://trustgraph.ai/kg/org/openai")
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
|
||||
# Act
|
||||
triples_batch = Triples(
|
||||
metadata=metadata,
|
||||
triples=sample_triples
|
||||
)
|
||||
|
||||
|
||||
# Assert
|
||||
assert isinstance(triples_batch, Triples)
|
||||
assert triples_batch.metadata.id == "test-doc-123"
|
||||
assert triples_batch.metadata.user == "test_user"
|
||||
assert triples_batch.metadata.collection == "test_collection"
|
||||
assert len(triples_batch.triples) == 2
|
||||
|
||||
|
||||
# Check that triples are properly embedded
|
||||
for triple in triples_batch.triples:
|
||||
assert isinstance(triple, Triple)
|
||||
assert isinstance(triple.s, Value)
|
||||
assert isinstance(triple.p, Value)
|
||||
assert isinstance(triple.o, Value)
|
||||
assert isinstance(triple.s, Term)
|
||||
assert isinstance(triple.p, Term)
|
||||
assert isinstance(triple.o, Term)
|
||||
|
||||
def test_uri_collision_handling(self):
|
||||
"""Test handling of URI collisions and duplicate detection"""
|
||||
|
|
|
|||
|
|
@ -339,7 +339,250 @@ class TestPromptManager:
|
|||
"""Test PromptManager with minimal configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config({}) # Empty config
|
||||
|
||||
|
||||
assert pm.config.system_template == "Be helpful." # Default system
|
||||
assert pm.terms == {} # Default empty terms
|
||||
assert len(pm.prompts) == 0
|
||||
assert len(pm.prompts) == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptManagerJsonl:
|
||||
"""Unit tests for PromptManager JSONL functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def jsonl_config(self):
|
||||
"""Configuration with JSONL response type prompts"""
|
||||
return {
|
||||
"system": json.dumps("You are an extraction assistant."),
|
||||
"template-index": json.dumps(["extract_simple", "extract_with_schema", "extract_mixed"]),
|
||||
"template.extract_simple": json.dumps({
|
||||
"prompt": "Extract entities from: {{ text }}",
|
||||
"response-type": "jsonl"
|
||||
}),
|
||||
"template.extract_with_schema": json.dumps({
|
||||
"prompt": "Extract definitions from: {{ text }}",
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity": {"type": "string"},
|
||||
"definition": {"type": "string"}
|
||||
},
|
||||
"required": ["entity", "definition"]
|
||||
}
|
||||
}),
|
||||
"template.extract_mixed": json.dumps({
|
||||
"prompt": "Extract knowledge from: {{ text }}",
|
||||
"response-type": "jsonl",
|
||||
"schema": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "definition"},
|
||||
"entity": {"type": "string"},
|
||||
"definition": {"type": "string"}
|
||||
},
|
||||
"required": ["type", "entity", "definition"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "relationship"},
|
||||
"subject": {"type": "string"},
|
||||
"predicate": {"type": "string"},
|
||||
"object": {"type": "string"}
|
||||
},
|
||||
"required": ["type", "subject", "predicate", "object"]
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(self, jsonl_config):
|
||||
"""Create a PromptManager with JSONL configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config(jsonl_config)
|
||||
return pm
|
||||
|
||||
def test_parse_jsonl_basic(self, prompt_manager):
|
||||
"""Test basic JSONL parsing"""
|
||||
text = '{"entity": "cat", "definition": "A small furry animal"}\n{"entity": "dog", "definition": "A loyal pet"}'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "cat"
|
||||
assert result[1]["entity"] == "dog"
|
||||
|
||||
def test_parse_jsonl_with_empty_lines(self, prompt_manager):
|
||||
"""Test JSONL parsing skips empty lines"""
|
||||
text = '{"entity": "cat"}\n\n\n{"entity": "dog"}\n'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_parse_jsonl_with_markdown_fences(self, prompt_manager):
|
||||
"""Test JSONL parsing strips markdown code fences"""
|
||||
text = '''```json
|
||||
{"entity": "cat", "definition": "A furry animal"}
|
||||
{"entity": "dog", "definition": "A loyal pet"}
|
||||
```'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "cat"
|
||||
assert result[1]["entity"] == "dog"
|
||||
|
||||
def test_parse_jsonl_with_jsonl_fence(self, prompt_manager):
|
||||
"""Test JSONL parsing strips jsonl-marked code fences"""
|
||||
text = '''```jsonl
|
||||
{"entity": "cat"}
|
||||
{"entity": "dog"}
|
||||
```'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_parse_jsonl_truncation_resilience(self, prompt_manager):
|
||||
"""Test JSONL parsing handles truncated final line"""
|
||||
text = '{"entity": "cat", "definition": "Complete"}\n{"entity": "dog", "defi'
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
# Should get the first valid object, skip the truncated one
|
||||
assert len(result) == 1
|
||||
assert result[0]["entity"] == "cat"
|
||||
|
||||
def test_parse_jsonl_invalid_lines_skipped(self, prompt_manager):
|
||||
"""Test JSONL parsing skips invalid JSON lines"""
|
||||
text = '''{"entity": "valid1"}
|
||||
not json at all
|
||||
{"entity": "valid2"}
|
||||
{broken json
|
||||
{"entity": "valid3"}'''
|
||||
|
||||
result = prompt_manager.parse_jsonl(text)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0]["entity"] == "valid1"
|
||||
assert result[1]["entity"] == "valid2"
|
||||
assert result[2]["entity"] == "valid3"
|
||||
|
||||
def test_parse_jsonl_empty_input(self, prompt_manager):
|
||||
"""Test JSONL parsing with empty input"""
|
||||
result = prompt_manager.parse_jsonl("")
|
||||
assert result == []
|
||||
|
||||
result = prompt_manager.parse_jsonl("\n\n\n")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_response(self, prompt_manager):
|
||||
"""Test invoking a prompt with JSONL response"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"entity": "photosynthesis", "definition": "Plant process"}\n{"entity": "mitosis", "definition": "Cell division"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Biology text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "photosynthesis"
|
||||
assert result[1]["entity"] == "mitosis"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_with_schema_validation(self, prompt_manager):
|
||||
"""Test JSONL response with schema validation"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"entity": "cat", "definition": "A pet"}\n{"entity": "dog", "definition": "Another pet"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_with_schema",
|
||||
{"text": "Animal text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all("entity" in obj and "definition" in obj for obj in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_schema_filters_invalid(self, prompt_manager):
|
||||
"""Test JSONL schema validation filters out invalid objects"""
|
||||
mock_llm = AsyncMock()
|
||||
# Second object is missing required 'definition' field
|
||||
mock_llm.return_value = '{"entity": "valid", "definition": "Has both fields"}\n{"entity": "invalid_missing_definition"}\n{"entity": "also_valid", "definition": "Complete"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_with_schema",
|
||||
{"text": "Test text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
# Only the two valid objects should be returned
|
||||
assert len(result) == 2
|
||||
assert result[0]["entity"] == "valid"
|
||||
assert result[1]["entity"] == "also_valid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_mixed_types(self, prompt_manager):
|
||||
"""Test JSONL with discriminated union schema (oneOf)"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '''{"type": "definition", "entity": "DNA", "definition": "Genetic material"}
|
||||
{"type": "relationship", "subject": "DNA", "predicate": "found_in", "object": "nucleus"}
|
||||
{"type": "definition", "entity": "RNA", "definition": "Messenger molecule"}'''
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_mixed",
|
||||
{"text": "Biology text"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Check definitions
|
||||
definitions = [r for r in result if r.get("type") == "definition"]
|
||||
assert len(definitions) == 2
|
||||
|
||||
# Check relationships
|
||||
relationships = [r for r in result if r.get("type") == "relationship"]
|
||||
assert len(relationships) == 1
|
||||
assert relationships[0]["subject"] == "DNA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_empty_result(self, prompt_manager):
|
||||
"""Test JSONL response that yields no valid objects"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = "No JSON here at all"
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Test"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_jsonl_without_schema(self, prompt_manager):
|
||||
"""Test JSONL response without schema validation"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"any": "structure"}\n{"completely": "different"}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"extract_simple",
|
||||
{"text": "Test"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"any": "structure"}
|
||||
assert result[1] == {"completely": "different"}
|
||||
|
|
@ -167,7 +167,7 @@ class TestFlowClient:
|
|||
expected_methods = [
|
||||
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||
'triples_query', 'objects_query'
|
||||
'triples_query', 'rows_query'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -216,7 +216,7 @@ class TestSocketClient:
|
|||
expected_methods = [
|
||||
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||
'triples_query', 'objects_query', 'mcp_tool'
|
||||
'triples_query', 'rows_query', 'mcp_tool'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
|
|
@ -243,7 +243,7 @@ class TestBulkClient:
|
|||
'import_graph_embeddings',
|
||||
'import_document_embeddings',
|
||||
'import_entity_contexts',
|
||||
'import_objects'
|
||||
'import_rows'
|
||||
]
|
||||
|
||||
for method in import_methods:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus.service import Processor
|
||||
from trustgraph.schema import Value, GraphEmbeddingsRequest
|
||||
from trustgraph.schema import Term, GraphEmbeddingsRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -68,50 +68,50 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
|
|
@ -138,17 +138,17 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
[0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10
|
||||
)
|
||||
|
||||
# Verify results are converted to Value objects
|
||||
# Verify results are converted to Term objects
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], Value)
|
||||
assert result[0].value == "http://example.com/entity1"
|
||||
assert result[0].is_uri is True
|
||||
assert isinstance(result[1], Value)
|
||||
assert result[1].value == "http://example.com/entity2"
|
||||
assert result[1].is_uri is True
|
||||
assert isinstance(result[2], Value)
|
||||
assert isinstance(result[0], Term)
|
||||
assert result[0].iri == "http://example.com/entity1"
|
||||
assert result[0].type == IRI
|
||||
assert isinstance(result[1], Term)
|
||||
assert result[1].iri == "http://example.com/entity2"
|
||||
assert result[1].type == IRI
|
||||
assert isinstance(result[2], Term)
|
||||
assert result[2].value == "literal entity"
|
||||
assert result[2].is_uri is False
|
||||
assert result[2].type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor):
|
||||
|
|
@ -186,7 +186,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results are deduplicated and limited
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
assert "http://example.com/entity3" in entity_values
|
||||
|
|
@ -246,7 +246,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify duplicates are removed
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert len(set(entity_values)) == 3 # All unique
|
||||
assert "http://example.com/entity1" in entity_values
|
||||
assert "http://example.com/entity2" in entity_values
|
||||
|
|
@ -346,14 +346,14 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
assert len(result) == 4
|
||||
|
||||
# Check URI entities
|
||||
uri_results = [r for r in result if r.is_uri]
|
||||
uri_results = [r for r in result if r.type == IRI]
|
||||
assert len(uri_results) == 2
|
||||
uri_values = [r.value for r in uri_results]
|
||||
uri_values = [r.iri for r in uri_results]
|
||||
assert "http://example.com/uri_entity" in uri_values
|
||||
assert "https://example.com/another_uri" in uri_values
|
||||
|
||||
# Check literal entities
|
||||
literal_results = [r for r in result if not r.is_uri]
|
||||
literal_results = [r for r in result if not r.type == IRI]
|
||||
assert len(literal_results) == 2
|
||||
literal_values = [r.value for r in literal_results]
|
||||
assert "literal entity text" in literal_values
|
||||
|
|
@ -486,7 +486,7 @@ class TestMilvusGraphEmbeddingsQueryProcessor:
|
|||
|
||||
# Verify results from all dimensions
|
||||
assert len(result) == 3
|
||||
entity_values = [r.value for r in result]
|
||||
entity_values = [r.iri if r.type == IRI else r.value for r in result]
|
||||
assert "entity_2d" in entity_values
|
||||
assert "entity_4d" in entity_values
|
||||
assert "entity_3d" in entity_values
|
||||
|
|
@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch
|
|||
pytest.skip("Pinecone library missing protoc_gen_openapiv2 dependency", allow_module_level=True)
|
||||
|
||||
from trustgraph.query.graph_embeddings.pinecone.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestPineconeGraphEmbeddingsQueryProcessor:
|
||||
|
|
@ -105,27 +105,27 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
uri_entity = "http://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
assert value.type == IRI
|
||||
|
||||
def test_create_value_https_uri(self, processor):
|
||||
"""Test create_value method for HTTPS URI entities"""
|
||||
uri_entity = "https://example.org/entity"
|
||||
value = processor.create_value(uri_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == uri_entity
|
||||
assert value.is_uri == True
|
||||
assert value.type == IRI
|
||||
|
||||
def test_create_value_literal(self, processor):
|
||||
"""Test create_value method for literal entities"""
|
||||
literal_entity = "literal_entity"
|
||||
value = processor.create_value(literal_entity)
|
||||
|
||||
assert isinstance(value, Value)
|
||||
assert isinstance(value, Term)
|
||||
assert value.value == literal_entity
|
||||
assert value.is_uri == False
|
||||
assert value.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_single_vector(self, processor):
|
||||
|
|
@ -165,11 +165,11 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
|
|||
# Verify results
|
||||
assert len(entities) == 3
|
||||
assert entities[0].value == 'http://example.org/entity1'
|
||||
assert entities[0].is_uri == True
|
||||
assert entities[0].type == IRI
|
||||
assert entities[1].value == 'entity2'
|
||||
assert entities[1].is_uri == False
|
||||
assert entities[1].type == LITERAL
|
||||
assert entities[2].value == 'http://example.org/entity3'
|
||||
assert entities[2].is_uri == True
|
||||
assert entities[2].type == IRI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, processor, mock_query_message):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
|
|
@ -85,10 +86,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
assert hasattr(value, 'iri')
|
||||
assert value.iri == 'http://example.com/entity'
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == IRI
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -109,10 +110,10 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
assert hasattr(value, 'iri')
|
||||
assert value.iri == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == IRI
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -135,8 +136,8 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
assert hasattr(value, 'type')
|
||||
assert value.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
|
|
@ -428,14 +429,14 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
|||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
uri_entities = [entity for entity in result if entity.type == IRI]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
uri_values = [entity.iri for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
regular_entities = [entity for entity in result if entity.type == LITERAL]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryUserCollectionIsolation:
|
||||
|
|
@ -24,9 +24,9 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="test_object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="test_object"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -65,8 +65,8 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
|
@ -105,9 +105,9 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=Value(value="http://example.com/o", is_uri=True),
|
||||
o=Term(type=IRI, iri="http://example.com/o"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -185,8 +185,8 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="literal", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="literal"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
|
@ -265,7 +265,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="test_value", is_uri=False),
|
||||
o=Term(type=LITERAL, value="test_value"),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
|
||||
# Query without user/collection fields
|
||||
query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -385,7 +385,7 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
|
|
@ -416,17 +416,17 @@ class TestMemgraphQueryUserCollectionIsolation:
|
|||
assert len(result) == 2
|
||||
|
||||
# First triple (literal object)
|
||||
assert result[0].s.value == "http://example.com/s"
|
||||
assert result[0].s.is_uri == True
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].p.is_uri == True
|
||||
assert result[0].s.iri == "http://example.com/s"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "literal_value"
|
||||
assert result[0].o.is_uri == False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Second triple (URI object)
|
||||
assert result[1].s.value == "http://example.com/s"
|
||||
assert result[1].s.is_uri == True
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].p.is_uri == True
|
||||
assert result[1].o.value == "http://example.com/o"
|
||||
assert result[1].o.is_uri == True
|
||||
assert result[1].s.iri == "http://example.com/s"
|
||||
assert result[1].s.type == IRI
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].p.type == IRI
|
||||
assert result[1].o.iri == "http://example.com/o"
|
||||
assert result[1].o.type == IRI
|
||||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryUserCollectionIsolation:
|
||||
|
|
@ -24,21 +24,23 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="test_object", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="test_object"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SPO query for literal includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN $src as src"
|
||||
"RETURN $src as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -63,23 +65,25 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=None
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SP query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN dest.value as dest"
|
||||
"RETURN dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
src="http://example.com/s",
|
||||
|
|
@ -88,13 +92,14 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
|
||||
# Verify SP query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN dest.uri as dest"
|
||||
"RETURN dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -118,21 +123,23 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=Value(value="http://example.com/o", is_uri=True)
|
||||
o=Term(type=IRI, iri="http://example.com/o"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify SO query for nodes includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel"
|
||||
"RETURN rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -156,23 +163,25 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify S query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN rel.uri as rel, dest.value as dest"
|
||||
"RETURN rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_query,
|
||||
src="http://example.com/s",
|
||||
|
|
@ -194,20 +203,22 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="literal", is_uri=False)
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="literal"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify PO query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src"
|
||||
"RETURN src.uri as src "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -232,20 +243,22 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=None,
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=None
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify P query includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, dest.value as dest"
|
||||
"RETURN src.uri as src, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -270,19 +283,21 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="test_value", is_uri=False)
|
||||
o=Term(type=LITERAL, value="test_value"),
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify O query for literals includes user/collection
|
||||
expected_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel"
|
||||
"RETURN src.uri as src, rel.uri as rel "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -307,34 +322,37 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
collection="test_collection",
|
||||
s=None,
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify wildcard query for literals includes user/collection
|
||||
expected_literal_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Literal {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
expected_literal_query,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
database_='neo4j'
|
||||
)
|
||||
|
||||
|
||||
# Verify wildcard query for nodes includes user/collection
|
||||
expected_node_query = (
|
||||
"MATCH (src:Node {user: $user, collection: $collection})-"
|
||||
"[rel:Rel {user: $user, collection: $collection}]->"
|
||||
"(dest:Node {user: $user, collection: $collection}) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest"
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
|
||||
"LIMIT 10"
|
||||
)
|
||||
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -355,9 +373,10 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
|
||||
# Query without user/collection fields
|
||||
query = TriplesQueryRequest(
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
|
||||
|
|
@ -384,47 +403,48 @@ class TestNeo4jQueryUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/s", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/s"),
|
||||
p=None,
|
||||
o=None
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
# Mock some results
|
||||
mock_record1 = MagicMock()
|
||||
mock_record1.data.return_value = {
|
||||
"rel": "http://example.com/p1",
|
||||
"dest": "literal_value"
|
||||
}
|
||||
|
||||
|
||||
mock_record2 = MagicMock()
|
||||
mock_record2.data.return_value = {
|
||||
"rel": "http://example.com/p2",
|
||||
"dest": "http://example.com/o"
|
||||
}
|
||||
|
||||
|
||||
# Return results for literal query, empty for node query
|
||||
mock_driver.execute_query.side_effect = [
|
||||
([mock_record1], MagicMock(), MagicMock()), # Literal query
|
||||
([mock_record2], MagicMock(), MagicMock()) # Node query
|
||||
]
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify results are proper Triple objects
|
||||
assert len(result) == 2
|
||||
|
||||
# First triple (literal object)
|
||||
assert result[0].s.value == "http://example.com/s"
|
||||
assert result[0].s.is_uri == True
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].p.is_uri == True
|
||||
assert result[0].s.iri == "http://example.com/s"
|
||||
assert result[0].s.type == IRI
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].p.type == IRI
|
||||
assert result[0].o.value == "literal_value"
|
||||
assert result[0].o.is_uri == False
|
||||
|
||||
assert result[0].o.type == LITERAL
|
||||
|
||||
# Second triple (URI object)
|
||||
assert result[1].s.value == "http://example.com/s"
|
||||
assert result[1].s.is_uri == True
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].p.is_uri == True
|
||||
assert result[1].o.value == "http://example.com/o"
|
||||
assert result[1].o.is_uri == True
|
||||
assert result[1].s.iri == "http://example.com/s"
|
||||
assert result[1].s.type == IRI
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].p.type == IRI
|
||||
assert result[1].o.iri == "http://example.com/o"
|
||||
assert result[1].o.type == IRI
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
"""
|
||||
Unit tests for Cassandra Objects GraphQL Query Processor
|
||||
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the GraphQL query processor including:
|
||||
- GraphQL schema generation from RowSchema
|
||||
- Query execution and validation
|
||||
- CQL translation logic
|
||||
- Schema configuration handling
|
||||
- Query execution using unified rows table
|
||||
- Name sanitization
|
||||
- GraphQL query execution
|
||||
- Message processing logic
|
||||
"""
|
||||
|
||||
|
|
@ -12,119 +13,91 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
import strawberry
|
||||
from strawberry import Schema
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.query.rows.cassandra.service import Processor
|
||||
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsGraphQLQueryLogic:
|
||||
"""Test business logic without external dependencies"""
|
||||
|
||||
def test_get_python_type_mapping(self):
|
||||
"""Test schema field type conversion to Python types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_python_type("string") == str
|
||||
assert processor.get_python_type("integer") == int
|
||||
assert processor.get_python_type("float") == float
|
||||
assert processor.get_python_type("boolean") == bool
|
||||
assert processor.get_python_type("timestamp") == str
|
||||
assert processor.get_python_type("date") == str
|
||||
assert processor.get_python_type("time") == str
|
||||
assert processor.get_python_type("uuid") == str
|
||||
|
||||
# Unknown type defaults to str
|
||||
assert processor.get_python_type("unknown_type") == str
|
||||
|
||||
def test_create_graphql_type_basic_fields(self):
|
||||
"""Test GraphQL type creation for basic field types"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
description="Test table",
|
||||
fields=[
|
||||
Field(
|
||||
name="id",
|
||||
type="string",
|
||||
primary=True,
|
||||
required=True,
|
||||
description="Primary key"
|
||||
),
|
||||
Field(
|
||||
name="name",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Name field"
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
required=False,
|
||||
description="Optional age"
|
||||
),
|
||||
Field(
|
||||
name="active",
|
||||
type="boolean",
|
||||
required=False,
|
||||
description="Status flag"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test_table", schema)
|
||||
|
||||
# Verify type was created
|
||||
assert graphql_type is not None
|
||||
assert hasattr(graphql_type, '__name__')
|
||||
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
|
||||
class TestRowsGraphQLQueryLogic:
|
||||
"""Test business logic for unified table query implementation"""
|
||||
|
||||
def test_sanitize_name_cassandra_compatibility(self):
|
||||
"""Test name sanitization for Cassandra field names"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test field name sanitization (matches storage processor)
|
||||
|
||||
# Test field name sanitization (uses r_ prefix like storage processor)
|
||||
assert processor.sanitize_name("simple_field") == "simple_field"
|
||||
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
|
||||
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
|
||||
assert processor.sanitize_name("123_field") == "o_123_field"
|
||||
assert processor.sanitize_name("123_field") == "r_123_field"
|
||||
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
|
||||
assert processor.sanitize_name("special!@#chars") == "special___chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
|
||||
def test_sanitize_table_name(self):
|
||||
"""Test table name sanitization (always gets o_ prefix)"""
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Table names always get o_ prefix
|
||||
assert processor.sanitize_table("simple_table") == "o_simple_table"
|
||||
assert processor.sanitize_table("Table-Name") == "o_table_name"
|
||||
assert processor.sanitize_table("123table") == "o_123table"
|
||||
assert processor.sanitize_table("") == "o_"
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_find_matching_index_exact_match(self):
|
||||
"""Test finding matching index for exact match query"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Filter on indexed field should return match
|
||||
filters = {"category": "electronics"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is not None
|
||||
assert result[0] == "category"
|
||||
assert result[1] == ["electronics"]
|
||||
|
||||
# Filter on non-indexed field should return None
|
||||
filters = {"name": "test"}
|
||||
result = processor.find_matching_index(schema, filters)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.graphql_types = {}
|
||||
processor.graphql_schema = None
|
||||
processor.config_key = "schema" # Set the config key
|
||||
processor.generate_graphql_schema = AsyncMock()
|
||||
processor.config_key = "schema"
|
||||
processor.schema_builder = MagicMock()
|
||||
processor.schema_builder.clear = MagicMock()
|
||||
processor.schema_builder.add_schema = MagicMock()
|
||||
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Create test config
|
||||
schema_config = {
|
||||
"schema": {
|
||||
|
|
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Process config
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer" in processor.schemas
|
||||
schema = processor.schemas["customer"]
|
||||
assert schema.name == "customer"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
|
||||
# Verify fields
|
||||
id_field = next(f for f in schema.fields if f.name == "id")
|
||||
assert id_field.primary is True
|
||||
# The field should have been created correctly from JSON
|
||||
# Let's test what we can verify - that the field has the right attributes
|
||||
assert hasattr(id_field, 'required') # Has the required attribute
|
||||
assert hasattr(id_field, 'primary') # Has the primary attribute
|
||||
|
||||
|
||||
email_field = next(f for f in schema.fields if f.name == "email")
|
||||
assert email_field.indexed is True
|
||||
|
||||
|
||||
status_field = next(f for f in schema.fields if f.name == "status")
|
||||
assert status_field.enum_values == ["active", "inactive"]
|
||||
|
||||
# Verify GraphQL schema regeneration was called
|
||||
processor.generate_graphql_schema.assert_called_once()
|
||||
|
||||
def test_cql_query_building_basic(self):
|
||||
"""Test basic CQL query construction"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to capture the query
|
||||
mock_result = []
|
||||
processor.session.execute.return_value = mock_result
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="test_table",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string", indexed=True),
|
||||
Field(name="status", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Test query building
|
||||
asyncio = pytest.importorskip("asyncio")
|
||||
|
||||
async def run_test():
|
||||
await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="test_table",
|
||||
row_schema=schema,
|
||||
filters={"name": "John", "invalid_filter": "ignored"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Run the async test
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(run_test())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Verify Cassandra connection and query execution
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify the query structure (can't easily test exact query without complex mocking)
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0] # First positional argument is the query
|
||||
params = call_args[0][1] # Second positional argument is parameters
|
||||
|
||||
# Basic query structure checks
|
||||
assert "SELECT * FROM test_user.o_test_table" in query
|
||||
assert "WHERE" in query
|
||||
assert "collection = %s" in query
|
||||
assert "LIMIT 10" in query
|
||||
|
||||
# Parameters should include collection and name filter
|
||||
assert "test_collection" in params
|
||||
assert "John" in params
|
||||
# Verify schema builder was called
|
||||
processor.schema_builder.add_schema.assert_called_once()
|
||||
processor.schema_builder.build.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_context_handling(self):
|
||||
|
|
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock schema execution
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||
mock_result.errors = None
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
|
|
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify schema.execute was called with correct context
|
||||
processor.graphql_schema.execute.assert_called_once()
|
||||
call_args = processor.graphql_schema.execute.call_args
|
||||
|
||||
|
||||
# Verify context was passed
|
||||
context = call_args[1]['context_value'] # keyword argument
|
||||
context = call_args[1]['context_value']
|
||||
assert context["processor"] == processor
|
||||
assert context["user"] == "test_user"
|
||||
assert context["collection"] == "test_collection"
|
||||
|
||||
|
||||
# Verify result structure
|
||||
assert "data" in result
|
||||
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
|
||||
|
|
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.graphql_schema = AsyncMock()
|
||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||
|
||||
# Create a simple object to simulate GraphQL error instead of MagicMock
|
||||
|
||||
# Create a simple object to simulate GraphQL error
|
||||
class MockError:
|
||||
def __init__(self, message, path, extensions):
|
||||
self.message = message
|
||||
self.path = path
|
||||
self.extensions = extensions
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
mock_error = MockError(
|
||||
message="Field 'invalid_field' doesn't exist",
|
||||
path=["customers", "0", "invalid_field"],
|
||||
extensions={"code": "FIELD_NOT_FOUND"}
|
||||
)
|
||||
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = None
|
||||
mock_result.errors = [mock_error]
|
||||
processor.graphql_schema.execute.return_value = mock_result
|
||||
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ customers { invalid_field } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify error handling
|
||||
assert "errors" in result
|
||||
assert len(result["errors"]) == 1
|
||||
|
||||
|
||||
error = result["errors"][0]
|
||||
assert error["message"] == "Field 'invalid_field' doesn't exist"
|
||||
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
|
||||
assert error["path"] == ["customers", "0", "invalid_field"]
|
||||
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
|
||||
|
||||
def test_schema_generation_basic_structure(self):
|
||||
"""Test basic GraphQL schema generation structure"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Test individual type creation (avoiding the full schema generation which has annotation issues)
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify type was created
|
||||
assert len(processor.graphql_types) == 1
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processing_success(self):
|
||||
"""Test successful message processing flow"""
|
||||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock successful query result
|
||||
processor.execute_graphql_query.return_value = {
|
||||
"data": {"customers": [{"id": "1", "name": "John"}]},
|
||||
"errors": [],
|
||||
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
|
||||
"extensions": {}
|
||||
}
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
collection="test_collection",
|
||||
query='{ customers { id name } }',
|
||||
variables={},
|
||||
operation_name=None
|
||||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-123"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify query was executed
|
||||
processor.execute_graphql_query.assert_called_once_with(
|
||||
query='{ customers { id name } }',
|
||||
|
|
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
|
||||
# Verify response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is None
|
||||
assert '"customers"' in response_call.data # JSON encoded
|
||||
assert len(response_call.errors) == 0
|
||||
|
|
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
|
|||
processor = MagicMock()
|
||||
processor.execute_graphql_query = AsyncMock()
|
||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Mock query execution error
|
||||
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_request = ObjectsQueryRequest(
|
||||
mock_request = RowsQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
query='{ invalid_query }',
|
||||
|
|
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
|
|||
)
|
||||
mock_msg.value.return_value = mock_request
|
||||
mock_msg.properties.return_value = {"id": "test-456"}
|
||||
|
||||
|
||||
# Mock flow
|
||||
mock_flow = MagicMock()
|
||||
mock_response_flow = AsyncMock()
|
||||
mock_flow.return_value = mock_response_flow
|
||||
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
|
||||
# Verify error response was sent
|
||||
mock_response_flow.send.assert_called_once()
|
||||
response_call = mock_response_flow.send.call_args[0][0]
|
||||
|
||||
|
||||
# Verify error response structure
|
||||
assert isinstance(response_call, ObjectsQueryResponse)
|
||||
assert isinstance(response_call, RowsQueryResponse)
|
||||
assert response_call.error is not None
|
||||
assert response_call.error.type == "objects-query-error"
|
||||
assert response_call.error.type == "rows-query-error"
|
||||
assert "No schema available" in response_call.error.message
|
||||
assert response_call.data is None
|
||||
|
||||
|
||||
class TestCQLQueryGeneration:
|
||||
"""Test CQL query generation logic in isolation"""
|
||||
|
||||
def test_partition_key_inclusion(self):
|
||||
"""Test that collection is always included in queries"""
|
||||
class TestUnifiedTableQueries:
|
||||
"""Test queries against the unified rows table"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_index_match(self):
|
||||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
|
||||
# Mock the query building (simplified version)
|
||||
keyspace = processor.sanitize_name("test_user")
|
||||
table = processor.sanitize_table("test_table")
|
||||
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
where_clauses = ["collection = %s"]
|
||||
|
||||
assert "collection = %s" in where_clauses
|
||||
assert keyspace == "test_user"
|
||||
assert table == "o_test_table"
|
||||
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row = MagicMock()
|
||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||
processor.session.execute.return_value = [mock_row]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"category": "electronics"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify Cassandra was connected and queried
|
||||
processor.connect_cassandra.assert_called_once()
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
# Verify query structure - should query unified rows table
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
params = call_args[0][1]
|
||||
|
||||
assert "SELECT data, source FROM test_user.rows" in query
|
||||
assert "collection = %s" in query
|
||||
assert "schema_name = %s" in query
|
||||
assert "index_name = %s" in query
|
||||
assert "index_value = %s" in query
|
||||
|
||||
assert params[0] == "test_collection"
|
||||
assert params[1] == "products"
|
||||
assert params[2] == "category"
|
||||
assert params[3] == ["electronics"]
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "123"
|
||||
assert results[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_without_index_match(self):
|
||||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||
|
||||
# Mock session execute to return test data
|
||||
mock_row1 = MagicMock()
|
||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||
mock_row2 = MagicMock()
|
||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
||||
processor.session.execute.return_value = [mock_row1, mock_row2]
|
||||
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="price", type="string") # Not indexed
|
||||
]
|
||||
)
|
||||
|
||||
# Query with filter on non-indexed field
|
||||
results = await processor.query_cassandra(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="products",
|
||||
row_schema=schema,
|
||||
filters={"name": "Product A"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Query should use ALLOW FILTERING for scan
|
||||
call_args = processor.session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
|
||||
assert "ALLOW FILTERING" in query
|
||||
|
||||
# Should post-filter results
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "Product A"
|
||||
|
||||
|
||||
class TestFilterMatching:
|
||||
"""Test filter matching logic"""
|
||||
|
||||
def test_matches_filters_exact_match(self):
|
||||
"""Test exact match filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active", "name": "test"}
|
||||
assert processor._matches_filters(row, {"status": "active"}, schema) is True
|
||||
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
|
||||
|
||||
def test_matches_filters_comparison_operators(self):
|
||||
"""Test comparison operators in filters"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
|
||||
|
||||
row = {"price": "100.0"}
|
||||
|
||||
# Greater than
|
||||
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
|
||||
|
||||
# Less than
|
||||
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
|
||||
|
||||
# Greater than or equal
|
||||
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
|
||||
|
||||
# Less than or equal
|
||||
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
|
||||
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
|
||||
|
||||
def test_matches_filters_contains(self):
|
||||
"""Test contains filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
|
||||
|
||||
row = {"description": "A great product for everyone"}
|
||||
|
||||
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
|
||||
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
|
||||
|
||||
def test_matches_filters_in_list(self):
|
||||
"""Test in-list filter"""
|
||||
processor = MagicMock()
|
||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||
|
||||
row = {"status": "active"}
|
||||
|
||||
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
|
||||
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
|
||||
|
||||
|
||||
class TestIndexedFieldFiltering:
|
||||
"""Test that only indexed or primary key fields can be directly filtered"""
|
||||
|
||||
def test_indexed_field_filtering(self):
|
||||
"""Test that only indexed or primary key fields can be filtered"""
|
||||
# Create schema with mixed field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="indexed_field", type="string", indexed=True),
|
||||
Field(name="normal_field", type="string", indexed=False),
|
||||
Field(name="another_field", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
filters = {
|
||||
"id": "test123", # Primary key - should be included
|
||||
"indexed_field": "value", # Indexed - should be included
|
||||
"normal_field": "ignored", # Not indexed - should be ignored
|
||||
"another_field": "also_ignored" # Not indexed - should be ignored
|
||||
}
|
||||
|
||||
|
||||
# Simulate the filtering logic from the processor
|
||||
valid_filters = []
|
||||
for field_name, value in filters.items():
|
||||
|
|
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
|
|||
schema_field = next((f for f in schema.fields if f.name == field_name), None)
|
||||
if schema_field and (schema_field.indexed or schema_field.primary):
|
||||
valid_filters.append((field_name, value))
|
||||
|
||||
|
||||
# Only id and indexed_field should be included
|
||||
assert len(valid_filters) == 2
|
||||
field_names = [f[0] for f in valid_filters]
|
||||
|
|
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
|
|||
assert "indexed_field" in field_names
|
||||
assert "normal_field" not in field_names
|
||||
assert "another_field" not in field_names
|
||||
|
||||
|
||||
class TestGraphQLSchemaGeneration:
|
||||
"""Test GraphQL schema generation in detail"""
|
||||
|
||||
def test_field_type_annotations(self):
|
||||
"""Test that GraphQL types have correct field annotations"""
|
||||
processor = MagicMock()
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create schema with various field types
|
||||
schema = RowSchema(
|
||||
name="test",
|
||||
fields=[
|
||||
Field(name="id", type="string", required=True, primary=True),
|
||||
Field(name="count", type="integer", required=True),
|
||||
Field(name="price", type="float", required=False),
|
||||
Field(name="active", type="boolean", required=False),
|
||||
Field(name="optional_text", type="string", required=False)
|
||||
]
|
||||
)
|
||||
|
||||
# Create GraphQL type
|
||||
graphql_type = processor.create_graphql_type("test", schema)
|
||||
|
||||
# Verify type was created successfully
|
||||
assert graphql_type is not None
|
||||
|
||||
def test_basic_type_creation(self):
|
||||
"""Test that GraphQL types are created correctly"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"customer": RowSchema(
|
||||
name="customer",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.graphql_types = {}
|
||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
||||
|
||||
# Create GraphQL type directly
|
||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
||||
processor.graphql_types["customer"] = graphql_type
|
||||
|
||||
# Verify customer type was created
|
||||
assert "customer" in processor.graphql_types
|
||||
assert processor.graphql_types["customer"] is not None
|
||||
|
|
@ -5,8 +5,8 @@ Tests for Cassandra triples query service
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
|
|
@ -21,94 +21,101 @@ class TestCassandraQueryProcessor:
|
|||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
"""Test create_term with HTTP URI"""
|
||||
result = create_term("http://example.com/resource")
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_https_uri(self, processor):
|
||||
"""Test create_term with HTTPS URI"""
|
||||
result = create_term("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_term_with_literal(self, processor):
|
||||
"""Test create_term with literal value"""
|
||||
result = create_term("just a literal string")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_empty_string(self, processor):
|
||||
"""Test create_term with empty string"""
|
||||
result = create_term("")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_partial_uri(self, processor):
|
||||
"""Test create_term with string that looks like URI but isn't complete"""
|
||||
result = create_term("http")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
def test_create_term_with_ftp_uri(self, processor):
|
||||
"""Test create_term with FTP URI (should not be detected as URI)"""
|
||||
result = create_term("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_spo_query(self, mock_kg_class):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
# Setup mock TrustGraph via factory function
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results (with mock graph attribute)
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user'
|
||||
)
|
||||
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
|
|
@ -143,154 +150,174 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_kg_class):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
# Setup mock TrustGraph via factory function
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_kg_class):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_kg_class):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_kg_class):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=75
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_kg_class):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
|
|
@ -299,9 +326,9 @@ class TestCassandraQueryProcessor:
|
|||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
|
|
@ -372,37 +399,44 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o, g) quad pattern, some values may be\nnull. Output is a list of quads.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_kg_class):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_username='authuser',
|
||||
cassandra_password='authpass'
|
||||
)
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='test_user',
|
||||
username='authuser',
|
||||
|
|
@ -410,128 +444,154 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_kg_class):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
# SPO query returns a list of results
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_table_switching(self, mock_kg_class):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
# Setup mock results for both instances
|
||||
mock_result = MagicMock()
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_kg_class):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_kg_class):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result1.g = ''
|
||||
mock_result1.otype = None
|
||||
mock_result1.dtype = None
|
||||
mock_result1.lang = None
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_result2.g = ''
|
||||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
|
|
@ -541,16 +601,20 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
"""Test cases for multi-table performance optimizations in query service"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_get_po_query_optimization(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_get_po_query_optimization(self, mock_kg_class):
|
||||
"""Test that get_po queries use optimized table (no ALLOW FILTERING)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
|
@ -560,8 +624,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
p=Term(type=LITERAL, value='test_predicate'),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=50
|
||||
)
|
||||
|
||||
|
|
@ -569,7 +633,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', limit=50
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
|
@ -578,16 +642,20 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_get_os_query_optimization(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_get_os_query_optimization(self, mock_kg_class):
|
||||
"""Test that get_os queries use optimized table (no ALLOW FILTERING)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
|
@ -596,9 +664,9 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
s=Term(type=LITERAL, value='test_subject'),
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
o=Term(type=LITERAL, value='test_object'),
|
||||
limit=25
|
||||
)
|
||||
|
||||
|
|
@ -606,7 +674,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', limit=25
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
|
|
@ -615,13 +683,13 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_all_query_patterns_use_correct_tables(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_all_query_patterns_use_correct_tables(self, mock_kg_class):
|
||||
"""Test that all query patterns route to their optimal tables"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
|
|
@ -655,9 +723,9 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value=s, is_uri=False) if s else None,
|
||||
p=Value(value=p, is_uri=False) if p else None,
|
||||
o=Value(value=o, is_uri=False) if o else None,
|
||||
s=Term(type=LITERAL, value=s) if s else None,
|
||||
p=Term(type=LITERAL, value=p) if p else None,
|
||||
o=Term(type=LITERAL, value=o) if o else None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
|
|
@ -687,19 +755,23 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
# Mode is determined in KnowledgeGraph initialization
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
|
||||
async def test_performance_critical_po_query_no_filtering(self, mock_trustgraph):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_performance_critical_po_query_no_filtering(self, mock_kg_class):
|
||||
"""Test the performance-critical PO query that eliminates ALLOW FILTERING"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
from trustgraph.schema import TriplesQueryRequest, Term, IRI, LITERAL
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple subjects for the same predicate-object pair
|
||||
mock_results = []
|
||||
for i in range(5):
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = f'subject_{i}'
|
||||
mock_result.g = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
|
|
@ -711,8 +783,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
user='large_dataset_user',
|
||||
collection='massive_collection',
|
||||
s=None,
|
||||
p=Value(value='http://www.w3.org/1999/02/22-rdf-syntax-ns#type', is_uri=True),
|
||||
o=Value(value='http://example.com/Person', is_uri=True),
|
||||
p=Term(type=IRI, iri='http://www.w3.org/1999/02/22-rdf-syntax-ns#type'),
|
||||
o=Term(type=IRI, iri='http://example.com/Person'),
|
||||
limit=1000
|
||||
)
|
||||
|
||||
|
|
@ -723,14 +795,15 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
g=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
# Verify all results were returned
|
||||
assert len(result) == 5
|
||||
for i, triple in enumerate(result):
|
||||
assert triple.s.value == f'subject_{i}'
|
||||
assert triple.p.value == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.is_uri is True
|
||||
assert triple.o.value == 'http://example.com/Person'
|
||||
assert triple.o.is_uri is True
|
||||
assert triple.s.value == f'subject_{i}' # Mock returns literal values
|
||||
assert triple.p.iri == 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
|
||||
assert triple.p.type == IRI
|
||||
assert triple.o.iri == 'http://example.com/Person' # URIs use .iri
|
||||
assert triple.o.type == IRI
|
||||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.falkordb.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestFalkorDBQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestFalkorDBQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
def test_processor_initialization_with_defaults(self, mock_falkordb):
|
||||
|
|
@ -125,9 +125,9 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -138,8 +138,8 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -166,8 +166,8 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -211,9 +211,9 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -224,12 +224,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -256,7 +256,7 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
@ -269,13 +269,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -302,8 +302,8 @@ class TestFalkorDBQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -314,12 +314,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -347,7 +347,7 @@ class TestFalkorDBQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -359,13 +359,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -393,7 +393,7 @@ class TestFalkorDBQueryProcessor:
|
|||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -404,12 +404,12 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
|
|
@ -449,13 +449,13 @@ class TestFalkorDBQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.falkordb.service.FalkorDB')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -476,7 +476,7 @@ class TestFalkorDBQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.memgraph.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestMemgraphQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
|
|
@ -124,9 +124,9 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +137,8 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -166,8 +166,8 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -212,9 +212,9 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -225,12 +225,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different predicates
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -258,7 +258,7 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
@ -271,13 +271,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different predicate-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -305,8 +305,8 @@ class TestMemgraphQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -317,12 +317,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subjects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -351,7 +351,7 @@ class TestMemgraphQueryProcessor:
|
|||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -363,13 +363,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-object pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -398,7 +398,7 @@ class TestMemgraphQueryProcessor:
|
|||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -409,12 +409,12 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different subject-predicate pairs
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subj1"
|
||||
assert result[0].p.value == "http://example.com/pred1"
|
||||
assert result[0].s.iri == "http://example.com/subj1"
|
||||
assert result[0].p.iri == "http://example.com/pred1"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subj2"
|
||||
assert result[1].p.value == "http://example.com/pred2"
|
||||
assert result[1].s.iri == "http://example.com/subj2"
|
||||
assert result[1].p.iri == "http://example.com/pred2"
|
||||
assert result[1].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
|
|
@ -455,13 +455,13 @@ class TestMemgraphQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.memgraph.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -480,7 +480,7 @@ class TestMemgraphQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.neo4j.service import Processor
|
||||
from trustgraph.schema import Value, TriplesQueryRequest
|
||||
from trustgraph.schema import Term, TriplesQueryRequest, IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jQueryProcessor:
|
||||
|
|
@ -25,50 +25,50 @@ class TestNeo4jQueryProcessor:
|
|||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "http://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.iri == "https://example.com/resource"
|
||||
assert result.type == IRI
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
|
||||
assert isinstance(result, Term)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
assert result.type == LITERAL
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
def test_processor_initialization_with_defaults(self, mock_graph_db):
|
||||
|
|
@ -124,9 +124,9 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal object", is_uri=False),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal object"),
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
|
@ -137,8 +137,8 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify result contains the queried triple (appears twice - once from each query)
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal object"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
|
|
@ -166,8 +166,8 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
|
@ -179,13 +179,13 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify results contain different objects
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/subject"
|
||||
assert result[0].p.value == "http://example.com/predicate"
|
||||
assert result[0].s.iri == "http://example.com/subject"
|
||||
assert result[0].p.iri == "http://example.com/predicate"
|
||||
assert result[0].o.value == "literal result"
|
||||
|
||||
assert result[1].s.value == "http://example.com/subject"
|
||||
assert result[1].p.value == "http://example.com/predicate"
|
||||
assert result[1].o.value == "http://example.com/uri_result"
|
||||
|
||||
assert result[1].s.iri == "http://example.com/subject"
|
||||
assert result[1].p.iri == "http://example.com/predicate"
|
||||
assert result[1].o.iri == "http://example.com/uri_result"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -225,13 +225,13 @@ class TestNeo4jQueryProcessor:
|
|||
|
||||
# Verify results contain different triples
|
||||
assert len(result) == 2
|
||||
assert result[0].s.value == "http://example.com/s1"
|
||||
assert result[0].p.value == "http://example.com/p1"
|
||||
assert result[0].s.iri == "http://example.com/s1"
|
||||
assert result[0].p.iri == "http://example.com/p1"
|
||||
assert result[0].o.value == "literal1"
|
||||
|
||||
assert result[1].s.value == "http://example.com/s2"
|
||||
assert result[1].p.value == "http://example.com/p2"
|
||||
assert result[1].o.value == "http://example.com/o2"
|
||||
|
||||
assert result[1].s.iri == "http://example.com/s2"
|
||||
assert result[1].p.iri == "http://example.com/p2"
|
||||
assert result[1].o.iri == "http://example.com/o2"
|
||||
|
||||
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -250,12 +250,12 @@ class TestNeo4jQueryProcessor:
|
|||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="Database connection failed"):
|
||||
await processor.query_triples(query)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
RowsQueryRequest, RowsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
|
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||
errors=None,
|
||||
|
|
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
|
|||
# Verify objects query service was called correctly
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
|
|
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
|
|||
assert response.error is not None
|
||||
assert "empty GraphQL query" in response.error.message
|
||||
|
||||
async def test_objects_query_service_error(self, processor):
|
||||
async def test_rows_query_service_error(self, processor):
|
||||
"""Test handling of objects query service errors"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
|
|
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects query service error
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
|
||||
data=None,
|
||||
errors=None,
|
||||
|
|
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
|
|||
response = response_call[0][0]
|
||||
|
||||
assert response.error is not None
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "Rows query service error" in response.error.message
|
||||
assert "Table 'customers' not found" in response.error.message
|
||||
|
||||
async def test_graphql_errors_handling(self, processor):
|
||||
|
|
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None,
|
||||
errors=graphql_errors,
|
||||
|
|
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
|
|||
)
|
||||
|
||||
# Mock objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
|
||||
errors=None
|
||||
|
|
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
|
|||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
objects_response = RowsQueryResponse(
|
||||
error=None,
|
||||
data=None, # Null data
|
||||
errors=None,
|
||||
|
|
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
|
|||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
elif service_name == "rows-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
|
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
|
|||
assert processor.cassandra_password is None
|
||||
|
||||
|
||||
class TestObjectsWriterConfiguration:
|
||||
class TestRowsWriterConfiguration:
|
||||
"""Test Cassandra configuration in objects writer processor."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_environment_variable_configuration(self, mock_cluster):
|
||||
"""Test processor picks up configuration from environment variables."""
|
||||
env_vars = {
|
||||
|
|
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
|
||||
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
|
||||
assert processor.cassandra_username == 'obj-env-user'
|
||||
assert processor.cassandra_password == 'obj-env-pass'
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
|
||||
"""Test that Cassandra connection uses hosts list correctly."""
|
||||
env_vars = {
|
||||
|
|
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify cluster was called with hosts list
|
||||
|
|
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
|
|||
assert 'contact_points' in call_args.kwargs
|
||||
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
|
||||
"""Test authentication is configured when credentials are provided."""
|
||||
env_vars = {
|
||||
|
|
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
|
|||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor = RowsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify auth provider was created with correct credentials
|
||||
|
|
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
|
|||
def test_objects_writer_add_args(self):
|
||||
"""Test that objects writer adds standard Cassandra arguments."""
|
||||
import argparse
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ObjectsWriter.add_args(parser)
|
||||
RowsWriter.add_args(parser)
|
||||
|
||||
# Parse empty args to check that arguments exist
|
||||
args = parser.parse_args([])
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.graph_embeddings.milvus.write import Processor
|
||||
from trustgraph.schema import Value, EntityEmbeddings
|
||||
from trustgraph.schema import Term, EntityEmbeddings, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMilvusGraphEmbeddingsStorageProcessor:
|
||||
|
|
@ -22,11 +22,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
|
||||
# Create test entities with embeddings
|
||||
entity1 = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity1', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity1'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
entity2 = EntityEmbeddings(
|
||||
entity=Value(value='literal entity', is_uri=False),
|
||||
entity=Term(type=LITERAL, value='literal entity'),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [entity1, entity2]
|
||||
|
|
@ -84,7 +84,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -136,7 +136,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -155,7 +155,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -174,15 +174,15 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
valid_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/valid', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/valid'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
empty_entity = EntityEmbeddings(
|
||||
entity=Value(value='', is_uri=False),
|
||||
entity=Term(type=LITERAL, value=''),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
none_entity = EntityEmbeddings(
|
||||
entity=Value(value=None, is_uri=False),
|
||||
entity=Term(type=LITERAL, value=None),
|
||||
vectors=[[0.7, 0.8, 0.9]]
|
||||
)
|
||||
message.entities = [valid_entity, empty_entity, none_entity]
|
||||
|
|
@ -217,7 +217,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
|
@ -236,7 +236,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/entity'),
|
||||
vectors=[
|
||||
[0.1, 0.2], # 2D vector
|
||||
[0.3, 0.4, 0.5, 0.6], # 4D vector
|
||||
|
|
@ -269,11 +269,11 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
uri_entity = EntityEmbeddings(
|
||||
entity=Value(value='http://example.com/uri_entity', is_uri=True),
|
||||
entity=Term(type=IRI, iri='http://example.com/uri_entity'),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
literal_entity = EntityEmbeddings(
|
||||
entity=Value(value='literal entity text', is_uri=False),
|
||||
entity=Term(type=LITERAL, value='literal entity text'),
|
||||
vectors=[[0.4, 0.5, 0.6]]
|
||||
)
|
||||
message.entities = [uri_entity, literal_entity]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest import IsolatedAsyncioTestCase
|
|||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
|
|
@ -67,7 +68,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
|
@ -120,11 +122,13 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.entity.type = IRI
|
||||
mock_entity1.entity.iri = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.entity.type = IRI
|
||||
mock_entity2.entity.iri = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
|
@ -179,7 +183,8 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'multi_vector_entity'
|
||||
mock_entity.entity.type = IRI
|
||||
mock_entity.entity.iri = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
|
|
@ -231,11 +236,12 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.type = LITERAL
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity.value = None # None value
|
||||
mock_entity_none.entity = None # None entity
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch, call
|
|||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.neo4j.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triples, Triple, Value, Metadata
|
||||
from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
|
||||
|
|
@ -60,9 +60,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
)
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="literal_value", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="literal_value")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
|
|
@ -128,9 +128,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata = Metadata(id="test-id")
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="http://example.com/object", is_uri=True)
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=IRI, iri="http://example.com/object")
|
||||
)
|
||||
|
||||
message = Triples(
|
||||
|
|
@ -170,8 +170,8 @@ class TestNeo4jUserCollectionIsolation:
|
|||
query = TriplesQueryRequest(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
s=Value(value="http://example.com/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
s=Term(type=IRI, iri="http://example.com/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=None
|
||||
)
|
||||
|
||||
|
|
@ -254,9 +254,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.com/user1/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="user1_data", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/user1/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user1_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -265,9 +265,9 @@ class TestNeo4jUserCollectionIsolation:
|
|||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://example.com/user2/subject", is_uri=True),
|
||||
p=Value(value="http://example.com/predicate", is_uri=True),
|
||||
o=Value(value="user2_data", is_uri=False)
|
||||
s=Term(type=IRI, iri="http://example.com/user2/subject"),
|
||||
p=Term(type=IRI, iri="http://example.com/predicate"),
|
||||
o=Term(type=LITERAL, value="user2_data")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -429,9 +429,9 @@ class TestNeo4jUserCollectionRegression:
|
|||
metadata=Metadata(user="user1", collection="coll1"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=shared_uri, is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="user1_value", is_uri=False)
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user1_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -440,9 +440,9 @@ class TestNeo4jUserCollectionRegression:
|
|||
metadata=Metadata(user="user2", collection="coll2"),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value=shared_uri, is_uri=True),
|
||||
p=Value(value="http://example.com/p", is_uri=True),
|
||||
o=Value(value="user2_value", is_uri=False)
|
||||
s=Term(type=IRI, iri=shared_uri),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="user2_value")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,533 +0,0 @@
|
|||
"""
|
||||
Unit tests for Cassandra Object Storage Processor
|
||||
|
||||
Tests the business logic of the object storage processor including:
|
||||
- Schema configuration handling
|
||||
- Type conversions
|
||||
- Name sanitization
|
||||
- Table structure generation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.objects.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageLogic:
|
||||
"""Test business logic without FlowProcessor dependencies"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns (back to original logic)
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
|
||||
def test_get_cassandra_type(self):
|
||||
"""Test field type conversion to Cassandra types"""
|
||||
processor = MagicMock()
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
|
||||
# Basic type mappings
|
||||
assert processor.get_cassandra_type("string") == "text"
|
||||
assert processor.get_cassandra_type("boolean") == "boolean"
|
||||
assert processor.get_cassandra_type("timestamp") == "timestamp"
|
||||
assert processor.get_cassandra_type("uuid") == "uuid"
|
||||
|
||||
# Integer types with size hints
|
||||
assert processor.get_cassandra_type("integer", size=2) == "int"
|
||||
assert processor.get_cassandra_type("integer", size=8) == "bigint"
|
||||
|
||||
# Float types with size hints
|
||||
assert processor.get_cassandra_type("float", size=2) == "float"
|
||||
assert processor.get_cassandra_type("float", size=8) == "double"
|
||||
|
||||
# Unknown type defaults to text
|
||||
assert processor.get_cassandra_type("unknown_type") == "text"
|
||||
|
||||
def test_convert_value(self):
|
||||
"""Test value conversion for different field types"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Integer conversions
|
||||
assert processor.convert_value("123", "integer") == 123
|
||||
assert processor.convert_value(123.5, "integer") == 123
|
||||
assert processor.convert_value(None, "integer") is None
|
||||
|
||||
# Float conversions
|
||||
assert processor.convert_value("123.45", "float") == 123.45
|
||||
assert processor.convert_value(123, "float") == 123.0
|
||||
|
||||
# Boolean conversions
|
||||
assert processor.convert_value("true", "boolean") is True
|
||||
assert processor.convert_value("false", "boolean") is False
|
||||
assert processor.convert_value("1", "boolean") is True
|
||||
assert processor.convert_value("0", "boolean") is False
|
||||
assert processor.convert_value("yes", "boolean") is True
|
||||
assert processor.convert_value("no", "boolean") is False
|
||||
|
||||
# String conversions
|
||||
assert processor.convert_value(123, "string") == "123"
|
||||
assert processor.convert_value(True, "string") == "True"
|
||||
|
||||
def test_table_creation_cql_generation(self):
|
||||
"""Test CQL generation for table creation"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create test schema
|
||||
schema = RowSchema(
|
||||
name="customer_records",
|
||||
description="Test customer schema",
|
||||
fields=[
|
||||
Field(
|
||||
name="customer_id",
|
||||
type="string",
|
||||
size=50,
|
||||
primary=True,
|
||||
required=True,
|
||||
indexed=False
|
||||
),
|
||||
Field(
|
||||
name="email",
|
||||
type="string",
|
||||
size=100,
|
||||
required=True,
|
||||
indexed=True
|
||||
),
|
||||
Field(
|
||||
name="age",
|
||||
type="integer",
|
||||
size=4,
|
||||
required=False,
|
||||
indexed=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "customer_records", schema)
|
||||
|
||||
# Verify keyspace was ensured (check that it was added to known_keyspaces)
|
||||
assert "test_user" in processor.known_keyspaces
|
||||
|
||||
# Check the CQL that was executed (first call should be table creation)
|
||||
all_calls = processor.session.execute.call_args_list
|
||||
table_creation_cql = all_calls[0][0][0] # First call
|
||||
|
||||
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
|
||||
assert "collection text" in table_creation_cql
|
||||
assert "customer_id text" in table_creation_cql
|
||||
assert "email text" in table_creation_cql
|
||||
assert "age int" in table_creation_cql
|
||||
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
|
||||
|
||||
def test_table_creation_without_primary_key(self):
|
||||
"""Test table creation when no primary key is defined"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema without primary key
|
||||
schema = RowSchema(
|
||||
name="events",
|
||||
description="Event log",
|
||||
fields=[
|
||||
Field(name="event_type", type="string", size=50),
|
||||
Field(name="timestamp", type="timestamp", size=0)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "events", schema)
|
||||
|
||||
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
|
||||
executed_cql = processor.session.execute.call_args[0][0]
|
||||
assert "synthetic_id uuid" in executed_cql
|
||||
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "balance",
|
||||
"type": "float",
|
||||
"size": 8
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
# Note: Field.required always returns False due to Pulsar schema limitations
|
||||
# The actual required value is tracked during schema parsing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_logic(self):
|
||||
"""Test the logic for processing ExtractedObject"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "456"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
|
||||
|
||||
# Verify insert was executed (keyspace normal, table with o_ prefix)
|
||||
processor.session.execute.assert_called_once()
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
assert values[0] == "test_collection" # collection value
|
||||
assert values[1] == "123" # id value (from values[0])
|
||||
assert values[2] == 456 # converted integer value (from values[0])
|
||||
|
||||
def test_secondary_index_creation(self):
|
||||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
if keyspace not in processor.known_tables:
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
description="Product catalog",
|
||||
fields=[
|
||||
Field(name="product_id", type="string", size=50, primary=True),
|
||||
Field(name="category", type="string", size=30, indexed=True),
|
||||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check index creation calls (table has o_ prefix, fields don't)
|
||||
calls = processor.session.execute.call_args_list
|
||||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
||||
assert len(index_calls) == 2
|
||||
assert any("o_products_category_idx" in call for call in index_calls)
|
||||
assert any("o_products_price_idx" in call for call in index_calls)
|
||||
|
||||
|
||||
class TestObjectsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic in Cassandra storage"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing_logic(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
description="Test batch schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="name", type="string", size=100),
|
||||
Field(name="value", type="integer", size=4)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First", "value": "100"},
|
||||
{"id": "002", "name": "Second", "value": "200"},
|
||||
{"id": "003", "name": "Third", "value": "300"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span="batch source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
# Process batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured once
|
||||
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
|
||||
|
||||
# Verify 3 separate insert calls (one per batch item)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert call
|
||||
calls = processor.session.execute.call_args_list
|
||||
for i, call in enumerate(calls):
|
||||
insert_cql = call[0][0]
|
||||
values = call[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
|
||||
assert "collection" in insert_cql
|
||||
|
||||
# Check values for each batch item
|
||||
assert values[0] == "batch_collection" # collection
|
||||
assert values[1] == f"00{i+1}" # id from batch item i
|
||||
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
|
||||
assert values[3] == (i+1) * 100 # converted integer value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing_logic(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="empty source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
# Process empty batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_item_batch_processing_logic(self):
|
||||
"""Test processing of single-item batch (backward compatibility)"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"single_schema": RowSchema(
|
||||
name="single_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="single-001",
|
||||
user="test_user",
|
||||
collection="single_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="single_schema",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with one item
|
||||
confidence=0.8,
|
||||
source_span="single source"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = single_batch_obj
|
||||
|
||||
# Process single-item batch object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table was ensured
|
||||
processor.ensure_table.assert_called_once()
|
||||
|
||||
# Verify exactly one insert call
|
||||
processor.session.execute.assert_called_once()
|
||||
|
||||
insert_cql = processor.session.execute.call_args[0][0]
|
||||
values = processor.session.execute.call_args[0][1]
|
||||
|
||||
assert "INSERT INTO test_user.o_single_schema" in insert_cql
|
||||
assert values[0] == "single_collection" # collection
|
||||
assert values[1] == "single-1" # id value
|
||||
assert values[2] == "single data" # data value
|
||||
|
||||
def test_batch_value_conversion_logic(self):
|
||||
"""Test value conversion works correctly for batch items"""
|
||||
processor = MagicMock()
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
|
||||
# Test various conversion scenarios that would occur in batch processing
|
||||
test_cases = [
|
||||
# Integer conversions for batch items
|
||||
("123", "integer", 123),
|
||||
("456", "integer", 456),
|
||||
("789", "integer", 789),
|
||||
# Float conversions for batch items
|
||||
("12.5", "float", 12.5),
|
||||
("34.7", "float", 34.7),
|
||||
# Boolean conversions for batch items
|
||||
("true", "boolean", True),
|
||||
("false", "boolean", False),
|
||||
("1", "boolean", True),
|
||||
("0", "boolean", False),
|
||||
# String conversions for batch items
|
||||
(123, "string", "123"),
|
||||
(45.6, "string", "45.6"),
|
||||
]
|
||||
|
||||
for input_val, field_type, expected_output in test_cases:
|
||||
result = processor.convert_value(input_val, field_type)
|
||||
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"
|
||||
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant row embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key='test-api-key'
|
||||
)
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
mock_qdrant_client.assert_called_once_with(
|
||||
url='http://localhost:6333', api_key=None
|
||||
)
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_sanitize_name(self, mock_qdrant_client):
|
||||
"""Test name sanitization for Qdrant collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Test basic sanitization
|
||||
assert processor.sanitize_name("simple") == "simple"
|
||||
assert processor.sanitize_name("with-dash") == "with_dash"
|
||||
assert processor.sanitize_name("with.dot") == "with_dot"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
|
||||
# Test numeric prefix handling
|
||||
assert processor.sanitize_name("123start") == "r_123start"
|
||||
assert processor.sanitize_name("_underscore") == "r__underscore"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_get_collection_name(self, mock_qdrant_client):
|
||||
"""Test Qdrant collection name generation"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
collection_name = processor.get_collection_name(
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
schema_name="customer_data",
|
||||
dimension=384
|
||||
)
|
||||
|
||||
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection creates a new collection when needed"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection skips creation when collection exists"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
|
||||
"""Test that ensure_collection uses cache for previously created collections"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing basic row embeddings message"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Create embeddings message
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='customer_id',
|
||||
index_value=['CUST001'],
|
||||
text='CUST001',
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='customers',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
# Mock message wrapper
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['index_name'] == 'customer_id'
|
||||
assert point.payload['index_value'] == ['CUST001']
|
||||
assert point.payload['text'] == 'CUST001'
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test processing embeddings with multiple vectors"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with multiple vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='name',
|
||||
index_value=['John Doe'],
|
||||
text='John Doe',
|
||||
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='people',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||
"""Test that embeddings with no vectors are skipped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'test_user'
|
||||
metadata.collection = 'test_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
# Embedding with no vectors
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[] # Empty vectors
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for empty vectors
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
|
||||
"""Test that messages for unknown collections are dropped"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
# No collections registered
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.user = 'unknown_user'
|
||||
metadata.collection = 'unknown_collection'
|
||||
metadata.id = 'doc-123'
|
||||
|
||||
embedding = RowIndexEmbedding(
|
||||
index_name='id',
|
||||
index_value=['123'],
|
||||
text='123',
|
||||
vectors=[[0.1, 0.2]]
|
||||
)
|
||||
|
||||
embeddings_msg = RowEmbeddings(
|
||||
metadata=metadata,
|
||||
schema_name='items',
|
||||
embeddings=[embedding]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = embeddings_msg
|
||||
|
||||
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||
|
||||
# Should not call upsert for unknown collection
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection(self, mock_qdrant_client):
|
||||
"""Test deleting all collections for a user/collection"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
# Mock collections list
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||
mock_coll3 = MagicMock()
|
||||
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_user', 'test_collection')
|
||||
|
||||
# Should delete only the matching collections
|
||||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
"""Test deleting collections for a specific schema"""
|
||||
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||
|
||||
mock_qdrant_instance = MagicMock()
|
||||
|
||||
mock_coll1 = MagicMock()
|
||||
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||
mock_coll2 = MagicMock()
|
||||
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||
|
||||
mock_collections = MagicMock()
|
||||
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
await processor.delete_collection_schema(
|
||||
'test_user', 'test_collection', 'customers'
|
||||
)
|
||||
|
||||
# Should only delete the customers schema collection
|
||||
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
"""
|
||||
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
|
||||
|
||||
Tests the business logic of the row storage processor including:
|
||||
- Schema configuration handling
|
||||
- Name sanitization
|
||||
- Unified table structure
|
||||
- Index management
|
||||
- Row storage with multi-index support
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import json
|
||||
|
||||
from trustgraph.storage.rows.cassandra.write import Processor
|
||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||
|
||||
|
||||
class TestRowsCassandraStorageLogic:
|
||||
"""Test business logic for unified table implementation"""
|
||||
|
||||
def test_sanitize_name(self):
|
||||
"""Test name sanitization for Cassandra compatibility"""
|
||||
processor = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
|
||||
# Test various name patterns
|
||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
|
||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
|
||||
|
||||
def test_get_index_names(self):
|
||||
"""Test extraction of index names from schema"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
# Schema with primary and indexed fields
|
||||
schema = RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="name", type="string"), # Not indexed
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
|
||||
# Should include primary key and indexed fields
|
||||
assert "id" in index_names
|
||||
assert "category" in index_names
|
||||
assert "status" in index_names
|
||||
assert "name" not in index_names # Not indexed
|
||||
assert len(index_names) == 3
|
||||
|
||||
def test_get_index_names_no_indexes(self):
|
||||
"""Test schema with no indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
|
||||
schema = RowSchema(
|
||||
name="no_index_schema",
|
||||
fields=[
|
||||
Field(name="data1", type="string"),
|
||||
Field(name="data2", type="string")
|
||||
]
|
||||
)
|
||||
|
||||
index_names = processor.get_index_names(schema)
|
||||
assert len(index_names) == 0
|
||||
|
||||
def test_build_index_value(self):
|
||||
"""Test building index values from row data"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
|
||||
|
||||
# Single field index
|
||||
result = processor.build_index_value(value_map, "id")
|
||||
assert result == ["123"]
|
||||
|
||||
result = processor.build_index_value(value_map, "category")
|
||||
assert result == ["electronics"]
|
||||
|
||||
# Missing field returns empty string
|
||||
result = processor.build_index_value(value_map, "missing")
|
||||
assert result == [""]
|
||||
|
||||
def test_build_index_value_composite(self):
|
||||
"""Test building composite index values"""
|
||||
processor = MagicMock()
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
|
||||
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
|
||||
|
||||
# Composite index (comma-separated field names)
|
||||
result = processor.build_index_value(value_map, "region,category")
|
||||
assert result == ["us-west", "electronics"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
config = {
|
||||
"schema": {
|
||||
"customer_records": json.dumps({
|
||||
"name": "customer_records",
|
||||
"description": "Customer data",
|
||||
"fields": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"type": "string",
|
||||
"indexed": True
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Process configuration
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Verify schema was loaded
|
||||
assert "customer_records" in processor.schemas
|
||||
schema = processor.schemas["customer_records"]
|
||||
assert schema.name == "customer_records"
|
||||
assert len(schema.fields) == 3
|
||||
|
||||
# Check field properties
|
||||
id_field = schema.fields[0]
|
||||
assert id_field.name == "id"
|
||||
assert id_field.type == "string"
|
||||
assert id_field.primary is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_stores_data_map(self):
|
||||
"""Test that row processing stores data as map<text, text>"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
description="Test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="value", type="string", size=100)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="test_schema",
|
||||
values=[{"id": "123", "value": "test_data"}],
|
||||
confidence=0.9,
|
||||
source_span="test source"
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
# Process object
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify insert was executed
|
||||
processor.session.execute.assert_called()
|
||||
insert_call = processor.session.execute.call_args
|
||||
insert_cql = insert_call[0][0]
|
||||
values = insert_call[0][1]
|
||||
|
||||
# Verify using unified rows table
|
||||
assert "INSERT INTO test_user.rows" in insert_cql
|
||||
|
||||
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||
assert values[0] == "test_collection" # collection
|
||||
assert values[1] == "test_schema" # schema_name
|
||||
assert values[2] == "id" # index_name (primary key field)
|
||||
assert values[3] == ["123"] # index_value as list
|
||||
assert values[4] == {"id": "123", "value": "test_data"} # data map
|
||||
assert values[5] == "" # source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_processing_multiple_indexes(self):
|
||||
"""Test that row is written once per indexed field"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"multi_index_schema": RowSchema(
|
||||
name="multi_index_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True),
|
||||
Field(name="status", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="test-001",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="multi_index_schema",
|
||||
values=[{"id": "123", "category": "electronics", "status": "active"}],
|
||||
confidence=0.9,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = test_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check that different index_names were used
|
||||
index_names_used = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
index_names_used.add(values[2]) # index_name is 3rd value
|
||||
|
||||
assert index_names_used == {"id", "category", "status"}
|
||||
|
||||
|
||||
class TestRowsCassandraStorageBatchLogic:
|
||||
"""Test batch processing logic for unified table implementation"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self):
|
||||
"""Test processing of batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"batch_schema": RowSchema(
|
||||
name="batch_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="name", type="string")
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_schema",
|
||||
values=[
|
||||
{"id": "001", "name": "First"},
|
||||
{"id": "002", "name": "Second"},
|
||||
{"id": "003", "name": "Third"}
|
||||
],
|
||||
confidence=0.95,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
# Check each insert has different id
|
||||
ids_inserted = set()
|
||||
for call in processor.session.execute.call_args_list:
|
||||
values = call[0][1]
|
||||
ids_inserted.add(tuple(values[3])) # index_value is 4th value
|
||||
|
||||
assert ids_inserted == {("001",), ("002",), ("003",)}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self):
|
||||
"""Test processing of empty batch ExtractedObjects"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {
|
||||
"empty_schema": RowSchema(
|
||||
name="empty_schema",
|
||||
fields=[Field(name="id", type="string", primary=True)]
|
||||
)
|
||||
}
|
||||
processor.tables_initialized = {"test_user"}
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.ensure_tables = MagicMock()
|
||||
processor.register_partitions = MagicMock()
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="empty-001",
|
||||
user="test_user",
|
||||
collection="empty_collection",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="empty_schema",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span=""
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify no insert calls for empty batch
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestUnifiedTableStructure:
|
||||
"""Test the unified rows table structure"""
|
||||
|
||||
def test_ensure_tables_creates_unified_structure(self):
|
||||
"""Test that ensure_tables creates the unified rows table"""
|
||||
processor = MagicMock()
|
||||
processor.known_keyspaces = {"test_user"}
|
||||
processor.tables_initialized = set()
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should have 2 calls: create rows table + create row_partitions table
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Check rows table creation
|
||||
rows_cql = processor.session.execute.call_args_list[0][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
|
||||
assert "collection text" in rows_cql
|
||||
assert "schema_name text" in rows_cql
|
||||
assert "index_name text" in rows_cql
|
||||
assert "index_value frozen<list<text>>" in rows_cql
|
||||
assert "data map<text, text>" in rows_cql
|
||||
assert "source text" in rows_cql
|
||||
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
|
||||
|
||||
# Check row_partitions table creation
|
||||
partitions_cql = processor.session.execute.call_args_list[1][0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
|
||||
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
|
||||
|
||||
# Verify keyspace added to initialized set
|
||||
assert "test_user" in processor.tables_initialized
|
||||
|
||||
def test_ensure_tables_idempotent(self):
|
||||
"""Test that ensure_tables is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.tables_initialized = {"test_user"} # Already initialized
|
||||
processor.session = MagicMock()
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
||||
processor.ensure_tables("test_user")
|
||||
|
||||
# Should not execute any CQL since already initialized
|
||||
processor.session.execute.assert_not_called()
|
||||
|
||||
|
||||
class TestPartitionRegistration:
|
||||
"""Test partition registration for tracking what's stored"""
|
||||
|
||||
def test_register_partitions(self):
|
||||
"""Test registering partitions for a collection/schema pair"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = set()
|
||||
processor.session = MagicMock()
|
||||
processor.schemas = {
|
||||
"test_schema": RowSchema(
|
||||
name="test_schema",
|
||||
fields=[
|
||||
Field(name="id", type="string", primary=True),
|
||||
Field(name="category", type="string", indexed=True)
|
||||
]
|
||||
)
|
||||
}
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should have 2 inserts (one per index: id, category)
|
||||
assert processor.session.execute.call_count == 2
|
||||
|
||||
# Verify cache was updated
|
||||
assert ("test_collection", "test_schema") in processor.registered_partitions
|
||||
|
||||
def test_register_partitions_idempotent(self):
|
||||
"""Test that partition registration is idempotent"""
|
||||
processor = MagicMock()
|
||||
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
|
||||
processor.session = MagicMock()
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
|
||||
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||
|
||||
# Should not execute any CQL since already registered
|
||||
processor.session.execute.assert_not_called()
|
||||
|
|
@ -6,7 +6,8 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Triple, LITERAL, IRI
|
||||
from trustgraph.direct.cassandra_kg import DEFAULT_GRAPH
|
||||
|
||||
|
||||
class TestCassandraStorageProcessor:
|
||||
|
|
@ -86,29 +87,29 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_username == 'new-user' # Only cassandra_* params work
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_switching_with_auth(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_switching_with_auth(self, mock_kg_class):
|
||||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
cassandra_username='testuser',
|
||||
cassandra_password='testpass'
|
||||
)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was called with auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user1',
|
||||
username='testuser',
|
||||
|
|
@ -117,128 +118,150 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_switching_without_auth(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_switching_without_auth(self, mock_kg_class):
|
||||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify KnowledgeGraph was called without auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_table_reuse_when_same(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_table_reuse_when_same(self, mock_kg_class):
|
||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
assert mock_kg_class.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
assert mock_kg_class.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_triple_insertion(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_triple_insertion(self, mock_kg_class):
|
||||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
|
||||
# Create mock triples with proper Term structure
|
||||
triple1 = MagicMock()
|
||||
triple1.s.type = LITERAL
|
||||
triple1.s.value = 'subject1'
|
||||
triple1.s.datatype = ''
|
||||
triple1.s.language = ''
|
||||
triple1.p.type = LITERAL
|
||||
triple1.p.value = 'predicate1'
|
||||
triple1.o.type = LITERAL
|
||||
triple1.o.value = 'object1'
|
||||
|
||||
triple1.o.datatype = ''
|
||||
triple1.o.language = ''
|
||||
triple1.g = None
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.type = LITERAL
|
||||
triple2.s.value = 'subject2'
|
||||
triple2.s.datatype = ''
|
||||
triple2.s.language = ''
|
||||
triple2.p.type = LITERAL
|
||||
triple2.p.value = 'predicate2'
|
||||
triple2.o.type = LITERAL
|
||||
triple2.o.value = 'object2'
|
||||
|
||||
triple2.o.datatype = ''
|
||||
triple2.o.language = ''
|
||||
triple2.g = None
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify both triples were inserted
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1')
|
||||
mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2')
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_kg_class):
|
||||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
|
|
@ -326,92 +349,104 @@ class TestCassandraStorageProcessor:
|
|||
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_kg_class):
|
||||
"""Test table switching when different tables are used in sequence"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
mock_kg_class.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
assert mock_trustgraph.call_count == 2
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_kg_class):
|
||||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
|
||||
# Create triple with special characters and proper Term structure
|
||||
triple = MagicMock()
|
||||
triple.s.type = LITERAL
|
||||
triple.s.value = 'subject with spaces & symbols'
|
||||
triple.s.datatype = ''
|
||||
triple.s.language = ''
|
||||
triple.p.type = LITERAL
|
||||
triple.p.value = 'predicate:with/colons'
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
|
||||
|
||||
triple.o.datatype = ''
|
||||
triple.o.language = ''
|
||||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
'object with "quotes" and unicode: ñáéíóú'
|
||||
'object with "quotes" and unicode: ñáéíóú',
|
||||
g=DEFAULT_GRAPH,
|
||||
otype='l',
|
||||
dtype='',
|
||||
lang=''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
|
|
@ -422,12 +457,12 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test cases for multi-table performance optimizations"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_legacy_mode_uses_single_table(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_legacy_mode_uses_single_table(self, mock_kg_class):
|
||||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -440,16 +475,15 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance uses legacy mode
|
||||
kg_instance = mock_trustgraph.return_value
|
||||
assert kg_instance is not None
|
||||
assert mock_tg_instance is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_optimized_mode_uses_multi_table(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_optimized_mode_uses_multi_table(self, mock_kg_class):
|
||||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -462,24 +496,31 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify KnowledgeGraph instance is in optimized mode
|
||||
kg_instance = mock_trustgraph.return_value
|
||||
assert kg_instance is not None
|
||||
assert mock_tg_instance is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
|
||||
async def test_batch_write_consistency(self, mock_trustgraph):
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_batch_write_consistency(self, mock_kg_class):
|
||||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create test triple
|
||||
# Create test triple with proper Term structure
|
||||
triple = MagicMock()
|
||||
triple.s.type = LITERAL
|
||||
triple.s.value = 'test_subject'
|
||||
triple.s.datatype = ''
|
||||
triple.s.language = ''
|
||||
triple.p.type = LITERAL
|
||||
triple.p.value = 'test_predicate'
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'test_object'
|
||||
triple.o.datatype = ''
|
||||
triple.o.language = ''
|
||||
triple.g = None
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
|
|
@ -490,7 +531,8 @@ class TestCassandraPerformanceOptimizations:
|
|||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object'
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
||||
def test_environment_variable_controls_mode(self):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.falkordb.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestFalkorDBStorageProcessor:
|
||||
|
|
@ -22,9 +22,9 @@ class TestFalkorDBStorageProcessor:
|
|||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -183,9 +183,9 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=IRI, iri='http://example.com/object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -269,14 +269,14 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object1')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
@ -337,14 +337,14 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.storage.triples.memgraph.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
from trustgraph.schema import Term, Triple, IRI, LITERAL
|
||||
|
||||
|
||||
class TestMemgraphStorageProcessor:
|
||||
|
|
@ -22,9 +22,9 @@ class TestMemgraphStorageProcessor:
|
|||
|
||||
# Create a test triple
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
message.triples = [triple]
|
||||
|
||||
|
|
@ -231,9 +231,9 @@ class TestMemgraphStorageProcessor:
|
|||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='http://example.com/object', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=IRI, iri='http://example.com/object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
|
|
@ -265,9 +265,9 @@ class TestMemgraphStorageProcessor:
|
|||
mock_tx = MagicMock()
|
||||
|
||||
triple = Triple(
|
||||
s=Value(value='http://example.com/subject', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate', is_uri=True),
|
||||
o=Value(value='literal object', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate'),
|
||||
o=Term(type=LITERAL, value='literal object')
|
||||
)
|
||||
|
||||
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
|
||||
|
|
@ -347,14 +347,14 @@ class TestMemgraphStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
|
||||
triple1 = Triple(
|
||||
s=Value(value='http://example.com/subject1', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate1', is_uri=True),
|
||||
o=Value(value='literal object1', is_uri=False)
|
||||
s=Term(type=IRI, iri='http://example.com/subject1'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate1'),
|
||||
o=Term(type=LITERAL, value='literal object1')
|
||||
)
|
||||
triple2 = Triple(
|
||||
s=Value(value='http://example.com/subject2', is_uri=True),
|
||||
p=Value(value='http://example.com/predicate2', is_uri=True),
|
||||
o=Value(value='http://example.com/object2', is_uri=True)
|
||||
s=Term(type=IRI, iri='http://example.com/subject2'),
|
||||
p=Term(type=IRI, iri='http://example.com/predicate2'),
|
||||
o=Term(type=IRI, iri='http://example.com/object2')
|
||||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.neo4j.write import Processor
|
||||
from trustgraph.schema import IRI, LITERAL
|
||||
|
||||
|
||||
class TestNeo4jStorageProcessor:
|
||||
|
|
@ -257,10 +258,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = IRI
|
||||
triple.o.iri = "http://example.com/object"
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -327,10 +330,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triple with literal object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = "literal value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -398,16 +403,20 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = "http://example.com/subject1"
|
||||
triple1.p.value = "http://example.com/predicate1"
|
||||
triple1.o.value = "http://example.com/object1"
|
||||
triple1.o.is_uri = True
|
||||
|
||||
triple1.s.type = IRI
|
||||
triple1.s.iri = "http://example.com/subject1"
|
||||
triple1.p.type = IRI
|
||||
triple1.p.iri = "http://example.com/predicate1"
|
||||
triple1.o.type = IRI
|
||||
triple1.o.iri = "http://example.com/object1"
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = "http://example.com/subject2"
|
||||
triple2.p.value = "http://example.com/predicate2"
|
||||
triple2.s.type = IRI
|
||||
triple2.s.iri = "http://example.com/subject2"
|
||||
triple2.p.type = IRI
|
||||
triple2.p.iri = "http://example.com/predicate2"
|
||||
triple2.o.type = LITERAL
|
||||
triple2.o.value = "literal value"
|
||||
triple2.o.is_uri = False
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
|
|
@ -550,10 +559,12 @@ class TestNeo4jStorageProcessor:
|
|||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject with spaces"
|
||||
triple.p.value = "http://example.com/predicate:with/symbols"
|
||||
triple.s.type = IRI
|
||||
triple.s.iri = "http://example.com/subject with spaces"
|
||||
triple.p.type = IRI
|
||||
triple.p.iri = "http://example.com/predicate:with/symbols"
|
||||
triple.o.type = LITERAL
|
||||
triple.o.value = 'literal with "quotes" and unicode: ñáéíóú'
|
||||
triple.o.is_uri = False
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vertexai
|
||||
Starting simple with one test to get the basics working
|
||||
Updated for google-genai SDK
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
|
|||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Simple test for processor initialization"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test basic processor initialization with mocked dependencies"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization to avoid taskgroup requirement
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
assert hasattr(processor, 'generation_configs') # Cache dictionary
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
assert hasattr(processor, 'client') # genai.Client
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project="test-project-123",
|
||||
location="us-central1",
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Gemini"
|
||||
mock_response.usage_metadata.prompt_token_count = 15
|
||||
mock_response.usage_metadata.candidates_token_count = 8
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
# Check that the method was called (actual prompt format may vary)
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# Generation config is now created dynamically per model
|
||||
assert 'generation_config' in call_args[1]
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test handling of blocked content (safety filters)"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Blocked content returns None
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 0
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
|
||||
"""Test processor initialization without private key (uses default credentials)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
||||
# Mock google.auth.default() to return credentials and project ID
|
||||
mock_credentials = MagicMock()
|
||||
mock_auth_default.return_value = (mock_credentials, "test-project-123")
|
||||
|
||||
# Mock GenerativeModel
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
|
|
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
project='test-project-123'
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project="test-project-123",
|
||||
location="us-central1",
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = Exception("Network error")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.side_effect = Exception("Network error")
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(Exception, match="Network error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
|
||||
# Verify that generation_config cache exists
|
||||
assert hasattr(processor, 'generation_configs')
|
||||
assert processor.generation_configs == {} # Empty cache initially
|
||||
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
||||
|
||||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
# Verify that api_params dict has the correct values (this is accessible)
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
|
||||
# Verify that api_params dict has the correct values
|
||||
assert processor.api_params["temperature"] == 0.7
|
||||
assert processor.api_params["max_output_tokens"] == 4096
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test that VertexAI is initialized correctly with credentials"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify VertexAI init was called with correct parameters
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
# Verify genai.Client was called with correct parameters
|
||||
mock_genai.Client.assert_called_once_with(
|
||||
vertexai=True,
|
||||
project='test-project-123',
|
||||
location='europe-west1',
|
||||
credentials=mock_credentials,
|
||||
project='test-project-123'
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
# GenerativeModel is now created lazily on first use, not at initialization
|
||||
mock_generative_model.assert_not_called()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
|
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the model was called with the combined empty prompts
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
# Verify the client was called
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
|
||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
|
||||
"""Test Anthropic processor initialization with private key credentials"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-456"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
# Mock AnthropicVertex
|
||||
mock_anthropic_client = MagicMock()
|
||||
mock_anthropic_vertex.return_value = mock_anthropic_client
|
||||
|
|
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||
# is_anthropic logic is now determined dynamically per request
|
||||
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
||||
# Verify AnthropicVertex was initialized with credentials
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
||||
|
||||
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
|
||||
mock_anthropic_vertex.assert_called_once_with(
|
||||
region='us-west1',
|
||||
project_id='test-project-456',
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
|
||||
# Verify api_params are set correctly
|
||||
assert processor.api_params["temperature"] == 0.5
|
||||
assert processor.api_params["max_output_tokens"] == 2048
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'temperature': 0.2,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
|
|
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
# Verify the call was made with the override model
|
||||
call_args = mock_client.models.generate_content.call_args
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-pro"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
|
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
mock_client.models.generate_content.assert_called_once()
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
# Verify the model override was used
|
||||
call_args = mock_client.models.generate_content.call_args
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ from .async_metrics import AsyncMetrics
|
|||
# Types
|
||||
from .types import (
|
||||
Triple,
|
||||
Uri,
|
||||
Literal,
|
||||
ConfigKey,
|
||||
ConfigValue,
|
||||
DocumentMetadata,
|
||||
|
|
@ -99,7 +101,7 @@ from .exceptions import (
|
|||
LoadError,
|
||||
LookupError,
|
||||
NLPQueryError,
|
||||
ObjectsQueryError,
|
||||
RowsQueryError,
|
||||
RequestError,
|
||||
StructuredQueryError,
|
||||
UnexpectedError,
|
||||
|
|
@ -133,6 +135,8 @@ __all__ = [
|
|||
|
||||
# Types
|
||||
"Triple",
|
||||
"Uri",
|
||||
"Literal",
|
||||
"ConfigKey",
|
||||
"ConfigValue",
|
||||
"DocumentMetadata",
|
||||
|
|
@ -157,7 +161,7 @@ __all__ = [
|
|||
"LoadError",
|
||||
"LookupError",
|
||||
"NLPQueryError",
|
||||
"ObjectsQueryError",
|
||||
"RowsQueryError",
|
||||
"RequestError",
|
||||
"StructuredQueryError",
|
||||
"UnexpectedError",
|
||||
|
|
|
|||
|
|
@ -115,15 +115,15 @@ class AsyncBulkClient:
|
|||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import objects via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import rows via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for obj in objects:
|
||||
await websocket.send(json.dumps(obj))
|
||||
async for row in rows:
|
||||
await websocket.send(json.dumps(row))
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close connections"""
|
||||
|
|
|
|||
|
|
@ -612,8 +612,12 @@ class AsyncFlowInstance:
|
|||
print(f"{entity['name']}: {entity['score']}")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request_data = {
|
||||
"text": text,
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -704,18 +708,18 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("triples", request_data)
|
||||
|
||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs: Any):
|
||||
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs: Any):
|
||||
"""
|
||||
Execute a GraphQL query on stored objects.
|
||||
Execute a GraphQL query on stored rows.
|
||||
|
||||
Queries structured data objects using GraphQL syntax. Supports complex
|
||||
Queries structured data rows using GraphQL syntax. Supports complex
|
||||
queries with variables and named operations.
|
||||
|
||||
Args:
|
||||
query: GraphQL query string
|
||||
user: User identifier
|
||||
collection: Collection identifier containing objects
|
||||
collection: Collection identifier containing rows
|
||||
variables: Optional GraphQL query variables
|
||||
operation_name: Optional operation name for multi-operation queries
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
|
@ -739,7 +743,7 @@ class AsyncFlowInstance:
|
|||
}
|
||||
'''
|
||||
|
||||
result = await flow.objects_query(
|
||||
result = await flow.rows_query(
|
||||
query=query,
|
||||
user="trustgraph",
|
||||
collection="users",
|
||||
|
|
@ -761,4 +765,64 @@ class AsyncFlowInstance:
|
|||
request_data["operationName"] = operation_name
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("objects", request_data)
|
||||
return await self.request("rows", request_data)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs: Any
|
||||
):
|
||||
"""
|
||||
Query row embeddings for semantic search on structured data.
|
||||
|
||||
Performs semantic search over row index embeddings to find rows whose
|
||||
indexed field values are most similar to the input text. Enables
|
||||
fuzzy/semantic matching on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
dict: Response containing matches with index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
async_flow = await api.async_flow()
|
||||
flow = async_flow.id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = await flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
for match in results.get("matches", []):
|
||||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request_data["index_name"] = index_name
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("row-embeddings", request_data)
|
||||
|
|
|
|||
|
|
@ -282,8 +282,12 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"text": text,
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -316,9 +320,9 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("triples", self.flow_id, request)
|
||||
|
||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs):
|
||||
"""GraphQL query"""
|
||||
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs):
|
||||
"""GraphQL query against structured rows"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
|
|
@ -330,7 +334,7 @@ class AsyncSocketFlowInstance:
|
|||
request["operationName"] = operation_name
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("objects", self.flow_id, request)
|
||||
return await self.client._send_request("rows", self.flow_id, request)
|
||||
|
||||
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
|
||||
"""Execute MCP tool"""
|
||||
|
|
@ -341,3 +345,26 @@ class AsyncSocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs
|
||||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("row-embeddings", self.flow_id, request)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,15 @@ from . types import Triple
|
|||
from . exceptions import ProtocolException
|
||||
|
||||
|
||||
def _string_to_term(value: str) -> Dict[str, Any]:
|
||||
"""Convert a string value to Term format for the gateway."""
|
||||
# Treat URIs as IRI type, otherwise as literal
|
||||
if value.startswith("http://") or value.startswith("https://") or "://" in value:
|
||||
return {"t": "i", "i": value}
|
||||
else:
|
||||
return {"t": "l", "v": value}
|
||||
|
||||
|
||||
class BulkClient:
|
||||
"""
|
||||
Synchronous bulk operations client for import/export.
|
||||
|
|
@ -62,7 +71,12 @@ class BulkClient:
|
|||
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None:
|
||||
def import_triples(
|
||||
self, flow: str, triples: Iterator[Triple],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 100,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Bulk import RDF triples into a flow.
|
||||
|
||||
|
|
@ -71,6 +85,8 @@ class BulkClient:
|
|||
Args:
|
||||
flow: Flow identifier
|
||||
triples: Iterator yielding Triple objects
|
||||
metadata: Metadata dict with id, metadata, user, collection
|
||||
batch_size: Number of triples per batch (default 100)
|
||||
**kwargs: Additional parameters (reserved for future use)
|
||||
|
||||
Example:
|
||||
|
|
@ -86,23 +102,47 @@ class BulkClient:
|
|||
# ... more triples
|
||||
|
||||
# Import triples
|
||||
bulk.import_triples(flow="default", triples=triple_generator())
|
||||
bulk.import_triples(
|
||||
flow="default",
|
||||
triples=triple_generator(),
|
||||
metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
self._run_async(self._import_triples_async(flow, triples))
|
||||
self._run_async(self._import_triples_async(flow, triples, metadata, batch_size))
|
||||
|
||||
async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None:
|
||||
async def _import_triples_async(
|
||||
self, flow: str, triples: Iterator[Triple],
|
||||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of triple import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"}
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
batch = []
|
||||
for triple in triples:
|
||||
batch.append({
|
||||
"s": _string_to_term(triple.s),
|
||||
"p": _string_to_term(triple.p),
|
||||
"o": _string_to_term(triple.o)
|
||||
})
|
||||
if len(batch) >= batch_size:
|
||||
message = {
|
||||
"metadata": metadata,
|
||||
"triples": batch
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
batch = []
|
||||
# Send remaining items
|
||||
if batch:
|
||||
message = {
|
||||
"s": triple.s,
|
||||
"p": triple.p,
|
||||
"o": triple.o
|
||||
"metadata": metadata,
|
||||
"triples": batch
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
|
|
@ -362,7 +402,12 @@ class BulkClient:
|
|||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
def import_entity_contexts(
|
||||
self, flow: str, contexts: Iterator[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 100,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Bulk import entity contexts into a flow.
|
||||
|
||||
|
|
@ -373,6 +418,8 @@ class BulkClient:
|
|||
Args:
|
||||
flow: Flow identifier
|
||||
contexts: Iterator yielding context dictionaries
|
||||
metadata: Metadata dict with id, metadata, user, collection
|
||||
batch_size: Number of contexts per batch (default 100)
|
||||
**kwargs: Additional parameters (reserved for future use)
|
||||
|
||||
Example:
|
||||
|
|
@ -381,27 +428,49 @@ class BulkClient:
|
|||
|
||||
# Generate entity contexts to import
|
||||
def context_generator():
|
||||
yield {"entity": "entity1", "context": "Description of entity1..."}
|
||||
yield {"entity": "entity2", "context": "Description of entity2..."}
|
||||
yield {"entity": {"v": "entity1", "e": True}, "context": "Description..."}
|
||||
yield {"entity": {"v": "entity2", "e": True}, "context": "Description..."}
|
||||
# ... more contexts
|
||||
|
||||
bulk.import_entity_contexts(
|
||||
flow="default",
|
||||
contexts=context_generator()
|
||||
contexts=context_generator(),
|
||||
metadata={"id": "doc1", "metadata": [], "user": "user1", "collection": "default"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
self._run_async(self._import_entity_contexts_async(flow, contexts))
|
||||
self._run_async(self._import_entity_contexts_async(flow, contexts, metadata, batch_size))
|
||||
|
||||
async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None:
|
||||
async def _import_entity_contexts_async(
|
||||
self, flow: str, contexts: Iterator[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of entity contexts import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "user": "trustgraph", "collection": "default"}
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
batch = []
|
||||
for context in contexts:
|
||||
await websocket.send(json.dumps(context))
|
||||
batch.append(context)
|
||||
if len(batch) >= batch_size:
|
||||
message = {
|
||||
"metadata": metadata,
|
||||
"entities": batch
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
batch = []
|
||||
# Send remaining items
|
||||
if batch:
|
||||
message = {
|
||||
"metadata": metadata,
|
||||
"entities": batch
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -461,45 +530,45 @@ class BulkClient:
|
|||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""
|
||||
Bulk import structured objects into a flow.
|
||||
Bulk import structured rows into a flow.
|
||||
|
||||
Efficiently uploads structured data objects via WebSocket streaming
|
||||
Efficiently uploads structured data rows via WebSocket streaming
|
||||
for use in GraphQL queries.
|
||||
|
||||
Args:
|
||||
flow: Flow identifier
|
||||
objects: Iterator yielding object dictionaries
|
||||
rows: Iterator yielding row dictionaries
|
||||
**kwargs: Additional parameters (reserved for future use)
|
||||
|
||||
Example:
|
||||
```python
|
||||
bulk = api.bulk()
|
||||
|
||||
# Generate objects to import
|
||||
def object_generator():
|
||||
yield {"id": "obj1", "name": "Object 1", "value": 100}
|
||||
yield {"id": "obj2", "name": "Object 2", "value": 200}
|
||||
# ... more objects
|
||||
# Generate rows to import
|
||||
def row_generator():
|
||||
yield {"id": "row1", "name": "Row 1", "value": 100}
|
||||
yield {"id": "row2", "name": "Row 2", "value": 200}
|
||||
# ... more rows
|
||||
|
||||
bulk.import_objects(
|
||||
bulk.import_rows(
|
||||
flow="default",
|
||||
objects=object_generator()
|
||||
rows=row_generator()
|
||||
)
|
||||
```
|
||||
"""
|
||||
self._run_async(self._import_objects_async(flow, objects))
|
||||
self._run_async(self._import_rows_async(flow, rows))
|
||||
|
||||
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of objects import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of rows import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for obj in objects:
|
||||
await websocket.send(json.dumps(obj))
|
||||
for row in rows:
|
||||
await websocket.send(json.dumps(row))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close connections"""
|
||||
|
|
|
|||
|
|
@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException):
|
|||
pass
|
||||
|
||||
|
||||
class ObjectsQueryError(TrustGraphException):
|
||||
"""Objects query service error"""
|
||||
class RowsQueryError(TrustGraphException):
|
||||
"""Rows query service error"""
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = {
|
|||
"load-error": LoadError,
|
||||
"lookup-error": LookupError,
|
||||
"nlp-query-error": NLPQueryError,
|
||||
"objects-query-error": ObjectsQueryError,
|
||||
"rows-query-error": RowsQueryError,
|
||||
"request-error": RequestError,
|
||||
"structured-query-error": StructuredQueryError,
|
||||
"unexpected-error": UnexpectedError,
|
||||
|
|
|
|||
|
|
@ -10,12 +10,27 @@ import json
|
|||
import base64
|
||||
|
||||
from .. knowledge import hash, Uri, Literal
|
||||
from .. schema import IRI, LITERAL
|
||||
from . types import Triple
|
||||
from . exceptions import ProtocolException
|
||||
|
||||
|
||||
def to_value(x):
|
||||
if x["e"]: return Uri(x["v"])
|
||||
return Literal(x["v"])
|
||||
"""Convert wire format to Uri or Literal."""
|
||||
if x.get("t") == IRI:
|
||||
return Uri(x.get("i", ""))
|
||||
elif x.get("t") == LITERAL:
|
||||
return Literal(x.get("v", ""))
|
||||
# Fallback for any other type
|
||||
return Literal(x.get("v", x.get("i", "")))
|
||||
|
||||
|
||||
def from_value(v):
|
||||
"""Convert Uri or Literal to wire format."""
|
||||
if isinstance(v, Uri):
|
||||
return {"t": IRI, "i": str(v)}
|
||||
else:
|
||||
return {"t": LITERAL, "v": str(v)}
|
||||
|
||||
class Flow:
|
||||
"""
|
||||
|
|
@ -569,9 +584,13 @@ class FlowInstance:
|
|||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
# Query graph embeddings for semantic search
|
||||
input = {
|
||||
"text": text,
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -582,6 +601,51 @@ class FlowInstance:
|
|||
input
|
||||
)
|
||||
|
||||
def document_embeddings_query(self, text, user, collection, limit=10):
|
||||
"""
|
||||
Query document chunks using semantic similarity.
|
||||
|
||||
Finds document chunks whose content is semantically similar to the
|
||||
input text, using vector embeddings.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
limit: Maximum number of results (default: 10)
|
||||
|
||||
Returns:
|
||||
dict: Query results with similar document chunks
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
results = flow.document_embeddings_query(
|
||||
text="machine learning algorithms",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
limit=5
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
# Query document embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/document-embeddings",
|
||||
input
|
||||
)
|
||||
|
||||
def prompt(self, id, variables):
|
||||
"""
|
||||
Execute a prompt template with variable substitution.
|
||||
|
|
@ -751,17 +815,17 @@ class FlowInstance:
|
|||
if s:
|
||||
if not isinstance(s, Uri):
|
||||
raise RuntimeError("s must be Uri")
|
||||
input["s"] = { "v": str(s), "e": isinstance(s, Uri), }
|
||||
|
||||
input["s"] = from_value(s)
|
||||
|
||||
if p:
|
||||
if not isinstance(p, Uri):
|
||||
raise RuntimeError("p must be Uri")
|
||||
input["p"] = { "v": str(p), "e": isinstance(p, Uri), }
|
||||
input["p"] = from_value(p)
|
||||
|
||||
if o:
|
||||
if not isinstance(o, Uri) and not isinstance(o, Literal):
|
||||
raise RuntimeError("o must be Uri or Literal")
|
||||
input["o"] = { "v": str(o), "e": isinstance(o, Uri), }
|
||||
input["o"] = from_value(o)
|
||||
|
||||
object = self.request(
|
||||
"service/triples",
|
||||
|
|
@ -834,9 +898,9 @@ class FlowInstance:
|
|||
if metadata:
|
||||
metadata.emit(
|
||||
lambda t: triples.append({
|
||||
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
|
||||
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
|
||||
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
|
||||
"s": from_value(t["s"]),
|
||||
"p": from_value(t["p"]),
|
||||
"o": from_value(t["o"]),
|
||||
})
|
||||
)
|
||||
|
||||
|
|
@ -913,9 +977,9 @@ class FlowInstance:
|
|||
if metadata:
|
||||
metadata.emit(
|
||||
lambda t: triples.append({
|
||||
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
|
||||
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
|
||||
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
|
||||
"s": from_value(t["s"]),
|
||||
"p": from_value(t["p"]),
|
||||
"o": from_value(t["o"]),
|
||||
})
|
||||
)
|
||||
|
||||
|
|
@ -937,12 +1001,12 @@ class FlowInstance:
|
|||
input
|
||||
)
|
||||
|
||||
def objects_query(
|
||||
def rows_query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
variables=None, operation_name=None
|
||||
):
|
||||
"""
|
||||
Execute a GraphQL query against structured objects in the knowledge graph.
|
||||
Execute a GraphQL query against structured rows in the knowledge graph.
|
||||
|
||||
Queries structured data using GraphQL syntax, allowing complex queries
|
||||
with filtering, aggregation, and relationship traversal.
|
||||
|
|
@ -974,7 +1038,7 @@ class FlowInstance:
|
|||
}
|
||||
}
|
||||
'''
|
||||
result = flow.objects_query(
|
||||
result = flow.rows_query(
|
||||
query=query,
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
|
|
@ -989,7 +1053,7 @@ class FlowInstance:
|
|||
}
|
||||
}
|
||||
'''
|
||||
result = flow.objects_query(
|
||||
result = flow.rows_query(
|
||||
query=query,
|
||||
variables={"name": "Marie Curie"}
|
||||
)
|
||||
|
|
@ -1010,7 +1074,7 @@ class FlowInstance:
|
|||
input["operation_name"] = operation_name
|
||||
|
||||
response = self.request(
|
||||
"service/objects",
|
||||
"service/rows",
|
||||
input
|
||||
)
|
||||
|
||||
|
|
@ -1233,3 +1297,78 @@ class FlowInstance:
|
|||
|
||||
return response["schema-matches"]
|
||||
|
||||
def row_embeddings_query(
|
||||
self, text, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10
|
||||
):
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
# Query row embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if index_name:
|
||||
input["index_name"] = index_name
|
||||
|
||||
response = self.request(
|
||||
"service/row-embeddings",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -10,11 +10,18 @@ import json
|
|||
import base64
|
||||
|
||||
from .. knowledge import hash, Uri, Literal
|
||||
from .. schema import IRI, LITERAL
|
||||
from . types import Triple
|
||||
|
||||
|
||||
def to_value(x):
|
||||
if x["e"]: return Uri(x["v"])
|
||||
return Literal(x["v"])
|
||||
"""Convert wire format to Uri or Literal."""
|
||||
if x.get("t") == IRI:
|
||||
return Uri(x.get("i", ""))
|
||||
elif x.get("t") == LITERAL:
|
||||
return Literal(x.get("v", ""))
|
||||
# Fallback for any other type
|
||||
return Literal(x.get("v", x.get("i", "")))
|
||||
|
||||
class Knowledge:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,13 +12,28 @@ import logging
|
|||
|
||||
from . types import DocumentMetadata, ProcessingMetadata, Triple
|
||||
from .. knowledge import hash, Uri, Literal
|
||||
from .. schema import IRI, LITERAL
|
||||
from . exceptions import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_value(x):
|
||||
if x["e"]: return Uri(x["v"])
|
||||
return Literal(x["v"])
|
||||
"""Convert wire format to Uri or Literal."""
|
||||
if x.get("t") == IRI:
|
||||
return Uri(x.get("i", ""))
|
||||
elif x.get("t") == LITERAL:
|
||||
return Literal(x.get("v", ""))
|
||||
# Fallback for any other type
|
||||
return Literal(x.get("v", x.get("i", "")))
|
||||
|
||||
|
||||
def from_value(v):
|
||||
"""Convert Uri or Literal to wire format."""
|
||||
if isinstance(v, Uri):
|
||||
return {"t": IRI, "i": str(v)}
|
||||
else:
|
||||
return {"t": LITERAL, "v": str(v)}
|
||||
|
||||
class Library:
|
||||
"""
|
||||
|
|
@ -118,18 +133,18 @@ class Library:
|
|||
if isinstance(metadata, list):
|
||||
triples = [
|
||||
{
|
||||
"s": { "v": t.s, "e": isinstance(t.s, Uri) },
|
||||
"p": { "v": t.p, "e": isinstance(t.p, Uri) },
|
||||
"o": { "v": t.o, "e": isinstance(t.o, Uri) }
|
||||
"s": from_value(t.s),
|
||||
"p": from_value(t.p),
|
||||
"o": from_value(t.o),
|
||||
}
|
||||
for t in metadata
|
||||
]
|
||||
elif hasattr(metadata, "emit"):
|
||||
metadata.emit(
|
||||
lambda t: triples.append({
|
||||
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
|
||||
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
|
||||
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
|
||||
"s": from_value(t["s"]),
|
||||
"p": from_value(t["p"]),
|
||||
"o": from_value(t["o"]),
|
||||
})
|
||||
)
|
||||
else:
|
||||
|
|
@ -315,9 +330,9 @@ class Library:
|
|||
"comments": metadata.comments,
|
||||
"metadata": [
|
||||
{
|
||||
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
|
||||
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
|
||||
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
|
||||
"s": from_value(t["s"]),
|
||||
"p": from_value(t["p"]),
|
||||
"o": from_value(t["o"]),
|
||||
}
|
||||
for t in metadata.metadata
|
||||
],
|
||||
|
|
|
|||
|
|
@ -649,8 +649,12 @@ class SocketFlowInstance:
|
|||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"text": text,
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -659,6 +663,54 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
|
||||
|
||||
def document_embeddings_query(
|
||||
self,
|
||||
text: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Query document chunks using semantic similarity.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
limit: Maximum number of results (default: 10)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with similar document chunks
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
results = flow.document_embeddings_query(
|
||||
text="machine learning algorithms",
|
||||
user="trustgraph",
|
||||
collection="research-papers",
|
||||
limit=5
|
||||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
|
||||
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate vector embeddings for text.
|
||||
|
|
@ -737,7 +789,7 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
|
||||
def objects_query(
|
||||
def rows_query(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
|
|
@ -747,7 +799,7 @@ class SocketFlowInstance:
|
|||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a GraphQL query against structured objects.
|
||||
Execute a GraphQL query against structured rows.
|
||||
|
||||
Args:
|
||||
query: GraphQL query string
|
||||
|
|
@ -774,7 +826,7 @@ class SocketFlowInstance:
|
|||
}
|
||||
}
|
||||
'''
|
||||
result = flow.objects_query(
|
||||
result = flow.rows_query(
|
||||
query=query,
|
||||
user="trustgraph",
|
||||
collection="scientists"
|
||||
|
|
@ -792,7 +844,7 @@ class SocketFlowInstance:
|
|||
request["operationName"] = operation_name
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("objects", self.flow_id, request, False)
|
||||
return self.client._send_request_sync("rows", self.flow_id, request, False)
|
||||
|
||||
def mcp_tool(
|
||||
self,
|
||||
|
|
@ -829,3 +881,73 @@ class SocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||
|
||||
def row_embeddings_query(
|
||||
self,
|
||||
text: str,
|
||||
schema_name: str,
|
||||
user: str = "trustgraph",
|
||||
collection: str = "default",
|
||||
index_name: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)
|
||||
|
|
|
|||
|
|
@ -34,5 +34,6 @@ from . tool_service import ToolService
|
|||
from . tool_client import ToolClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ embeddings.
|
|||
import logging
|
||||
|
||||
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .. schema import Error, Value
|
||||
from .. schema import Error, Term
|
||||
|
||||
from . flow_processor import FlowProcessor
|
||||
from . consumer_spec import ConsumerSpec
|
||||
|
|
@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-query"
|
||||
default_ident = "doc-embeddings-query"
|
||||
|
||||
class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,21 @@
|
|||
import logging
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse, IRI, LITERAL
|
||||
from .. knowledge import Uri, Literal
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_value(x):
|
||||
if x.is_uri: return Uri(x.value)
|
||||
return Literal(x.value)
|
||||
"""Convert schema Term to Uri or Literal."""
|
||||
if x.type == IRI:
|
||||
return Uri(x.iri)
|
||||
elif x.type == LITERAL:
|
||||
return Literal(x.value)
|
||||
# Fallback
|
||||
return Literal(x.value or x.iri)
|
||||
|
||||
class GraphEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vectors, limit=20, user="trustgraph",
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ embeddings.
|
|||
import logging
|
||||
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .. schema import Error, Value
|
||||
from .. schema import Error, Term
|
||||
|
||||
from . flow_processor import FlowProcessor
|
||||
from . consumer_spec import ConsumerSpec
|
||||
|
|
@ -16,7 +16,7 @@ from . producer_spec import ProducerSpec
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "ge-query"
|
||||
default_ident = "graph-embeddings-query"
|
||||
|
||||
class GraphEmbeddingsQueryService(FlowProcessor):
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vectors=vectors,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit
|
||||
)
|
||||
if index_name:
|
||||
request.index_name = index_name
|
||||
|
||||
resp = await self.request(request, timeout=timeout)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Return matches as list of dicts
|
||||
return [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in (resp.matches or [])
|
||||
]
|
||||
|
||||
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RowEmbeddingsQueryClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RowEmbeddingsRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RowEmbeddingsResponse,
|
||||
impl = RowEmbeddingsQueryClient,
|
||||
)
|
||||
|
|
@ -222,35 +222,50 @@ class Subscriber:
|
|||
# Store message for later acknowledgment
|
||||
msg_id = str(uuid.uuid4())
|
||||
self.pending_acks[msg_id] = msg
|
||||
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
|
||||
value = msg.value()
|
||||
delivery_success = False
|
||||
|
||||
has_matching_waiter = False
|
||||
|
||||
async with self.lock:
|
||||
# Deliver to specific subscribers
|
||||
if id in self.q:
|
||||
has_matching_waiter = True
|
||||
delivery_success = await self._deliver_to_queue(
|
||||
self.q[id], value
|
||||
)
|
||||
|
||||
|
||||
# Deliver to all subscribers
|
||||
for q in self.full.values():
|
||||
has_matching_waiter = True
|
||||
if await self._deliver_to_queue(q, value):
|
||||
delivery_success = True
|
||||
|
||||
# Acknowledge only on successful delivery
|
||||
if delivery_success:
|
||||
self.consumer.acknowledge(msg)
|
||||
del self.pending_acks[msg_id]
|
||||
else:
|
||||
# Negative acknowledge for retry
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
del self.pending_acks[msg_id]
|
||||
|
||||
# Always acknowledge the message to prevent redelivery storms
|
||||
# on shared topics. Negative acknowledging orphaned messages
|
||||
# (no matching waiter) causes immediate redelivery to all
|
||||
# subscribers, none of whom can handle it either.
|
||||
self.consumer.acknowledge(msg)
|
||||
del self.pending_acks[msg_id]
|
||||
|
||||
if not delivery_success:
|
||||
if not has_matching_waiter:
|
||||
# Message arrived for a waiter that no longer exists
|
||||
# (likely due to client disconnect or timeout)
|
||||
logger.debug(
|
||||
f"Discarding orphaned message with id={id} - "
|
||||
"no matching waiter"
|
||||
)
|
||||
else:
|
||||
# Delivery failed (e.g., queue full with drop_new strategy)
|
||||
logger.debug(
|
||||
f"Message with id={id} dropped due to backpressure"
|
||||
)
|
||||
|
||||
async def _deliver_to_queue(self, queue, value):
|
||||
"""Deliver message to queue with backpressure handling"""
|
||||
|
|
|
|||
|
|
@ -1,24 +1,34 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
|
||||
from .. knowledge import Uri, Literal
|
||||
|
||||
|
||||
class Triple:
|
||||
def __init__(self, s, p, o):
|
||||
self.s = s
|
||||
self.p = p
|
||||
self.o = o
|
||||
|
||||
|
||||
def to_value(x):
|
||||
if x.is_uri: return Uri(x.value)
|
||||
return Literal(x.value)
|
||||
"""Convert schema Term to Uri or Literal."""
|
||||
if x.type == IRI:
|
||||
return Uri(x.iri)
|
||||
elif x.type == LITERAL:
|
||||
return Literal(x.value)
|
||||
# Fallback
|
||||
return Literal(x.value or x.iri)
|
||||
|
||||
|
||||
def from_value(x):
|
||||
if x is None: return None
|
||||
"""Convert Uri or Literal to schema Term."""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, Uri):
|
||||
return Value(value=str(x), is_uri=True)
|
||||
return Term(type=IRI, iri=str(x))
|
||||
else:
|
||||
return Value(value=str(x), is_uri=False)
|
||||
return Term(type=LITERAL, value=str(x))
|
||||
|
||||
class TriplesClient(RequestResponse):
|
||||
async def query(self, s=None, p=None, o=None, limit=20,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ null. Output is a list of triples.
|
|||
import logging
|
||||
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .. schema import Value, Triple
|
||||
from .. schema import Term, Triple
|
||||
|
||||
from . flow_processor import FlowProcessor
|
||||
from . consumer_spec import ConsumerSpec
|
||||
|
|
|
|||
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from .. schema import row_embeddings_request_queue
|
||||
from .. schema import row_embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class RowEmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = row_embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = row_embeddings_response_queue
|
||||
|
||||
super(RowEmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_schema=RowEmbeddingsRequest,
|
||||
output_schema=RowEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=300
|
||||
):
|
||||
kwargs = dict(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, schema_name=schema_name,
|
||||
limit=limit, timeout=timeout
|
||||
)
|
||||
if index_name:
|
||||
kwargs["index_name"] = index_name
|
||||
|
||||
response = self.call(**kwargs)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(f"{response.error.type}: {response.error.message}")
|
||||
|
||||
return response.matches
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Term, IRI, LITERAL
|
||||
from .. schema import triples_request_queue
|
||||
from .. schema import triples_response_queue
|
||||
from . base import BaseClient
|
||||
|
|
@ -46,9 +46,9 @@ class TriplesQueryClient(BaseClient):
|
|||
if ent == None: return None
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
return Term(type=IRI, iri=ent)
|
||||
|
||||
return Value(value=ent, is_uri=False)
|
||||
return Term(type=LITERAL, value=ent)
|
||||
|
||||
def request(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
|
|||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
||||
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
@ -107,15 +108,21 @@ TranslatorRegistry.register_service(
|
|||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"objects-query",
|
||||
ObjectsQueryRequestTranslator(),
|
||||
ObjectsQueryResponseTranslator()
|
||||
"row-embeddings-query",
|
||||
RowEmbeddingsRequestTranslator(),
|
||||
RowEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"rows-query",
|
||||
RowsQueryRequestTranslator(),
|
||||
RowsQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from .base import Translator, MessageTranslator
|
||||
from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator
|
||||
from .primitives import TermTranslator, ValueTranslator, TripleTranslator, SubgraphTranslator, RowSchemaTranslator, FieldTranslator, row_schema_translator, field_translator
|
||||
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
|
||||
from .agent import AgentRequestTranslator, AgentResponseTranslator
|
||||
from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
|
||||
|
|
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
|
|||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator
|
||||
|
|
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
schema_name=data.get("schema_name", ""),
|
||||
index_name=data.get("index_name")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"schema_name": obj.schema_name,
|
||||
}
|
||||
if obj.index_name:
|
||||
result["index_name"] = obj.index_name
|
||||
return result
|
||||
|
||||
|
||||
class RowEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.error is not None:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
if obj.matches is not None:
|
||||
result["matches"] = [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.matches
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue