mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Structured data 2 (#645)
* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables * Tech spec updated to track implementation
This commit is contained in:
parent
5ffad92345
commit
1809c1f56d
87 changed files with 5233 additions and 3235 deletions
6
Makefile
6
Makefile
|
|
@ -5,7 +5,7 @@ VERSION=0.0.0
|
||||||
|
|
||||||
DOCKER=podman
|
DOCKER=podman
|
||||||
|
|
||||||
all: container
|
all: containers
|
||||||
|
|
||||||
# Not used
|
# Not used
|
||||||
wheels:
|
wheels:
|
||||||
|
|
@ -49,7 +49,9 @@ update-package-versions:
|
||||||
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
|
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
|
||||||
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
|
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
|
||||||
|
|
||||||
container: update-package-versions
|
FORCE:
|
||||||
|
|
||||||
|
containers: FORCE
|
||||||
${DOCKER} build -f containers/Containerfile.base \
|
${DOCKER} build -f containers/Containerfile.base \
|
||||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||||
${DOCKER} build -f containers/Containerfile.flow \
|
${DOCKER} build -f containers/Containerfile.flow \
|
||||||
|
|
|
||||||
467
docs/tech-specs/structured-data-2.md
Normal file
467
docs/tech-specs/structured-data-2.md
Normal file
|
|
@ -0,0 +1,467 @@
|
||||||
|
# Structured Data Technical Specification (Part 2)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`.
|
||||||
|
|
||||||
|
## Problem Statements
|
||||||
|
|
||||||
|
### 1. Naming Inconsistency: "Object" vs "Row"
|
||||||
|
|
||||||
|
The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion:
|
||||||
|
|
||||||
|
- "Object" is an overloaded term in software (Python objects, JSON objects, etc.)
|
||||||
|
- The data being handled is fundamentally tabular - rows in tables with defined schemas
|
||||||
|
- "Row" more accurately describes the data model and aligns with database terminology
|
||||||
|
|
||||||
|
This inconsistency appears in module names, class names, message types, and documentation.
|
||||||
|
|
||||||
|
### 2. Row Store Query Limitations
|
||||||
|
|
||||||
|
The current row store implementation has significant query limitations:
|
||||||
|
|
||||||
|
**Natural Language Mismatch**: Queries struggle with real-world data variations. For example:
|
||||||
|
- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"`
|
||||||
|
- Abbreviations, case differences, and formatting variations break exact-match queries
|
||||||
|
- Users expect semantic understanding, but the store provides literal matching
|
||||||
|
|
||||||
|
**Schema Evolution Issues**: Changing schemas causes problems:
|
||||||
|
- Existing data may not conform to updated schemas
|
||||||
|
- Table structure changes can break queries and data integrity
|
||||||
|
- No clear migration path for schema updates
|
||||||
|
|
||||||
|
### 3. Row Embeddings Required
|
||||||
|
|
||||||
|
Related to problem 2, the system needs vector embeddings for row data to enable:
|
||||||
|
|
||||||
|
- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST")
|
||||||
|
- Similarity matching for fuzzy queries
|
||||||
|
- Hybrid search combining structured filters with semantic similarity
|
||||||
|
- Better natural language query support
|
||||||
|
|
||||||
|
The embedding service was specified but not implemented.
|
||||||
|
|
||||||
|
### 4. Row Data Ingestion Incomplete
|
||||||
|
|
||||||
|
The structured data ingestion pipeline is not fully operational:
|
||||||
|
|
||||||
|
- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.)
|
||||||
|
- The ingestion service that uses these prompts is not plumbed into the system
|
||||||
|
- No end-to-end path for loading pre-structured data into the row store
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations
|
||||||
|
- **Consistent Naming**: Standardize on "row" terminology throughout the codebase
|
||||||
|
- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings
|
||||||
|
- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data
|
||||||
|
|
||||||
|
## Technical Design
|
||||||
|
|
||||||
|
### Unified Row Storage Schema
|
||||||
|
|
||||||
|
The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations.
|
||||||
|
|
||||||
|
The new design uses a single unified table for all row data:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE rows (
|
||||||
|
collection text,
|
||||||
|
schema_name text,
|
||||||
|
index_name text,
|
||||||
|
index_value frozen<list<text>>,
|
||||||
|
data map<text, text>,
|
||||||
|
source text,
|
||||||
|
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Column Definitions
|
||||||
|
|
||||||
|
| Column | Type | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| `collection` | `text` | Data collection/import identifier (from metadata) |
|
||||||
|
| `schema_name` | `text` | Name of the schema this row conforms to |
|
||||||
|
| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites |
|
||||||
|
| `index_value` | `frozen<list<text>>` | Index value(s) as a list |
|
||||||
|
| `data` | `map<text, text>` | Row data as key-value pairs |
|
||||||
|
| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. |
|
||||||
|
|
||||||
|
#### Index Handling
|
||||||
|
|
||||||
|
Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility.
|
||||||
|
|
||||||
|
**Single-field index example:**
|
||||||
|
- Schema defines `email` as indexed
|
||||||
|
- `index_name = "email"`
|
||||||
|
- `index_value = ['foo@bar.com']`
|
||||||
|
|
||||||
|
**Composite index example:**
|
||||||
|
- Schema defines composite index on `region` and `status`
|
||||||
|
- `index_name = "region,status"` (field names sorted and comma-joined)
|
||||||
|
- `index_value = ['US', 'active']` (values in same order as field names)
|
||||||
|
|
||||||
|
**Primary key example:**
|
||||||
|
- Schema defines `customer_id` as primary key
|
||||||
|
- `index_name = "customer_id"`
|
||||||
|
- `index_value = ['CUST001']`
|
||||||
|
|
||||||
|
#### Query Patterns
|
||||||
|
|
||||||
|
All queries follow the same pattern regardless of which index is used:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT * FROM rows
|
||||||
|
WHERE collection = 'import_2024'
|
||||||
|
AND schema_name = 'customers'
|
||||||
|
AND index_name = 'email'
|
||||||
|
AND index_value = ['foo@bar.com']
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Design Trade-offs
|
||||||
|
|
||||||
|
**Advantages:**
|
||||||
|
- Schema changes don't require table structure changes
|
||||||
|
- Row data is opaque to Cassandra - field additions/removals are transparent
|
||||||
|
- Consistent query pattern for all access methods
|
||||||
|
- No Cassandra secondary indexes (which can be slow at scale)
|
||||||
|
- Native Cassandra types throughout (`map`, `frozen<list>`)
|
||||||
|
|
||||||
|
**Trade-offs:**
|
||||||
|
- Write amplification: each row insert = N inserts (one per indexed field)
|
||||||
|
- Storage overhead from duplicated row data
|
||||||
|
- Type information stored in schema config, conversion at application layer
|
||||||
|
|
||||||
|
#### Consistency Model
|
||||||
|
|
||||||
|
The design accepts certain simplifications:
|
||||||
|
|
||||||
|
1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row.
|
||||||
|
|
||||||
|
2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed.
|
||||||
|
|
||||||
|
### Partition Tracking and Deletion
|
||||||
|
|
||||||
|
#### The Problem
|
||||||
|
|
||||||
|
With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data.
|
||||||
|
|
||||||
|
#### Partition Tracking Table
|
||||||
|
|
||||||
|
A secondary lookup table tracks which partitions exist:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE row_partitions (
|
||||||
|
collection text,
|
||||||
|
schema_name text,
|
||||||
|
index_name text,
|
||||||
|
PRIMARY KEY ((collection), schema_name, index_name)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This enables efficient discovery of partitions for deletion operations.
|
||||||
|
|
||||||
|
#### Row Writer Behavior
|
||||||
|
|
||||||
|
The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row:
|
||||||
|
|
||||||
|
1. Check if `(collection, schema_name)` is in the cache
|
||||||
|
2. If not cached (first row for this pair):
|
||||||
|
- Look up the schema config to get all index names
|
||||||
|
- Insert entries into `row_partitions` for each `(collection, schema_name, index_name)`
|
||||||
|
- Add the pair to the cache
|
||||||
|
3. Proceed with writing the row data
|
||||||
|
|
||||||
|
The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names.
|
||||||
|
|
||||||
|
This approach ensures:
|
||||||
|
- Lookup table writes happen once per `(collection, schema_name)` pair, not per row
|
||||||
|
- The lookup table reflects the indexes that were active when data was written
|
||||||
|
- Schema changes mid-import are picked up correctly
|
||||||
|
|
||||||
|
#### Deletion Operations
|
||||||
|
|
||||||
|
**Delete collection:**
|
||||||
|
```sql
|
||||||
|
-- 1. Discover all partitions
|
||||||
|
SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X';
|
||||||
|
|
||||||
|
-- 2. Delete each partition from rows table
|
||||||
|
DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...';
|
||||||
|
-- (repeat for each discovered partition)
|
||||||
|
|
||||||
|
-- 3. Clean up the lookup table
|
||||||
|
DELETE FROM row_partitions WHERE collection = 'X';
|
||||||
|
```
|
||||||
|
|
||||||
|
**Delete collection + schema:**
|
||||||
|
```sql
|
||||||
|
-- 1. Discover partitions for this schema
|
||||||
|
SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
|
||||||
|
|
||||||
|
-- 2. Delete each partition from rows table
|
||||||
|
DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...';
|
||||||
|
-- (repeat for each discovered partition)
|
||||||
|
|
||||||
|
-- 3. Clean up the lookup table entries
|
||||||
|
DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
|
||||||
|
```
|
||||||
|
|
||||||
|
### Row Embeddings
|
||||||
|
|
||||||
|
Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street").
|
||||||
|
|
||||||
|
#### Design Overview
|
||||||
|
|
||||||
|
Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra.
|
||||||
|
|
||||||
|
#### Qdrant Collection Structure
|
||||||
|
|
||||||
|
One Qdrant collection per `(user, collection, schema_name, dimension)` tuple:
|
||||||
|
|
||||||
|
- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}`
|
||||||
|
- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix)
|
||||||
|
- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist
|
||||||
|
|
||||||
|
#### What Gets Embedded
|
||||||
|
|
||||||
|
The text representation of index values:
|
||||||
|
|
||||||
|
| Index Type | Example `index_value` | Text to Embed |
|
||||||
|
|------------|----------------------|---------------|
|
||||||
|
| Single-field | `['foo@bar.com']` | `"foo@bar.com"` |
|
||||||
|
| Composite | `['US', 'active']` | `"US active"` (space-joined) |
|
||||||
|
|
||||||
|
#### Point Structure
|
||||||
|
|
||||||
|
Each Qdrant point contains:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "<uuid>",
|
||||||
|
"vector": [0.1, 0.2, ...],
|
||||||
|
"payload": {
|
||||||
|
"index_name": "street_name",
|
||||||
|
"index_value": ["CHESTNUT ST"],
|
||||||
|
"text": "CHESTNUT ST"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Payload Field | Description |
|
||||||
|
|---------------|-------------|
|
||||||
|
| `index_name` | The indexed field(s) this embedding represents |
|
||||||
|
| `index_value` | The original list of values (for Cassandra lookup) |
|
||||||
|
| `text` | The text that was embedded (for debugging/display) |
|
||||||
|
|
||||||
|
Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name.
|
||||||
|
|
||||||
|
#### Query Flow
|
||||||
|
|
||||||
|
1. User queries for "Chestnut Street" within user U, collection X, schema Y
|
||||||
|
2. Embed the query text
|
||||||
|
3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_`
|
||||||
|
4. Search matching Qdrant collection(s) for nearest vectors
|
||||||
|
5. Get matching points with payloads containing `index_name` and `index_value`
|
||||||
|
6. Query Cassandra:
|
||||||
|
```sql
|
||||||
|
SELECT * FROM rows
|
||||||
|
WHERE collection = 'X'
|
||||||
|
AND schema_name = 'Y'
|
||||||
|
AND index_name = '<from payload>'
|
||||||
|
AND index_value = <from payload>
|
||||||
|
```
|
||||||
|
7. Return matched rows
|
||||||
|
|
||||||
|
#### Optional: Filtering by Index Name
|
||||||
|
|
||||||
|
Queries can optionally filter by `index_name` in Qdrant to search only specific fields:
|
||||||
|
|
||||||
|
- **"Find any field matching 'Chestnut'"** → search all vectors in the collection
|
||||||
|
- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'`
|
||||||
|
|
||||||
|
#### Architecture
|
||||||
|
|
||||||
|
Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings):
|
||||||
|
|
||||||
|
- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings`
|
||||||
|
- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant
|
||||||
|
|
||||||
|
The Cassandra row writer is a separate parallel consumer:
|
||||||
|
|
||||||
|
- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra
|
||||||
|
|
||||||
|
All three services consume from the same flow, keeping them decoupled. This allows:
|
||||||
|
- Independent scaling of Cassandra writes vs embedding generation vs vector storage
|
||||||
|
- Embedding services can be disabled if not needed
|
||||||
|
- Failures in one service don't affect the others
|
||||||
|
- Consistent architecture with GraphRAG pipelines
|
||||||
|
|
||||||
|
#### Write Path
|
||||||
|
|
||||||
|
**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`:
|
||||||
|
|
||||||
|
1. Look up the schema to find indexed fields
|
||||||
|
2. For each indexed field:
|
||||||
|
- Build the text representation of the index value
|
||||||
|
- Compute embedding via the embeddings service
|
||||||
|
3. Output a `RowEmbeddings` message containing all computed vectors
|
||||||
|
|
||||||
|
**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`:
|
||||||
|
|
||||||
|
1. For each embedding in the message:
|
||||||
|
- Determine Qdrant collection from `(user, collection, schema_name, dimension)`
|
||||||
|
- Create collection if needed (lazy creation on first write)
|
||||||
|
- Upsert point with vector and payload
|
||||||
|
|
||||||
|
#### Message Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class RowIndexEmbedding:
|
||||||
|
index_name: str # The indexed field name(s)
|
||||||
|
index_value: list[str] # The field value(s)
|
||||||
|
text: str # Text that was embedded
|
||||||
|
vectors: list[list[float]] # Computed embedding vectors
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddings:
|
||||||
|
metadata: Metadata
|
||||||
|
schema_name: str
|
||||||
|
embeddings: list[RowIndexEmbedding]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Deletion Integration
|
||||||
|
|
||||||
|
Qdrant collections are discovered by prefix matching on the collection name pattern:
|
||||||
|
|
||||||
|
**Delete `(user, collection)`:**
|
||||||
|
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_`
|
||||||
|
2. Delete each matching collection
|
||||||
|
3. Delete Cassandra rows partitions (as documented above)
|
||||||
|
4. Clean up `row_partitions` entries
|
||||||
|
|
||||||
|
**Delete `(user, collection, schema_name)`:**
|
||||||
|
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_`
|
||||||
|
2. Delete each matching collection (handles multiple dimensions)
|
||||||
|
3. Delete Cassandra rows partitions
|
||||||
|
4. Clean up `row_partitions`
|
||||||
|
|
||||||
|
#### Module Locations
|
||||||
|
|
||||||
|
| Stage | Module | Entry Point |
|
||||||
|
|-------|--------|-------------|
|
||||||
|
| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` |
|
||||||
|
| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` |
|
||||||
|
|
||||||
|
### Row Embeddings Query API
|
||||||
|
|
||||||
|
The row embeddings query is a **separate API** from the GraphQL row query service:
|
||||||
|
|
||||||
|
| API | Purpose | Backend |
|
||||||
|
|-----|---------|---------|
|
||||||
|
| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra |
|
||||||
|
| Row Embeddings Query | Fuzzy/semantic matching | Qdrant |
|
||||||
|
|
||||||
|
This separation keeps concerns clean:
|
||||||
|
- GraphQL service focuses on exact, structured queries
|
||||||
|
- Embeddings API handles semantic similarity
|
||||||
|
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
|
||||||
|
|
||||||
|
### Row Data Ingestion
|
||||||
|
|
||||||
|
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
|
||||||
|
|
||||||
|
## Implementation Impact
|
||||||
|
|
||||||
|
### Current State Analysis
|
||||||
|
|
||||||
|
The existing implementation has two main components:
|
||||||
|
|
||||||
|
| Component | Location | Lines | Description |
|
||||||
|
|-----------|----------|-------|-------------|
|
||||||
|
| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling |
|
||||||
|
| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete |
|
||||||
|
|
||||||
|
**Current Query Pattern:**
|
||||||
|
```sql
|
||||||
|
SELECT * FROM {keyspace}.o_{schema_name}
|
||||||
|
WHERE collection = 'X' AND email = 'foo@bar.com'
|
||||||
|
ALLOW FILTERING
|
||||||
|
```
|
||||||
|
|
||||||
|
**New Query Pattern:**
|
||||||
|
```sql
|
||||||
|
SELECT * FROM {keyspace}.rows
|
||||||
|
WHERE collection = 'X' AND schema_name = 'customers'
|
||||||
|
AND index_name = 'email' AND index_value = ['foo@bar.com']
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Changes
|
||||||
|
|
||||||
|
1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either:
|
||||||
|
- Become post-filtering on returned data (if still needed)
|
||||||
|
- Are removed in favor of using the embeddings API for fuzzy matching
|
||||||
|
|
||||||
|
2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code.
|
||||||
|
|
||||||
|
### Proposed Refactor
|
||||||
|
|
||||||
|
The refactor has two parts:
|
||||||
|
|
||||||
|
#### 1. Break Out GraphQL Code
|
||||||
|
|
||||||
|
Extract reusable GraphQL components into a shared module:
|
||||||
|
|
||||||
|
```
|
||||||
|
trustgraph-flow/trustgraph/query/graphql/
|
||||||
|
├── __init__.py
|
||||||
|
├── types.py # Filter types (IntFilter, StringFilter, FloatFilter)
|
||||||
|
├── schema.py # Dynamic schema generation from RowSchema
|
||||||
|
└── filters.py # Filter parsing utilities
|
||||||
|
```
|
||||||
|
|
||||||
|
This enables:
|
||||||
|
- Reuse across different row store backends
|
||||||
|
- Cleaner separation of concerns
|
||||||
|
- Easier testing of GraphQL logic independently
|
||||||
|
|
||||||
|
#### 2. Implement New Table Schema
|
||||||
|
|
||||||
|
Refactor the Cassandra-specific code to use the unified table:
|
||||||
|
|
||||||
|
**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`):
|
||||||
|
- Single `rows` table instead of per-schema tables
|
||||||
|
- Write N copies per row (one per index)
|
||||||
|
- Register to `row_partitions` table
|
||||||
|
- Simpler table creation (one-time setup)
|
||||||
|
|
||||||
|
**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`):
|
||||||
|
- Query the unified `rows` table
|
||||||
|
- Use extracted GraphQL module for schema generation
|
||||||
|
- Simplified filter handling (exact match only at DB level)
|
||||||
|
|
||||||
|
### Module Renames
|
||||||
|
|
||||||
|
As part of the "object" → "row" naming cleanup:
|
||||||
|
|
||||||
|
| Current | New |
|
||||||
|
|---------|-----|
|
||||||
|
| `storage/objects/cassandra/` | `storage/rows/cassandra/` |
|
||||||
|
| `query/objects/cassandra/` | `query/rows/cassandra/` |
|
||||||
|
| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` |
|
||||||
|
|
||||||
|
### New Modules
|
||||||
|
|
||||||
|
| Module | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities |
|
||||||
|
| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API |
|
||||||
|
| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) |
|
||||||
|
| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) |
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [Structured Data Technical Specification](structured-data.md)
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
type: object
|
type: object
|
||||||
description: |
|
description: |
|
||||||
Objects query request - GraphQL query over knowledge graph.
|
Rows query request - GraphQL query over structured data.
|
||||||
required:
|
required:
|
||||||
- query
|
- query
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
type: object
|
type: object
|
||||||
description: Objects query response (GraphQL format)
|
description: Rows query response (GraphQL format)
|
||||||
properties:
|
properties:
|
||||||
data:
|
data:
|
||||||
description: GraphQL response data (JSON object or null)
|
description: GraphQL response data (JSON object or null)
|
||||||
|
|
@ -121,8 +121,8 @@ paths:
|
||||||
$ref: './paths/flow/mcp-tool.yaml'
|
$ref: './paths/flow/mcp-tool.yaml'
|
||||||
/api/v1/flow/{flow}/service/triples:
|
/api/v1/flow/{flow}/service/triples:
|
||||||
$ref: './paths/flow/triples.yaml'
|
$ref: './paths/flow/triples.yaml'
|
||||||
/api/v1/flow/{flow}/service/objects:
|
/api/v1/flow/{flow}/service/rows:
|
||||||
$ref: './paths/flow/objects.yaml'
|
$ref: './paths/flow/rows.yaml'
|
||||||
/api/v1/flow/{flow}/service/nlp-query:
|
/api/v1/flow/{flow}/service/nlp-query:
|
||||||
$ref: './paths/flow/nlp-query.yaml'
|
$ref: './paths/flow/nlp-query.yaml'
|
||||||
/api/v1/flow/{flow}/service/structured-query:
|
/api/v1/flow/{flow}/service/structured-query:
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ post:
|
||||||
```
|
```
|
||||||
1. User asks: "Who does Alice know?"
|
1. User asks: "Who does Alice know?"
|
||||||
2. NLP Query generates GraphQL
|
2. NLP Query generates GraphQL
|
||||||
3. Execute via /api/v1/flow/{flow}/service/objects
|
3. Execute via /api/v1/flow/{flow}/service/rows
|
||||||
4. Return results to user
|
4. Return results to user
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
post:
|
post:
|
||||||
tags:
|
tags:
|
||||||
- Flow Services
|
- Flow Services
|
||||||
summary: Objects query - GraphQL over knowledge graph
|
summary: Rows query - GraphQL over structured data
|
||||||
description: |
|
description: |
|
||||||
Query knowledge graph using GraphQL for object-oriented data access.
|
Query structured data using GraphQL for row-oriented data access.
|
||||||
|
|
||||||
## Objects Query Overview
|
## Rows Query Overview
|
||||||
|
|
||||||
GraphQL interface to knowledge graph:
|
GraphQL interface to structured data:
|
||||||
- **Schema-driven**: Predefined types and relationships
|
- **Schema-driven**: Predefined types and relationships
|
||||||
- **Flexible queries**: Request exactly what you need
|
- **Flexible queries**: Request exactly what you need
|
||||||
- **Nested data**: Traverse relationships in single query
|
- **Nested data**: Traverse relationships in single query
|
||||||
- **Type-safe**: Strong typing with introspection
|
- **Type-safe**: Strong typing with introspection
|
||||||
|
|
||||||
Abstracts RDF triples into familiar object model.
|
Abstracts structured rows into familiar object model.
|
||||||
|
|
||||||
## GraphQL Benefits
|
## GraphQL Benefits
|
||||||
|
|
||||||
|
|
@ -61,7 +61,7 @@ post:
|
||||||
Schema defines available types via config service.
|
Schema defines available types via config service.
|
||||||
Use introspection query to discover schema.
|
Use introspection query to discover schema.
|
||||||
|
|
||||||
operationId: objectsQueryService
|
operationId: rowsQueryService
|
||||||
security:
|
security:
|
||||||
- bearerAuth: []
|
- bearerAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -77,7 +77,7 @@ post:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '../../components/schemas/query/ObjectsQueryRequest.yaml'
|
$ref: '../../components/schemas/query/RowsQueryRequest.yaml'
|
||||||
examples:
|
examples:
|
||||||
simpleQuery:
|
simpleQuery:
|
||||||
summary: Simple query
|
summary: Simple query
|
||||||
|
|
@ -129,7 +129,7 @@ post:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '../../components/schemas/query/ObjectsQueryResponse.yaml'
|
$ref: '../../components/schemas/query/RowsQueryResponse.yaml'
|
||||||
examples:
|
examples:
|
||||||
successfulQuery:
|
successfulQuery:
|
||||||
summary: Successful query
|
summary: Successful query
|
||||||
|
|
@ -9,7 +9,7 @@ post:
|
||||||
|
|
||||||
Combines two operations in one call:
|
Combines two operations in one call:
|
||||||
1. **NLP Query**: Generate GraphQL from question
|
1. **NLP Query**: Generate GraphQL from question
|
||||||
2. **Objects Query**: Execute generated query
|
2. **Rows Query**: Execute generated query
|
||||||
3. **Return Results**: Direct answer data
|
3. **Return Results**: Direct answer data
|
||||||
|
|
||||||
Simplest way to query knowledge graph with natural language.
|
Simplest way to query knowledge graph with natural language.
|
||||||
|
|
@ -21,7 +21,7 @@ post:
|
||||||
- **Output**: Query results (data)
|
- **Output**: Query results (data)
|
||||||
- **Use when**: Want simple, direct answers
|
- **Use when**: Want simple, direct answers
|
||||||
|
|
||||||
### NLP Query + Objects Query (separate calls)
|
### NLP Query + Rows Query (separate calls)
|
||||||
- **Step 1**: Convert question → GraphQL
|
- **Step 1**: Convert question → GraphQL
|
||||||
- **Step 2**: Execute GraphQL → results
|
- **Step 2**: Execute GraphQL → results
|
||||||
- **Use when**: Need to inspect/modify query before execution
|
- **Use when**: Need to inspect/modify query before execution
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ payload:
|
||||||
- $ref: './requests/EmbeddingsRequest.yaml'
|
- $ref: './requests/EmbeddingsRequest.yaml'
|
||||||
- $ref: './requests/McpToolRequest.yaml'
|
- $ref: './requests/McpToolRequest.yaml'
|
||||||
- $ref: './requests/TriplesRequest.yaml'
|
- $ref: './requests/TriplesRequest.yaml'
|
||||||
- $ref: './requests/ObjectsRequest.yaml'
|
- $ref: './requests/RowsRequest.yaml'
|
||||||
- $ref: './requests/NlpQueryRequest.yaml'
|
- $ref: './requests/NlpQueryRequest.yaml'
|
||||||
- $ref: './requests/StructuredQueryRequest.yaml'
|
- $ref: './requests/StructuredQueryRequest.yaml'
|
||||||
- $ref: './requests/StructuredDiagRequest.yaml'
|
- $ref: './requests/StructuredDiagRequest.yaml'
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
type: object
|
type: object
|
||||||
description: WebSocket request for objects service (flow-hosted service)
|
description: WebSocket request for rows service (flow-hosted service)
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- service
|
- service
|
||||||
|
|
@ -11,16 +11,16 @@ properties:
|
||||||
description: Unique request identifier
|
description: Unique request identifier
|
||||||
service:
|
service:
|
||||||
type: string
|
type: string
|
||||||
const: objects
|
const: rows
|
||||||
description: Service identifier for objects service
|
description: Service identifier for rows service
|
||||||
flow:
|
flow:
|
||||||
type: string
|
type: string
|
||||||
description: Flow ID
|
description: Flow ID
|
||||||
request:
|
request:
|
||||||
$ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml'
|
$ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml'
|
||||||
examples:
|
examples:
|
||||||
- id: req-1
|
- id: req-1
|
||||||
service: objects
|
service: rows
|
||||||
flow: my-flow
|
flow: my-flow
|
||||||
request:
|
request:
|
||||||
query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }"
|
query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }"
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""
|
"""
|
||||||
Contract tests for Cassandra Object Storage
|
Contract tests for Cassandra Row Storage
|
||||||
|
|
||||||
These tests verify the message contracts and schema compatibility
|
These tests verify the message contracts and schema compatibility
|
||||||
for the objects storage processor.
|
for the rows storage processor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -10,12 +10,12 @@ import json
|
||||||
from pulsar.schema import AvroSchema
|
from pulsar.schema import AvroSchema
|
||||||
|
|
||||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor
|
from trustgraph.storage.rows.cassandra.write import Processor
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.contract
|
@pytest.mark.contract
|
||||||
class TestObjectsCassandraContracts:
|
class TestRowsCassandraContracts:
|
||||||
"""Contract tests for Cassandra object storage messages"""
|
"""Contract tests for Cassandra row storage messages"""
|
||||||
|
|
||||||
def test_extracted_object_input_contract(self):
|
def test_extracted_object_input_contract(self):
|
||||||
"""Test that ExtractedObject schema matches expected input format"""
|
"""Test that ExtractedObject schema matches expected input format"""
|
||||||
|
|
@ -145,50 +145,6 @@ class TestObjectsCassandraContracts:
|
||||||
assert required_field_keys.issubset(field.keys())
|
assert required_field_keys.issubset(field.keys())
|
||||||
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
|
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
|
||||||
|
|
||||||
def test_cassandra_type_mapping_contract(self):
|
|
||||||
"""Test that all supported field types have Cassandra mappings"""
|
|
||||||
processor = Processor.__new__(Processor)
|
|
||||||
|
|
||||||
# All field types that should be supported
|
|
||||||
supported_types = [
|
|
||||||
("string", "text"),
|
|
||||||
("integer", "int"), # or bigint based on size
|
|
||||||
("float", "float"), # or double based on size
|
|
||||||
("boolean", "boolean"),
|
|
||||||
("timestamp", "timestamp"),
|
|
||||||
("date", "date"),
|
|
||||||
("time", "time"),
|
|
||||||
("uuid", "uuid")
|
|
||||||
]
|
|
||||||
|
|
||||||
for field_type, expected_cassandra_type in supported_types:
|
|
||||||
cassandra_type = processor.get_cassandra_type(field_type)
|
|
||||||
# For integer and float, the exact type depends on size
|
|
||||||
if field_type in ["integer", "float"]:
|
|
||||||
assert cassandra_type in ["int", "bigint", "float", "double"]
|
|
||||||
else:
|
|
||||||
assert cassandra_type == expected_cassandra_type
|
|
||||||
|
|
||||||
def test_value_conversion_contract(self):
|
|
||||||
"""Test value conversion for all supported types"""
|
|
||||||
processor = Processor.__new__(Processor)
|
|
||||||
|
|
||||||
# Test conversions maintain data integrity
|
|
||||||
test_cases = [
|
|
||||||
# (input_value, field_type, expected_output, expected_type)
|
|
||||||
("123", "integer", 123, int),
|
|
||||||
("123.45", "float", 123.45, float),
|
|
||||||
("true", "boolean", True, bool),
|
|
||||||
("false", "boolean", False, bool),
|
|
||||||
("test string", "string", "test string", str),
|
|
||||||
(None, "string", None, type(None)),
|
|
||||||
]
|
|
||||||
|
|
||||||
for input_val, field_type, expected_val, expected_type in test_cases:
|
|
||||||
result = processor.convert_value(input_val, field_type)
|
|
||||||
assert result == expected_val
|
|
||||||
assert isinstance(result, expected_type) or result is None
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||||
def test_extracted_object_serialization_contract(self):
|
def test_extracted_object_serialization_contract(self):
|
||||||
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
||||||
|
|
@ -222,43 +178,31 @@ class TestObjectsCassandraContracts:
|
||||||
assert decoded.confidence == original.confidence
|
assert decoded.confidence == original.confidence
|
||||||
assert decoded.source_span == original.source_span
|
assert decoded.source_span == original.source_span
|
||||||
|
|
||||||
def test_cassandra_table_naming_contract(self):
|
def test_cassandra_name_sanitization_contract(self):
|
||||||
"""Test Cassandra naming conventions and constraints"""
|
"""Test Cassandra naming conventions and constraints"""
|
||||||
processor = Processor.__new__(Processor)
|
processor = Processor.__new__(Processor)
|
||||||
|
|
||||||
# Test table naming (always gets o_ prefix)
|
# Test name sanitization for Cassandra identifiers
|
||||||
table_test_names = [
|
# - Non-alphanumeric chars (except underscore) become underscores
|
||||||
("simple_name", "o_simple_name"),
|
# - Names starting with non-letter get 'r_' prefix
|
||||||
("Name-With-Dashes", "o_name_with_dashes"),
|
# - All names converted to lowercase
|
||||||
("name.with.dots", "o_name_with_dots"),
|
|
||||||
("123_numbers", "o_123_numbers"),
|
|
||||||
("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores
|
|
||||||
("UPPERCASE", "o_uppercase"),
|
|
||||||
("CamelCase", "o_camelcase"),
|
|
||||||
("", "o_"), # Edge case - empty string becomes o_
|
|
||||||
]
|
|
||||||
|
|
||||||
for input_name, expected_name in table_test_names:
|
|
||||||
result = processor.sanitize_table(input_name)
|
|
||||||
assert result == expected_name
|
|
||||||
# Verify result is valid Cassandra identifier (starts with letter)
|
|
||||||
assert result.startswith('o_')
|
|
||||||
assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_'
|
|
||||||
|
|
||||||
# Test regular name sanitization (only adds o_ prefix if starts with number)
|
|
||||||
name_test_cases = [
|
name_test_cases = [
|
||||||
("simple_name", "simple_name"),
|
("simple_name", "simple_name"),
|
||||||
("Name-With-Dashes", "name_with_dashes"),
|
("Name-With-Dashes", "name_with_dashes"),
|
||||||
("name.with.dots", "name_with_dots"),
|
("name.with.dots", "name_with_dots"),
|
||||||
("123_numbers", "o_123_numbers"), # Only this gets o_ prefix
|
("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number)
|
||||||
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
|
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
|
||||||
("UPPERCASE", "uppercase"),
|
("UPPERCASE", "uppercase"),
|
||||||
("CamelCase", "camelcase"),
|
("CamelCase", "camelcase"),
|
||||||
|
("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore)
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_name, expected_name in name_test_cases:
|
for input_name, expected_name in name_test_cases:
|
||||||
result = processor.sanitize_name(input_name)
|
result = processor.sanitize_name(input_name)
|
||||||
assert result == expected_name
|
assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}"
|
||||||
|
# Verify result is valid Cassandra identifier (starts with letter)
|
||||||
|
if result: # Skip empty string case
|
||||||
|
assert result[0].isalpha(), f"Result {result} should start with a letter"
|
||||||
|
|
||||||
def test_primary_key_structure_contract(self):
|
def test_primary_key_structure_contract(self):
|
||||||
"""Test that primary key structure follows Cassandra best practices"""
|
"""Test that primary key structure follows Cassandra best practices"""
|
||||||
|
|
@ -308,8 +252,8 @@ class TestObjectsCassandraContracts:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.contract
|
@pytest.mark.contract
|
||||||
class TestObjectsCassandraContractsBatch:
|
class TestRowsCassandraContractsBatch:
|
||||||
"""Contract tests for Cassandra object storage batch processing"""
|
"""Contract tests for Cassandra row storage batch processing"""
|
||||||
|
|
||||||
def test_extracted_object_batch_input_contract(self):
|
def test_extracted_object_batch_input_contract(self):
|
||||||
"""Test that batched ExtractedObject schema matches expected input format"""
|
"""Test that batched ExtractedObject schema matches expected input format"""
|
||||||
|
|
@ -1,26 +1,26 @@
|
||||||
"""
|
"""
|
||||||
Contract tests for Objects GraphQL Query Service
|
Contract tests for Rows GraphQL Query Service
|
||||||
|
|
||||||
These tests verify the message contracts and schema compatibility
|
These tests verify the message contracts and schema compatibility
|
||||||
for the objects GraphQL query processor.
|
for the rows GraphQL query processor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
from pulsar.schema import AvroSchema
|
from pulsar.schema import AvroSchema
|
||||||
|
|
||||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
from trustgraph.query.objects.cassandra.service import Processor
|
from trustgraph.query.rows.cassandra.service import Processor
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.contract
|
@pytest.mark.contract
|
||||||
class TestObjectsGraphQLQueryContracts:
|
class TestRowsGraphQLQueryContracts:
|
||||||
"""Contract tests for GraphQL query service messages"""
|
"""Contract tests for GraphQL query service messages"""
|
||||||
|
|
||||||
def test_objects_query_request_contract(self):
|
def test_rows_query_request_contract(self):
|
||||||
"""Test ObjectsQueryRequest schema structure and required fields"""
|
"""Test RowsQueryRequest schema structure and required fields"""
|
||||||
# Create test request with all required fields
|
# Create test request with all required fields
|
||||||
test_request = ObjectsQueryRequest(
|
test_request = RowsQueryRequest(
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
query='{ customers { id name email } }',
|
query='{ customers { id name email } }',
|
||||||
|
|
@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert test_request.variables["status"] == "active"
|
assert test_request.variables["status"] == "active"
|
||||||
assert test_request.operation_name == "GetCustomers"
|
assert test_request.operation_name == "GetCustomers"
|
||||||
|
|
||||||
def test_objects_query_request_minimal(self):
|
def test_rows_query_request_minimal(self):
|
||||||
"""Test ObjectsQueryRequest with minimal required fields"""
|
"""Test RowsQueryRequest with minimal required fields"""
|
||||||
# Create request with only essential fields
|
# Create request with only essential fields
|
||||||
minimal_request = ObjectsQueryRequest(
|
minimal_request = RowsQueryRequest(
|
||||||
user="user",
|
user="user",
|
||||||
collection="collection",
|
collection="collection",
|
||||||
query='{ test }',
|
query='{ test }',
|
||||||
|
|
@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert test_error.path == ["customers", "0", "nonexistent"]
|
assert test_error.path == ["customers", "0", "nonexistent"]
|
||||||
assert test_error.extensions["code"] == "FIELD_ERROR"
|
assert test_error.extensions["code"] == "FIELD_ERROR"
|
||||||
|
|
||||||
def test_objects_query_response_success_contract(self):
|
def test_rows_query_response_success_contract(self):
|
||||||
"""Test ObjectsQueryResponse schema for successful queries"""
|
"""Test RowsQueryResponse schema for successful queries"""
|
||||||
# Create successful response
|
# Create successful response
|
||||||
success_response = ObjectsQueryResponse(
|
success_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||||
errors=[],
|
errors=[],
|
||||||
|
|
@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert len(parsed_data["customers"]) == 1
|
assert len(parsed_data["customers"]) == 1
|
||||||
assert parsed_data["customers"][0]["id"] == "1"
|
assert parsed_data["customers"][0]["id"] == "1"
|
||||||
|
|
||||||
def test_objects_query_response_error_contract(self):
|
def test_rows_query_response_error_contract(self):
|
||||||
"""Test ObjectsQueryResponse schema for error cases"""
|
"""Test RowsQueryResponse schema for error cases"""
|
||||||
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
|
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
|
||||||
# by creating a response without the problematic errors array first
|
# by creating a response without the problematic errors array first
|
||||||
error_response = ObjectsQueryResponse(
|
error_response = RowsQueryResponse(
|
||||||
error=None, # System error is None - these are GraphQL errors
|
error=None, # System error is None - these are GraphQL errors
|
||||||
data=None, # No data due to errors
|
data=None, # No data due to errors
|
||||||
errors=[], # Empty errors array to avoid Pulsar bug
|
errors=[], # Empty errors array to avoid Pulsar bug
|
||||||
|
|
@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert validation_error.path == ["customers", "email"]
|
assert validation_error.path == ["customers", "email"]
|
||||||
assert validation_error.extensions["details"] == "Invalid email format"
|
assert validation_error.extensions["details"] == "Invalid email format"
|
||||||
|
|
||||||
def test_objects_query_response_system_error_contract(self):
|
def test_rows_query_response_system_error_contract(self):
|
||||||
"""Test ObjectsQueryResponse schema for system errors"""
|
"""Test RowsQueryResponse schema for system errors"""
|
||||||
from trustgraph.schema import Error
|
from trustgraph.schema import Error
|
||||||
|
|
||||||
# Create system error response
|
# Create system error response
|
||||||
system_error_response = ObjectsQueryResponse(
|
system_error_response = RowsQueryResponse(
|
||||||
error=Error(
|
error=Error(
|
||||||
type="objects-query-error",
|
type="rows-query-error",
|
||||||
message="Failed to connect to Cassandra cluster"
|
message="Failed to connect to Cassandra cluster"
|
||||||
),
|
),
|
||||||
data=None,
|
data=None,
|
||||||
|
|
@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
|
|
||||||
# Verify system error structure
|
# Verify system error structure
|
||||||
assert system_error_response.error is not None
|
assert system_error_response.error is not None
|
||||||
assert system_error_response.error.type == "objects-query-error"
|
assert system_error_response.error.type == "rows-query-error"
|
||||||
assert "Cassandra" in system_error_response.error.message
|
assert "Cassandra" in system_error_response.error.message
|
||||||
assert system_error_response.data is None
|
assert system_error_response.data is None
|
||||||
assert len(system_error_response.errors) == 0
|
assert len(system_error_response.errors) == 0
|
||||||
|
|
@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
def test_request_response_serialization_contract(self):
|
def test_request_response_serialization_contract(self):
|
||||||
"""Test that request/response can be serialized/deserialized correctly"""
|
"""Test that request/response can be serialized/deserialized correctly"""
|
||||||
# Create original request
|
# Create original request
|
||||||
original_request = ObjectsQueryRequest(
|
original_request = RowsQueryRequest(
|
||||||
user="serialization_test",
|
user="serialization_test",
|
||||||
collection="test_data",
|
collection="test_data",
|
||||||
query='{ orders(limit: 5) { id total customer { name } } }',
|
query='{ orders(limit: 5) { id total customer { name } } }',
|
||||||
|
|
@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test request serialization using Pulsar schema
|
# Test request serialization using Pulsar schema
|
||||||
request_schema = AvroSchema(ObjectsQueryRequest)
|
request_schema = AvroSchema(RowsQueryRequest)
|
||||||
|
|
||||||
# Encode and decode request
|
# Encode and decode request
|
||||||
encoded_request = request_schema.encode(original_request)
|
encoded_request = request_schema.encode(original_request)
|
||||||
|
|
@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert decoded_request.operation_name == original_request.operation_name
|
assert decoded_request.operation_name == original_request.operation_name
|
||||||
|
|
||||||
# Create original response - work around Pulsar Array(Record) bug
|
# Create original response - work around Pulsar Array(Record) bug
|
||||||
original_response = ObjectsQueryResponse(
|
original_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"orders": []}',
|
data='{"orders": []}',
|
||||||
errors=[], # Empty to avoid Pulsar validation bug
|
errors=[], # Empty to avoid Pulsar validation bug
|
||||||
|
|
@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test response serialization
|
# Test response serialization
|
||||||
response_schema = AvroSchema(ObjectsQueryResponse)
|
response_schema = AvroSchema(RowsQueryResponse)
|
||||||
|
|
||||||
# Encode and decode response
|
# Encode and decode response
|
||||||
encoded_response = response_schema.encode(original_response)
|
encoded_response = response_schema.encode(original_response)
|
||||||
|
|
@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
def test_graphql_query_format_contract(self):
|
def test_graphql_query_format_contract(self):
|
||||||
"""Test supported GraphQL query formats"""
|
"""Test supported GraphQL query formats"""
|
||||||
# Test basic query
|
# Test basic query
|
||||||
basic_query = ObjectsQueryRequest(
|
basic_query = RowsQueryRequest(
|
||||||
user="test", collection="test", query='{ customers { id } }',
|
user="test", collection="test", query='{ customers { id } }',
|
||||||
variables={}, operation_name=""
|
variables={}, operation_name=""
|
||||||
)
|
)
|
||||||
|
|
@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert basic_query.query.strip().endswith('}')
|
assert basic_query.query.strip().endswith('}')
|
||||||
|
|
||||||
# Test query with variables
|
# Test query with variables
|
||||||
parameterized_query = ObjectsQueryRequest(
|
parameterized_query = RowsQueryRequest(
|
||||||
user="test", collection="test",
|
user="test", collection="test",
|
||||||
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
|
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
|
||||||
variables={"status": "active", "limit": "10"},
|
variables={"status": "active", "limit": "10"},
|
||||||
|
|
@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert parameterized_query.operation_name == "GetCustomers"
|
assert parameterized_query.operation_name == "GetCustomers"
|
||||||
|
|
||||||
# Test complex nested query
|
# Test complex nested query
|
||||||
nested_query = ObjectsQueryRequest(
|
nested_query = RowsQueryRequest(
|
||||||
user="test", collection="test",
|
user="test", collection="test",
|
||||||
query='''
|
query='''
|
||||||
{
|
{
|
||||||
|
|
@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
# Note: Current schema uses Map(String()) which only supports string values
|
# Note: Current schema uses Map(String()) which only supports string values
|
||||||
# This test verifies the current contract, though ideally we'd support all JSON types
|
# This test verifies the current contract, though ideally we'd support all JSON types
|
||||||
|
|
||||||
variables_test = ObjectsQueryRequest(
|
variables_test = RowsQueryRequest(
|
||||||
user="test", collection="test", query='{ test }',
|
user="test", collection="test", query='{ test }',
|
||||||
variables={
|
variables={
|
||||||
"string_var": "test_value",
|
"string_var": "test_value",
|
||||||
|
|
@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
def test_cassandra_context_fields_contract(self):
|
def test_cassandra_context_fields_contract(self):
|
||||||
"""Test that request contains necessary fields for Cassandra operations"""
|
"""Test that request contains necessary fields for Cassandra operations"""
|
||||||
# Verify request has fields needed for Cassandra keyspace/table targeting
|
# Verify request has fields needed for Cassandra keyspace/table targeting
|
||||||
request = ObjectsQueryRequest(
|
request = RowsQueryRequest(
|
||||||
user="keyspace_name", # Maps to Cassandra keyspace
|
user="keyspace_name", # Maps to Cassandra keyspace
|
||||||
collection="partition_collection", # Used in partition key
|
collection="partition_collection", # Used in partition key
|
||||||
query='{ objects { id } }',
|
query='{ objects { id } }',
|
||||||
|
|
@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
def test_graphql_extensions_contract(self):
|
def test_graphql_extensions_contract(self):
|
||||||
"""Test GraphQL extensions field format and usage"""
|
"""Test GraphQL extensions field format and usage"""
|
||||||
# Extensions should support query metadata
|
# Extensions should support query metadata
|
||||||
response_with_extensions = ObjectsQueryResponse(
|
response_with_extensions = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"test": "data"}',
|
data='{"test": "data"}',
|
||||||
errors=[],
|
errors=[],
|
||||||
|
|
@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Request to execute specific operation
|
# Request to execute specific operation
|
||||||
multi_op_request = ObjectsQueryRequest(
|
multi_op_request = RowsQueryRequest(
|
||||||
user="test", collection="test",
|
user="test", collection="test",
|
||||||
query=multi_op_query,
|
query=multi_op_query,
|
||||||
variables={},
|
variables={},
|
||||||
|
|
@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts:
|
||||||
assert "GetOrders" in multi_op_request.query
|
assert "GetOrders" in multi_op_request.query
|
||||||
|
|
||||||
# Test single operation (operation_name optional)
|
# Test single operation (operation_name optional)
|
||||||
single_op_request = ObjectsQueryRequest(
|
single_op_request = RowsQueryRequest(
|
||||||
user="test", collection="test",
|
user="test", collection="test",
|
||||||
query='{ customers { id } }',
|
query='{ customers { id } }',
|
||||||
variables={}, operation_name=""
|
variables={}, operation_name=""
|
||||||
|
|
@ -12,7 +12,7 @@ from argparse import ArgumentParser
|
||||||
|
|
||||||
# Import processors that use Cassandra configuration
|
# Import processors that use Cassandra configuration
|
||||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||||
|
|
||||||
|
|
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
|
||||||
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
|
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
|
||||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||||
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
|
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
|
||||||
"""Test complete flow from environment variables to Cassandra Cluster connection."""
|
"""Test complete flow from environment variables to Cassandra Cluster connection."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
|
|
||||||
# Trigger Cassandra connection
|
# Trigger Cassandra connection
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
|
||||||
class TestMultipleHostsHandling:
|
class TestMultipleHostsHandling:
|
||||||
"""Test multiple Cassandra hosts handling end-to-end."""
|
"""Test multiple Cassandra hosts handling end-to-end."""
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
|
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
|
||||||
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
|
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Verify all hosts were passed to Cluster
|
# Verify all hosts were passed to Cluster
|
||||||
|
|
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
|
||||||
class TestAuthenticationFlow:
|
class TestAuthenticationFlow:
|
||||||
"""Test authentication configuration flow end-to-end."""
|
"""Test authentication configuration flow end-to-end."""
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||||
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
|
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
|
||||||
"""Test that authentication is enabled when both username and password are provided."""
|
"""Test that authentication is enabled when both username and password are provided."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Auth provider should be created
|
# Auth provider should be created
|
||||||
|
|
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
|
||||||
assert 'auth_provider' in call_args.kwargs
|
assert 'auth_provider' in call_args.kwargs
|
||||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||||
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
||||||
"""Test that authentication is not used when credentials are missing."""
|
"""Test that authentication is not used when credentials are missing."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Auth provider should not be created
|
# Auth provider should not be created
|
||||||
|
|
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
|
||||||
call_args = mock_cluster.call_args
|
call_args = mock_cluster.call_args
|
||||||
assert 'auth_provider' not in call_args.kwargs
|
assert 'auth_provider' not in call_args.kwargs
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||||
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
||||||
"""Test that authentication is not used when only username is provided."""
|
"""Test that authentication is not used when only username is provided."""
|
||||||
processor = ObjectsWriter(
|
processor = RowsWriter(
|
||||||
taskgroup=MagicMock(),
|
taskgroup=MagicMock(),
|
||||||
cassandra_host='partial-auth-host',
|
cassandra_host='partial-auth-host',
|
||||||
cassandra_username='partial-user'
|
cassandra_username='partial-user'
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from trustgraph.extract.kg.objects.processor import Processor
|
from trustgraph.extract.kg.rows.processor import Processor
|
||||||
from trustgraph.schema import (
|
from trustgraph.schema import (
|
||||||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||||
PromptRequest, PromptResponse
|
PromptRequest, PromptResponse
|
||||||
|
|
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
|
||||||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||||
|
|
||||||
# Import and bind the convert_values_to_strings function
|
# Import and bind the convert_values_to_strings function
|
||||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||||
processor.convert_values_to_strings = convert_values_to_strings
|
processor.convert_values_to_strings = convert_values_to_strings
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
|
|
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
|
||||||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||||
|
|
||||||
# Import and bind the convert_values_to_strings function
|
# Import and bind the convert_values_to_strings function
|
||||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||||
processor.convert_values_to_strings = convert_values_to_strings
|
processor.convert_values_to_strings = convert_values_to_strings
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
|
|
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
|
||||||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||||
|
|
||||||
# Import and bind the convert_values_to_strings function
|
# Import and bind the convert_values_to_strings function
|
||||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||||
processor.convert_values_to_strings = convert_values_to_strings
|
processor.convert_values_to_strings = convert_values_to_strings
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
|
|
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
|
||||||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||||
|
|
||||||
# Import and bind the convert_values_to_strings function
|
# Import and bind the convert_values_to_strings function
|
||||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||||
processor.convert_values_to_strings = convert_values_to_strings
|
processor.convert_values_to_strings = convert_values_to_strings
|
||||||
|
|
||||||
# Mock flow with failing prompt service
|
# Mock flow with failing prompt service
|
||||||
|
|
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
|
||||||
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
|
||||||
|
|
||||||
# Import and bind the convert_values_to_strings function
|
# Import and bind the convert_values_to_strings function
|
||||||
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
|
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
|
||||||
processor.convert_values_to_strings = convert_values_to_strings
|
processor.convert_values_to_strings = convert_values_to_strings
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
|
|
|
||||||
|
|
@ -1,608 +0,0 @@
|
||||||
"""
|
|
||||||
Integration tests for Cassandra Object Storage
|
|
||||||
|
|
||||||
These tests verify the end-to-end functionality of storing ExtractedObjects
|
|
||||||
in Cassandra, including table creation, data insertion, and error handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor
|
|
||||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
class TestObjectsCassandraIntegration:
|
|
||||||
"""Integration tests for Cassandra object storage"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_cassandra_session(self):
|
|
||||||
"""Mock Cassandra session for integration tests"""
|
|
||||||
session = MagicMock()
|
|
||||||
|
|
||||||
# Track if keyspaces have been created
|
|
||||||
created_keyspaces = set()
|
|
||||||
|
|
||||||
# Mock the execute method to return a valid result for keyspace checks
|
|
||||||
def execute_mock(query, *args, **kwargs):
|
|
||||||
result = MagicMock()
|
|
||||||
query_str = str(query)
|
|
||||||
|
|
||||||
# Track keyspace creation
|
|
||||||
if "CREATE KEYSPACE" in query_str:
|
|
||||||
# Extract keyspace name from query
|
|
||||||
import re
|
|
||||||
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
|
||||||
if match:
|
|
||||||
created_keyspaces.add(match.group(1))
|
|
||||||
|
|
||||||
# For keyspace existence checks
|
|
||||||
if "system_schema.keyspaces" in query_str:
|
|
||||||
# Check if this keyspace was created
|
|
||||||
if args and args[0] in created_keyspaces:
|
|
||||||
result.one.return_value = MagicMock() # Exists
|
|
||||||
else:
|
|
||||||
result.one.return_value = None # Doesn't exist
|
|
||||||
else:
|
|
||||||
result.one.return_value = None
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
session.execute = MagicMock(side_effect=execute_mock)
|
|
||||||
return session
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_cassandra_cluster(self, mock_cassandra_session):
|
|
||||||
"""Mock Cassandra cluster"""
|
|
||||||
cluster = MagicMock()
|
|
||||||
cluster.connect.return_value = mock_cassandra_session
|
|
||||||
cluster.shutdown = MagicMock()
|
|
||||||
return cluster
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
|
||||||
"""Create processor with mocked Cassandra dependencies"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.graph_host = "localhost"
|
|
||||||
processor.graph_username = None
|
|
||||||
processor.graph_password = None
|
|
||||||
processor.config_key = "schema"
|
|
||||||
processor.schemas = {}
|
|
||||||
processor.known_keyspaces = set()
|
|
||||||
processor.known_tables = {}
|
|
||||||
processor.cluster = None
|
|
||||||
processor.session = None
|
|
||||||
|
|
||||||
# Bind actual methods
|
|
||||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
|
||||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
|
||||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
|
||||||
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
|
|
||||||
|
|
||||||
return processor, mock_cassandra_cluster, mock_cassandra_session
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
|
||||||
"""Test complete flow from schema config to object storage"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
# Mock Cluster creation
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
# Step 1: Configure schema
|
|
||||||
config = {
|
|
||||||
"schema": {
|
|
||||||
"customer_records": json.dumps({
|
|
||||||
"name": "customer_records",
|
|
||||||
"description": "Customer information",
|
|
||||||
"fields": [
|
|
||||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
|
||||||
{"name": "name", "type": "string", "required": True},
|
|
||||||
{"name": "email", "type": "string", "indexed": True},
|
|
||||||
{"name": "age", "type": "integer"}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await processor.on_schema_config(config, version=1)
|
|
||||||
assert "customer_records" in processor.schemas
|
|
||||||
|
|
||||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
|
||||||
await processor.create_collection("test_user", "import_2024", {})
|
|
||||||
|
|
||||||
# Step 2: Process an ExtractedObject
|
|
||||||
test_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="doc-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="import_2024",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="customer_records",
|
|
||||||
values=[{
|
|
||||||
"customer_id": "CUST001",
|
|
||||||
"name": "John Doe",
|
|
||||||
"email": "john@example.com",
|
|
||||||
"age": "30"
|
|
||||||
}],
|
|
||||||
confidence=0.95,
|
|
||||||
source_span="Customer: John Doe..."
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = test_obj
|
|
||||||
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify Cassandra interactions
|
|
||||||
assert mock_cluster.connect.called
|
|
||||||
|
|
||||||
# Verify keyspace creation
|
|
||||||
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE KEYSPACE" in str(call)]
|
|
||||||
assert len(keyspace_calls) == 1
|
|
||||||
assert "test_user" in str(keyspace_calls[0])
|
|
||||||
|
|
||||||
# Verify table creation
|
|
||||||
table_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE TABLE" in str(call)]
|
|
||||||
assert len(table_calls) == 1
|
|
||||||
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
|
|
||||||
assert "collection text" in str(table_calls[0])
|
|
||||||
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
|
|
||||||
|
|
||||||
# Verify index creation
|
|
||||||
index_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE INDEX" in str(call)]
|
|
||||||
assert len(index_calls) == 1
|
|
||||||
assert "email" in str(index_calls[0])
|
|
||||||
|
|
||||||
# Verify data insertion
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 1
|
|
||||||
insert_call = insert_calls[0]
|
|
||||||
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
|
|
||||||
|
|
||||||
# Check inserted values
|
|
||||||
values = insert_call[0][1]
|
|
||||||
assert "import_2024" in values # collection
|
|
||||||
assert "CUST001" in values # customer_id
|
|
||||||
assert "John Doe" in values # name
|
|
||||||
assert "john@example.com" in values # email
|
|
||||||
assert 30 in values # age (converted to int)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multi_schema_handling(self, processor_with_mocks):
|
|
||||||
"""Test handling multiple schemas and objects"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
# Configure multiple schemas
|
|
||||||
config = {
|
|
||||||
"schema": {
|
|
||||||
"products": json.dumps({
|
|
||||||
"name": "products",
|
|
||||||
"fields": [
|
|
||||||
{"name": "product_id", "type": "string", "primary_key": True},
|
|
||||||
{"name": "name", "type": "string"},
|
|
||||||
{"name": "price", "type": "float"}
|
|
||||||
]
|
|
||||||
}),
|
|
||||||
"orders": json.dumps({
|
|
||||||
"name": "orders",
|
|
||||||
"fields": [
|
|
||||||
{"name": "order_id", "type": "string", "primary_key": True},
|
|
||||||
{"name": "customer_id", "type": "string"},
|
|
||||||
{"name": "total", "type": "float"}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await processor.on_schema_config(config, version=1)
|
|
||||||
assert len(processor.schemas) == 2
|
|
||||||
|
|
||||||
# Create collections first
|
|
||||||
await processor.create_collection("shop", "catalog", {})
|
|
||||||
await processor.create_collection("shop", "sales", {})
|
|
||||||
|
|
||||||
# Process objects for different schemas
|
|
||||||
product_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
|
||||||
schema_name="products",
|
|
||||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
|
||||||
confidence=0.9,
|
|
||||||
source_span="Product..."
|
|
||||||
)
|
|
||||||
|
|
||||||
order_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
|
||||||
schema_name="orders",
|
|
||||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
|
||||||
confidence=0.85,
|
|
||||||
source_span="Order..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process both objects
|
|
||||||
for obj in [product_obj, order_obj]:
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = obj
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify separate tables were created
|
|
||||||
table_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE TABLE" in str(call)]
|
|
||||||
assert len(table_calls) == 2
|
|
||||||
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
|
|
||||||
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_missing_required_fields(self, processor_with_mocks):
|
|
||||||
"""Test handling of objects with missing required fields"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
# Configure schema with required field
|
|
||||||
processor.schemas["test_schema"] = RowSchema(
|
|
||||||
name="test_schema",
|
|
||||||
description="Test",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", size=50, primary=True, required=True),
|
|
||||||
Field(name="required_field", type="string", size=100, required=True)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collection first
|
|
||||||
await processor.create_collection("test", "test", {})
|
|
||||||
|
|
||||||
# Create object missing required field
|
|
||||||
test_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
|
||||||
schema_name="test_schema",
|
|
||||||
values=[{"id": "123"}], # missing required_field
|
|
||||||
confidence=0.8,
|
|
||||||
source_span="Test"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = test_obj
|
|
||||||
|
|
||||||
# Should still process (Cassandra doesn't enforce NOT NULL)
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify insert was attempted
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_schema_without_primary_key(self, processor_with_mocks):
|
|
||||||
"""Test handling schemas without defined primary keys"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
# Configure schema without primary key
|
|
||||||
processor.schemas["events"] = RowSchema(
|
|
||||||
name="events",
|
|
||||||
description="Event log",
|
|
||||||
fields=[
|
|
||||||
Field(name="event_type", type="string", size=50),
|
|
||||||
Field(name="timestamp", type="timestamp", size=0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collection first
|
|
||||||
await processor.create_collection("logger", "app_events", {})
|
|
||||||
|
|
||||||
# Process object
|
|
||||||
test_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
|
||||||
schema_name="events",
|
|
||||||
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
|
|
||||||
confidence=1.0,
|
|
||||||
source_span="Event"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = test_obj
|
|
||||||
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify synthetic_id was added
|
|
||||||
table_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE TABLE" in str(call)]
|
|
||||||
assert len(table_calls) == 1
|
|
||||||
assert "synthetic_id uuid" in str(table_calls[0])
|
|
||||||
|
|
||||||
# Verify insert includes UUID
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 1
|
|
||||||
values = insert_calls[0][0][1]
|
|
||||||
# Check that a UUID was generated (will be in values list)
|
|
||||||
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
|
|
||||||
assert uuid_found
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_authentication_handling(self, processor_with_mocks):
|
|
||||||
"""Test Cassandra authentication"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
processor.cassandra_username = "cassandra_user"
|
|
||||||
processor.cassandra_password = "cassandra_pass"
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
|
||||||
mock_cluster_class.return_value = mock_cluster
|
|
||||||
|
|
||||||
# Trigger connection
|
|
||||||
processor.connect_cassandra()
|
|
||||||
|
|
||||||
# Verify authentication was configured
|
|
||||||
mock_auth.assert_called_once_with(
|
|
||||||
username="cassandra_user",
|
|
||||||
password="cassandra_pass"
|
|
||||||
)
|
|
||||||
mock_cluster_class.assert_called_once()
|
|
||||||
call_kwargs = mock_cluster_class.call_args[1]
|
|
||||||
assert 'auth_provider' in call_kwargs
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_error_handling_during_insert(self, processor_with_mocks):
|
|
||||||
"""Test error handling when insertion fails"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
processor.schemas["test"] = RowSchema(
|
|
||||||
name="test",
|
|
||||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make insert fail
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.one.return_value = MagicMock() # Keyspace exists
|
|
||||||
mock_session.execute.side_effect = [
|
|
||||||
mock_result, # keyspace existence check succeeds
|
|
||||||
None, # table creation succeeds
|
|
||||||
Exception("Connection timeout") # insert fails
|
|
||||||
]
|
|
||||||
|
|
||||||
test_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
|
||||||
schema_name="test",
|
|
||||||
values=[{"id": "123"}],
|
|
||||||
confidence=0.9,
|
|
||||||
source_span="Test"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = test_obj
|
|
||||||
|
|
||||||
# Should raise the exception
|
|
||||||
with pytest.raises(Exception, match="Connection timeout"):
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_collection_partitioning(self, processor_with_mocks):
|
|
||||||
"""Test that objects are properly partitioned by collection"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
processor.schemas["data"] = RowSchema(
|
|
||||||
name="data",
|
|
||||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process objects from different collections
|
|
||||||
collections = ["import_jan", "import_feb", "import_mar"]
|
|
||||||
|
|
||||||
# Create all collections first
|
|
||||||
for coll in collections:
|
|
||||||
await processor.create_collection("analytics", coll, {})
|
|
||||||
|
|
||||||
for coll in collections:
|
|
||||||
obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
|
||||||
schema_name="data",
|
|
||||||
values=[{"id": f"ID-{coll}"}],
|
|
||||||
confidence=0.9,
|
|
||||||
source_span="Data"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = obj
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify all inserts include collection in values
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 3
|
|
||||||
|
|
||||||
# Check each insert has the correct collection
|
|
||||||
for i, call in enumerate(insert_calls):
|
|
||||||
values = call[0][1]
|
|
||||||
assert collections[i] in values
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_batch_object_processing(self, processor_with_mocks):
|
|
||||||
"""Test processing objects with batched values"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
# Configure schema
|
|
||||||
config = {
|
|
||||||
"schema": {
|
|
||||||
"batch_customers": json.dumps({
|
|
||||||
"name": "batch_customers",
|
|
||||||
"description": "Customer batch data",
|
|
||||||
"fields": [
|
|
||||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
|
||||||
{"name": "name", "type": "string", "required": True},
|
|
||||||
{"name": "email", "type": "string", "indexed": True}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await processor.on_schema_config(config, version=1)
|
|
||||||
|
|
||||||
# Process batch object with multiple values
|
|
||||||
batch_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="batch-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="batch_import",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="batch_customers",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"customer_id": "CUST001",
|
|
||||||
"name": "John Doe",
|
|
||||||
"email": "john@example.com"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"customer_id": "CUST002",
|
|
||||||
"name": "Jane Smith",
|
|
||||||
"email": "jane@example.com"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"customer_id": "CUST003",
|
|
||||||
"name": "Bob Johnson",
|
|
||||||
"email": "bob@example.com"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
confidence=0.92,
|
|
||||||
source_span="Multiple customers extracted from document"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collection first
|
|
||||||
await processor.create_collection("test_user", "batch_import", {})
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = batch_obj
|
|
||||||
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify table creation
|
|
||||||
table_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE TABLE" in str(call)]
|
|
||||||
assert len(table_calls) == 1
|
|
||||||
assert "o_batch_customers" in str(table_calls[0])
|
|
||||||
|
|
||||||
# Verify multiple inserts for batch values
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
# Should have 3 separate inserts for the 3 objects in the batch
|
|
||||||
assert len(insert_calls) == 3
|
|
||||||
|
|
||||||
# Check each insert has correct data
|
|
||||||
for i, call in enumerate(insert_calls):
|
|
||||||
values = call[0][1]
|
|
||||||
assert "batch_import" in values # collection
|
|
||||||
assert f"CUST00{i+1}" in values # customer_id
|
|
||||||
if i == 0:
|
|
||||||
assert "John Doe" in values
|
|
||||||
assert "john@example.com" in values
|
|
||||||
elif i == 1:
|
|
||||||
assert "Jane Smith" in values
|
|
||||||
assert "jane@example.com" in values
|
|
||||||
elif i == 2:
|
|
||||||
assert "Bob Johnson" in values
|
|
||||||
assert "bob@example.com" in values
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
|
||||||
"""Test processing objects with empty values array"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
processor.schemas["empty_test"] = RowSchema(
|
|
||||||
name="empty_test",
|
|
||||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collection first
|
|
||||||
await processor.create_collection("test", "empty", {})
|
|
||||||
|
|
||||||
# Process empty batch object
|
|
||||||
empty_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
|
||||||
schema_name="empty_test",
|
|
||||||
values=[], # Empty batch
|
|
||||||
confidence=1.0,
|
|
||||||
source_span="No objects found"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = empty_obj
|
|
||||||
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Should still create table
|
|
||||||
table_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "CREATE TABLE" in str(call)]
|
|
||||||
assert len(table_calls) == 1
|
|
||||||
|
|
||||||
# Should not create any insert statements for empty batch
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
|
|
||||||
"""Test processing mix of single and batch objects"""
|
|
||||||
processor, mock_cluster, mock_session = processor_with_mocks
|
|
||||||
|
|
||||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
|
||||||
processor.schemas["mixed_test"] = RowSchema(
|
|
||||||
name="mixed_test",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", size=50, primary=True),
|
|
||||||
Field(name="data", type="string", size=100)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create collection first
|
|
||||||
await processor.create_collection("test", "mixed", {})
|
|
||||||
|
|
||||||
# Single object (backward compatibility)
|
|
||||||
single_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
|
||||||
schema_name="mixed_test",
|
|
||||||
values=[{"id": "single-1", "data": "single data"}], # Array with single item
|
|
||||||
confidence=0.9,
|
|
||||||
source_span="Single object"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Batch object
|
|
||||||
batch_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
|
||||||
schema_name="mixed_test",
|
|
||||||
values=[
|
|
||||||
{"id": "batch-1", "data": "batch data 1"},
|
|
||||||
{"id": "batch-2", "data": "batch data 2"}
|
|
||||||
],
|
|
||||||
confidence=0.85,
|
|
||||||
source_span="Batch objects"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process both
|
|
||||||
for obj in [single_obj, batch_obj]:
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = obj
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Should have 3 total inserts (1 + 2)
|
|
||||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
|
||||||
if "INSERT INTO" in str(call)]
|
|
||||||
assert len(insert_calls) == 3
|
|
||||||
492
tests/integration/test_rows_cassandra_integration.py
Normal file
492
tests/integration/test_rows_cassandra_integration.py
Normal file
|
|
@ -0,0 +1,492 @@
|
||||||
|
"""
|
||||||
|
Integration tests for Cassandra Row Storage (Unified Table Implementation)
|
||||||
|
|
||||||
|
These tests verify the end-to-end functionality of storing ExtractedObjects
|
||||||
|
in the unified Cassandra rows table, including table creation, data insertion,
|
||||||
|
and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from trustgraph.storage.rows.cassandra.write import Processor
|
||||||
|
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestRowsCassandraIntegration:
|
||||||
|
"""Integration tests for Cassandra row storage with unified table"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_cassandra_session(self):
|
||||||
|
"""Mock Cassandra session for integration tests"""
|
||||||
|
session = MagicMock()
|
||||||
|
|
||||||
|
# Track if keyspaces have been created
|
||||||
|
created_keyspaces = set()
|
||||||
|
|
||||||
|
# Mock the execute method to return a valid result for keyspace checks
|
||||||
|
def execute_mock(query, *args, **kwargs):
|
||||||
|
result = MagicMock()
|
||||||
|
query_str = str(query)
|
||||||
|
|
||||||
|
# Track keyspace creation
|
||||||
|
if "CREATE KEYSPACE" in query_str:
|
||||||
|
import re
|
||||||
|
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
|
||||||
|
if match:
|
||||||
|
created_keyspaces.add(match.group(1))
|
||||||
|
|
||||||
|
# For keyspace existence checks
|
||||||
|
if "system_schema.keyspaces" in query_str:
|
||||||
|
if args and args[0] in created_keyspaces:
|
||||||
|
result.one.return_value = MagicMock() # Exists
|
||||||
|
else:
|
||||||
|
result.one.return_value = None # Doesn't exist
|
||||||
|
else:
|
||||||
|
result.one.return_value = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
session.execute = MagicMock(side_effect=execute_mock)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_cassandra_cluster(self, mock_cassandra_session):
|
||||||
|
"""Mock Cassandra cluster"""
|
||||||
|
cluster = MagicMock()
|
||||||
|
cluster.connect.return_value = mock_cassandra_session
|
||||||
|
cluster.shutdown = MagicMock()
|
||||||
|
return cluster
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
|
||||||
|
"""Create processor with mocked Cassandra dependencies"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.cassandra_host = ["localhost"]
|
||||||
|
processor.cassandra_username = None
|
||||||
|
processor.cassandra_password = None
|
||||||
|
processor.config_key = "schema"
|
||||||
|
processor.schemas = {}
|
||||||
|
processor.known_keyspaces = set()
|
||||||
|
processor.tables_initialized = set()
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.cluster = None
|
||||||
|
processor.session = None
|
||||||
|
|
||||||
|
# Bind actual methods from the new unified table implementation
|
||||||
|
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||||
|
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||||
|
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||||
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
|
||||||
|
return processor, mock_cassandra_cluster, mock_cassandra_session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_object_storage(self, processor_with_mocks):
|
||||||
|
"""Test complete flow from schema config to object storage"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
# Step 1: Configure schema
|
||||||
|
config = {
|
||||||
|
"schema": {
|
||||||
|
"customer_records": json.dumps({
|
||||||
|
"name": "customer_records",
|
||||||
|
"description": "Customer information",
|
||||||
|
"fields": [
|
||||||
|
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||||
|
{"name": "name", "type": "string", "required": True},
|
||||||
|
{"name": "email", "type": "string", "indexed": True},
|
||||||
|
{"name": "age", "type": "integer"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await processor.on_schema_config(config, version=1)
|
||||||
|
assert "customer_records" in processor.schemas
|
||||||
|
|
||||||
|
# Step 2: Process an ExtractedObject
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="doc-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="import_2024",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="customer_records",
|
||||||
|
values=[{
|
||||||
|
"customer_id": "CUST001",
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
"age": "30"
|
||||||
|
}],
|
||||||
|
confidence=0.95,
|
||||||
|
source_span="Customer: John Doe..."
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify Cassandra interactions
|
||||||
|
assert mock_cluster.connect.called
|
||||||
|
|
||||||
|
# Verify keyspace creation
|
||||||
|
keyspace_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "CREATE KEYSPACE" in str(call)]
|
||||||
|
assert len(keyspace_calls) == 1
|
||||||
|
assert "test_user" in str(keyspace_calls[0])
|
||||||
|
|
||||||
|
# Verify unified table creation (rows table, not per-schema table)
|
||||||
|
table_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "CREATE TABLE" in str(call)]
|
||||||
|
assert len(table_calls) == 2 # rows table + row_partitions table
|
||||||
|
assert any("rows" in str(call) for call in table_calls)
|
||||||
|
assert any("row_partitions" in str(call) for call in table_calls)
|
||||||
|
|
||||||
|
# Verify the rows table has correct structure
|
||||||
|
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
|
||||||
|
assert "collection text" in str(rows_table_call)
|
||||||
|
assert "schema_name text" in str(rows_table_call)
|
||||||
|
assert "index_name text" in str(rows_table_call)
|
||||||
|
assert "data map<text, text>" in str(rows_table_call)
|
||||||
|
|
||||||
|
# Verify data insertion into unified table
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
|
||||||
|
assert len(rows_insert_calls) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_schema_handling(self, processor_with_mocks):
|
||||||
|
"""Test handling multiple schemas stored in unified table"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
# Configure multiple schemas
|
||||||
|
config = {
|
||||||
|
"schema": {
|
||||||
|
"products": json.dumps({
|
||||||
|
"name": "products",
|
||||||
|
"fields": [
|
||||||
|
{"name": "product_id", "type": "string", "primary_key": True},
|
||||||
|
{"name": "name", "type": "string"},
|
||||||
|
{"name": "price", "type": "float"}
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
"orders": json.dumps({
|
||||||
|
"name": "orders",
|
||||||
|
"fields": [
|
||||||
|
{"name": "order_id", "type": "string", "primary_key": True},
|
||||||
|
{"name": "customer_id", "type": "string"},
|
||||||
|
{"name": "total", "type": "float"}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await processor.on_schema_config(config, version=1)
|
||||||
|
assert len(processor.schemas) == 2
|
||||||
|
|
||||||
|
# Process objects for different schemas
|
||||||
|
product_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||||
|
schema_name="products",
|
||||||
|
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span="Product..."
|
||||||
|
)
|
||||||
|
|
||||||
|
order_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||||
|
schema_name="orders",
|
||||||
|
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||||
|
confidence=0.85,
|
||||||
|
source_span="Order..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process both objects
|
||||||
|
for obj in [product_obj, order_obj]:
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = obj
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# All data goes into the same unified rows table
|
||||||
|
table_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "CREATE TABLE" in str(call)]
|
||||||
|
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
|
||||||
|
assert len(table_calls) == 2
|
||||||
|
|
||||||
|
# Verify data inserts go to unified rows table
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
assert len(rows_insert_calls) > 0
|
||||||
|
for call in rows_insert_calls:
|
||||||
|
assert ".rows" in str(call)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_index_storage(self, processor_with_mocks):
|
||||||
|
"""Test that rows are stored with multiple indexes"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
# Schema with multiple indexed fields
|
||||||
|
processor.schemas["indexed_data"] = RowSchema(
|
||||||
|
name="indexed_data",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", size=50, primary=True),
|
||||||
|
Field(name="category", type="string", size=50, indexed=True),
|
||||||
|
Field(name="status", type="string", size=50, indexed=True),
|
||||||
|
Field(name="description", type="string", size=200) # Not indexed
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||||
|
schema_name="indexed_data",
|
||||||
|
values=[{
|
||||||
|
"id": "123",
|
||||||
|
"category": "electronics",
|
||||||
|
"status": "active",
|
||||||
|
"description": "A product"
|
||||||
|
}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span="Test"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Should have 3 data inserts (one per indexed field: id, category, status)
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
assert len(rows_insert_calls) == 3
|
||||||
|
|
||||||
|
# Verify different index names were used
|
||||||
|
index_names = set()
|
||||||
|
for call in rows_insert_calls:
|
||||||
|
values = call[0][1]
|
||||||
|
index_names.add(values[2]) # index_name is 3rd parameter
|
||||||
|
|
||||||
|
assert index_names == {"id", "category", "status"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authentication_handling(self, processor_with_mocks):
|
||||||
|
"""Test Cassandra authentication"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
processor.cassandra_username = "cassandra_user"
|
||||||
|
processor.cassandra_password = "cassandra_pass"
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||||
|
mock_cluster_class.return_value = mock_cluster
|
||||||
|
|
||||||
|
# Trigger connection
|
||||||
|
processor.connect_cassandra()
|
||||||
|
|
||||||
|
# Verify authentication was configured
|
||||||
|
mock_auth.assert_called_once_with(
|
||||||
|
username="cassandra_user",
|
||||||
|
password="cassandra_pass"
|
||||||
|
)
|
||||||
|
mock_cluster_class.assert_called_once()
|
||||||
|
call_kwargs = mock_cluster_class.call_args[1]
|
||||||
|
assert 'auth_provider' in call_kwargs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_object_processing(self, processor_with_mocks):
|
||||||
|
"""Test processing objects with batched values"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
# Configure schema
|
||||||
|
config = {
|
||||||
|
"schema": {
|
||||||
|
"batch_customers": json.dumps({
|
||||||
|
"name": "batch_customers",
|
||||||
|
"description": "Customer batch data",
|
||||||
|
"fields": [
|
||||||
|
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||||
|
{"name": "name", "type": "string", "required": True},
|
||||||
|
{"name": "email", "type": "string", "indexed": True}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await processor.on_schema_config(config, version=1)
|
||||||
|
|
||||||
|
# Process batch object with multiple values
|
||||||
|
batch_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="batch-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="batch_import",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="batch_customers",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"customer_id": "CUST001",
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"customer_id": "CUST002",
|
||||||
|
"name": "Jane Smith",
|
||||||
|
"email": "jane@example.com"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"customer_id": "CUST003",
|
||||||
|
"name": "Bob Johnson",
|
||||||
|
"email": "bob@example.com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
confidence=0.92,
|
||||||
|
source_span="Multiple customers extracted from document"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = batch_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify unified table creation
|
||||||
|
table_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "CREATE TABLE" in str(call)]
|
||||||
|
assert len(table_calls) == 2 # rows + row_partitions
|
||||||
|
|
||||||
|
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
|
||||||
|
# 3 rows * 2 indexes = 6 data inserts
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
assert len(rows_insert_calls) == 6
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||||
|
"""Test processing objects with empty values array"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
processor.schemas["empty_test"] = RowSchema(
|
||||||
|
name="empty_test",
|
||||||
|
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process empty batch object
|
||||||
|
empty_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||||
|
schema_name="empty_test",
|
||||||
|
values=[], # Empty batch
|
||||||
|
confidence=1.0,
|
||||||
|
source_span="No objects found"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = empty_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Should not create any data insert statements for empty batch
|
||||||
|
# (partition registration may still happen)
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
assert len(rows_insert_calls) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_data_stored_as_map(self, processor_with_mocks):
|
||||||
|
"""Test that data is stored as map<text, text>"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
processor.schemas["map_test"] = RowSchema(
|
||||||
|
name="map_test",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", size=50, primary=True),
|
||||||
|
Field(name="name", type="string", size=100),
|
||||||
|
Field(name="count", type="integer", size=0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||||
|
schema_name="map_test",
|
||||||
|
values=[{"id": "123", "name": "Test Item", "count": "42"}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span="Test"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify insert uses map for data
|
||||||
|
rows_insert_calls = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and ".rows" in str(call)
|
||||||
|
and "row_partitions" not in str(call)]
|
||||||
|
assert len(rows_insert_calls) >= 1
|
||||||
|
|
||||||
|
# Check that data is passed as a dict (will be map in Cassandra)
|
||||||
|
insert_call = rows_insert_calls[0]
|
||||||
|
values = insert_call[0][1]
|
||||||
|
# Values are: (collection, schema_name, index_name, index_value, data, source)
|
||||||
|
# values[4] should be the data map
|
||||||
|
data_map = values[4]
|
||||||
|
assert isinstance(data_map, dict)
|
||||||
|
assert data_map["id"] == "123"
|
||||||
|
assert data_map["name"] == "Test Item"
|
||||||
|
assert data_map["count"] == "42"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partition_registration(self, processor_with_mocks):
|
||||||
|
"""Test that partitions are registered for efficient querying"""
|
||||||
|
processor, mock_cluster, mock_session = processor_with_mocks
|
||||||
|
|
||||||
|
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
|
||||||
|
processor.schemas["partition_test"] = RowSchema(
|
||||||
|
name="partition_test",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", size=50, primary=True),
|
||||||
|
Field(name="category", type="string", size=50, indexed=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
|
||||||
|
schema_name="partition_test",
|
||||||
|
values=[{"id": "123", "category": "test"}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span="Test"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify partition registration
|
||||||
|
partition_inserts = [call for call in mock_session.execute.call_args_list
|
||||||
|
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
|
||||||
|
# Should register partitions for each index (id, category)
|
||||||
|
assert len(partition_inserts) == 2
|
||||||
|
|
||||||
|
# Verify cache was updated
|
||||||
|
assert ("my_collection", "partition_test") in processor.registered_partitions
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Integration tests for Objects GraphQL Query Service
|
Integration tests for Rows GraphQL Query Service
|
||||||
|
|
||||||
These tests verify end-to-end functionality including:
|
These tests verify end-to-end functionality including:
|
||||||
- Real Cassandra database operations
|
- Real Cassandra database operations
|
||||||
|
|
@ -24,8 +24,8 @@ except Exception:
|
||||||
DOCKER_AVAILABLE = False
|
DOCKER_AVAILABLE = False
|
||||||
CassandraContainer = None
|
CassandraContainer = None
|
||||||
|
|
||||||
from trustgraph.query.objects.cassandra.service import Processor
|
from trustgraph.query.rows.cassandra.service import Processor
|
||||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
|
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
request = ObjectsQueryRequest(
|
request = RowsQueryRequest(
|
||||||
user="msg_test_user",
|
user="msg_test_user",
|
||||||
collection="msg_test_collection",
|
collection="msg_test_collection",
|
||||||
query='{ customer_objects { customer_id name } }',
|
query='{ customer_objects { customer_id name } }',
|
||||||
|
|
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
|
||||||
|
|
||||||
# Verify response structure
|
# Verify response structure
|
||||||
sent_response = mock_response_producer.send.call_args[0][0]
|
sent_response = mock_response_producer.send.call_args[0][0]
|
||||||
assert isinstance(sent_response, ObjectsQueryResponse)
|
assert isinstance(sent_response, RowsQueryResponse)
|
||||||
|
|
||||||
# Should have no system error (even if no data)
|
# Should have no system error (even if no data)
|
||||||
assert sent_response.error is None
|
assert sent_response.error is None
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
Integration tests for Structured Query Service
|
Integration tests for Structured Query Service
|
||||||
|
|
||||||
These tests verify the end-to-end functionality of the structured query service,
|
These tests verify the end-to-end functionality of the structured query service,
|
||||||
testing orchestration between nlp-query and objects-query services.
|
testing orchestration between nlp-query and rows-query services.
|
||||||
Following the TEST_STRATEGY.md approach for integration testing.
|
Following the TEST_STRATEGY.md approach for integration testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
from trustgraph.schema import (
|
from trustgraph.schema import (
|
||||||
StructuredQueryRequest, StructuredQueryResponse,
|
StructuredQueryRequest, StructuredQueryResponse,
|
||||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
RowsQueryRequest, RowsQueryResponse,
|
||||||
Error, GraphQLError
|
Error, GraphQLError
|
||||||
)
|
)
|
||||||
from trustgraph.retrieval.structured_query.service import Processor
|
from trustgraph.retrieval.structured_query.service import Processor
|
||||||
|
|
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock Objects Query Service Response
|
# Mock Objects Query Service Response
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
|
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
# Verify Objects service call
|
# Verify Objects service call
|
||||||
mock_objects_client.request.assert_called_once()
|
mock_objects_client.request.assert_called_once()
|
||||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||||
assert "customers" in objects_call_args.query
|
assert "customers" in objects_call_args.query
|
||||||
assert "orders" in objects_call_args.query
|
assert "orders" in objects_call_args.query
|
||||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||||
|
|
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock Objects service failure
|
# Mock Objects service failure
|
||||||
objects_error_response = ObjectsQueryResponse(
|
objects_error_response = RowsQueryResponse(
|
||||||
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
|
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
|
||||||
data=None,
|
data=None,
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
|
|
||||||
assert response.error is not None
|
assert response.error is not None
|
||||||
assert response.error.type == "structured-query-error"
|
assert response.error.type == "structured-query-error"
|
||||||
assert "Objects query service error" in response.error.message
|
assert "Rows query service error" in response.error.message
|
||||||
assert "nonexistent_table" in response.error.message
|
assert "nonexistent_table" in response.error.message
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data=None, # No data when validation fails
|
data=None, # No data when validation fails
|
||||||
errors=validation_errors,
|
errors=validation_errors,
|
||||||
|
|
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data=json.dumps(complex_data),
|
data=json.dumps(complex_data),
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock empty Objects response
|
# Mock empty Objects response
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"customers": []}', # Empty result set
|
data='{"customers": []}', # Empty result set
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
confidence=0.9
|
confidence=0.9
|
||||||
)
|
)
|
||||||
|
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
service_call_count += 1
|
service_call_count += 1
|
||||||
return nlp_client
|
return nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
service_call_count += 1
|
service_call_count += 1
|
||||||
return objects_client
|
return objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
|
|
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock Objects response
|
# Mock Objects response
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
|
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
|
||||||
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
380
tests/unit/test_embeddings/test_row_embeddings_processor.py
Normal file
|
|
@ -0,0 +1,380 @@
|
||||||
|
"""
|
||||||
|
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
|
||||||
|
Tests the Stage 1 processor that computes embeddings for row index fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
|
||||||
|
"""Test row embeddings processor functionality"""
|
||||||
|
|
||||||
|
async def test_processor_initialization(self):
|
||||||
|
"""Test basic processor initialization"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-row-embeddings'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
assert hasattr(processor, 'schemas')
|
||||||
|
assert processor.schemas == {}
|
||||||
|
assert processor.batch_size == 10 # default
|
||||||
|
|
||||||
|
async def test_processor_initialization_with_custom_batch_size(self):
|
||||||
|
"""Test processor initialization with custom batch size"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-row-embeddings',
|
||||||
|
'batch_size': 25
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
assert processor.batch_size == 25
|
||||||
|
|
||||||
|
async def test_get_index_names_single_index(self):
|
||||||
|
"""Test getting index names with single indexed field"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
from trustgraph.schema import RowSchema, Field
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
schema = RowSchema(
|
||||||
|
name='customers',
|
||||||
|
description='Customer records',
|
||||||
|
fields=[
|
||||||
|
Field(name='id', type='text', primary=True),
|
||||||
|
Field(name='name', type='text', indexed=True),
|
||||||
|
Field(name='email', type='text', indexed=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
index_names = processor.get_index_names(schema)
|
||||||
|
|
||||||
|
# Should include primary key and indexed field
|
||||||
|
assert 'id' in index_names
|
||||||
|
assert 'name' in index_names
|
||||||
|
assert 'email' not in index_names
|
||||||
|
|
||||||
|
async def test_get_index_names_no_indexes(self):
|
||||||
|
"""Test getting index names when no fields are indexed"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
from trustgraph.schema import RowSchema, Field
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
schema = RowSchema(
|
||||||
|
name='logs',
|
||||||
|
description='Log records',
|
||||||
|
fields=[
|
||||||
|
Field(name='timestamp', type='text'),
|
||||||
|
Field(name='message', type='text'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
index_names = processor.get_index_names(schema)
|
||||||
|
|
||||||
|
assert index_names == []
|
||||||
|
|
||||||
|
async def test_build_index_value_single_field(self):
|
||||||
|
"""Test building index value for single field"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
value_map = {
|
||||||
|
'id': 'CUST001',
|
||||||
|
'name': 'John Doe',
|
||||||
|
'email': 'john@example.com'
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor.build_index_value(value_map, 'name')
|
||||||
|
|
||||||
|
assert result == ['John Doe']
|
||||||
|
|
||||||
|
async def test_build_index_value_composite_index(self):
|
||||||
|
"""Test building index value for composite index"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
value_map = {
|
||||||
|
'first_name': 'John',
|
||||||
|
'last_name': 'Doe',
|
||||||
|
'city': 'New York'
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor.build_index_value(value_map, 'first_name, last_name')
|
||||||
|
|
||||||
|
assert result == ['John', 'Doe']
|
||||||
|
|
||||||
|
async def test_build_index_value_missing_field(self):
|
||||||
|
"""Test building index value when field is missing"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
value_map = {
|
||||||
|
'name': 'John Doe'
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor.build_index_value(value_map, 'missing_field')
|
||||||
|
|
||||||
|
assert result == ['']
|
||||||
|
|
||||||
|
async def test_build_text_for_embedding_single_value(self):
|
||||||
|
"""Test building text representation for single value"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
result = processor.build_text_for_embedding(['John Doe'])
|
||||||
|
|
||||||
|
assert result == 'John Doe'
|
||||||
|
|
||||||
|
async def test_build_text_for_embedding_multiple_values(self):
|
||||||
|
"""Test building text representation for multiple values"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
|
||||||
|
|
||||||
|
assert result == 'John Doe NYC'
|
||||||
|
|
||||||
|
async def test_on_schema_config_loads_schemas(self):
|
||||||
|
"""Test that schema configuration is loaded correctly"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor',
|
||||||
|
'config_type': 'schema'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
schema_def = {
|
||||||
|
'name': 'customers',
|
||||||
|
'description': 'Customer records',
|
||||||
|
'fields': [
|
||||||
|
{'name': 'id', 'type': 'text', 'primary_key': True},
|
||||||
|
{'name': 'name', 'type': 'text', 'indexed': True},
|
||||||
|
{'name': 'email', 'type': 'text'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
config_data = {
|
||||||
|
'schema': {
|
||||||
|
'customers': json.dumps(schema_def)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await processor.on_schema_config(config_data, 1)
|
||||||
|
|
||||||
|
assert 'customers' in processor.schemas
|
||||||
|
assert processor.schemas['customers'].name == 'customers'
|
||||||
|
assert len(processor.schemas['customers'].fields) == 3
|
||||||
|
|
||||||
|
async def test_on_schema_config_handles_missing_type(self):
|
||||||
|
"""Test that missing schema type is handled gracefully"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor',
|
||||||
|
'config_type': 'schema'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
config_data = {
|
||||||
|
'other_type': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
await processor.on_schema_config(config_data, 1)
|
||||||
|
|
||||||
|
assert processor.schemas == {}
|
||||||
|
|
||||||
|
async def test_on_message_drops_unknown_collection(self):
|
||||||
|
"""Test that messages for unknown collections are dropped"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
from trustgraph.schema import ExtractedObject
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
# No collections registered
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'unknown_user'
|
||||||
|
metadata.collection = 'unknown_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
obj = ExtractedObject(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='customers',
|
||||||
|
values=[{'id': '123', 'name': 'Test'}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = obj
|
||||||
|
|
||||||
|
mock_flow = MagicMock()
|
||||||
|
|
||||||
|
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||||
|
|
||||||
|
# Flow should not be called for output
|
||||||
|
mock_flow.assert_not_called()
|
||||||
|
|
||||||
|
async def test_on_message_drops_unknown_schema(self):
|
||||||
|
"""Test that messages for unknown schemas are dropped"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
from trustgraph.schema import ExtractedObject
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
# No schemas registered
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'test_user'
|
||||||
|
metadata.collection = 'test_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
obj = ExtractedObject(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='unknown_schema',
|
||||||
|
values=[{'id': '123', 'name': 'Test'}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = obj
|
||||||
|
|
||||||
|
mock_flow = MagicMock()
|
||||||
|
|
||||||
|
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||||
|
|
||||||
|
# Flow should not be called for output
|
||||||
|
mock_flow.assert_not_called()
|
||||||
|
|
||||||
|
async def test_on_message_processes_embeddings(self):
|
||||||
|
"""Test processing a message and computing embeddings"""
|
||||||
|
from trustgraph.embeddings.row_embeddings.embeddings import Processor
|
||||||
|
from trustgraph.schema import ExtractedObject, RowSchema, Field
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor',
|
||||||
|
'config_type': 'schema'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
|
# Set up schema
|
||||||
|
processor.schemas['customers'] = RowSchema(
|
||||||
|
name='customers',
|
||||||
|
description='Customer records',
|
||||||
|
fields=[
|
||||||
|
Field(name='id', type='text', primary=True),
|
||||||
|
Field(name='name', type='text', indexed=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'test_user'
|
||||||
|
metadata.collection = 'test_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
obj = ExtractedObject(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='customers',
|
||||||
|
values=[
|
||||||
|
{'id': 'CUST001', 'name': 'John Doe'},
|
||||||
|
{'id': 'CUST002', 'name': 'Jane Smith'}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = obj
|
||||||
|
|
||||||
|
# Mock the flow
|
||||||
|
mock_embeddings_request = AsyncMock()
|
||||||
|
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
|
mock_output = AsyncMock()
|
||||||
|
|
||||||
|
def flow_factory(name):
|
||||||
|
if name == 'embeddings-request':
|
||||||
|
return mock_embeddings_request
|
||||||
|
elif name == 'output':
|
||||||
|
return mock_output
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
mock_flow = MagicMock(side_effect=flow_factory)
|
||||||
|
|
||||||
|
await processor.on_message(mock_msg, MagicMock(), mock_flow)
|
||||||
|
|
||||||
|
# Should have called embed for each unique text
|
||||||
|
# 4 values: CUST001, John Doe, CUST002, Jane Smith
|
||||||
|
assert mock_embeddings_request.embed.call_count == 4
|
||||||
|
|
||||||
|
# Should have sent output
|
||||||
|
mock_output.send.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Unit tests for objects import dispatcher.
|
Unit tests for rows import dispatcher.
|
||||||
|
|
||||||
Tests the business logic of objects import dispatcher
|
Tests the business logic of rows import dispatcher
|
||||||
while mocking the Publisher and websocket components.
|
while mocking the Publisher and websocket components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -11,7 +11,7 @@ import asyncio
|
||||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
|
from trustgraph.gateway.dispatch.rows_import import RowsImport
|
||||||
from trustgraph.schema import Metadata, ExtractedObject
|
from trustgraph.schema import Metadata, ExtractedObject
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -92,16 +92,16 @@ def minimal_objects_message():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportInitialization:
|
class TestRowsImportInitialization:
|
||||||
"""Test ObjectsImport initialization."""
|
"""Test RowsImport initialization."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
"""Test that RowsImport creates Publisher with correct parameters."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify instance variables are set correctly
|
# Verify instance variables are set correctly
|
||||||
assert objects_import.ws == mock_websocket
|
assert rows_import.ws == mock_websocket
|
||||||
assert objects_import.running == mock_running
|
assert rows_import.running == mock_running
|
||||||
assert objects_import.publisher == mock_publisher_instance
|
assert rows_import.publisher == mock_publisher_instance
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that ObjectsImport stores all required references."""
|
"""Test that RowsImport stores all required references."""
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
queue="objects-queue"
|
queue="objects-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert objects_import.ws is mock_websocket
|
assert rows_import.ws is mock_websocket
|
||||||
assert objects_import.running is mock_running
|
assert rows_import.running is mock_running
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportLifecycle:
|
class TestRowsImportLifecycle:
|
||||||
"""Test ObjectsImport lifecycle methods."""
|
"""Test RowsImport lifecycle methods."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that start() calls publisher.start()."""
|
"""Test that start() calls publisher.start()."""
|
||||||
|
|
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
|
||||||
mock_publisher_instance.start = AsyncMock()
|
mock_publisher_instance.start = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
await objects_import.start()
|
await rows_import.start()
|
||||||
|
|
||||||
mock_publisher_instance.start.assert_called_once()
|
mock_publisher_instance.start.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||||
|
|
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
|
||||||
mock_publisher_instance.stop = AsyncMock()
|
mock_publisher_instance.stop = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
await objects_import.destroy()
|
await rows_import.destroy()
|
||||||
|
|
||||||
# Verify sequence of operations
|
# Verify sequence of operations
|
||||||
mock_running.stop.assert_called_once()
|
mock_running.stop.assert_called_once()
|
||||||
mock_publisher_instance.stop.assert_called_once()
|
mock_publisher_instance.stop.assert_called_once()
|
||||||
mock_websocket.close.assert_called_once()
|
mock_websocket.close.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||||
"""Test that destroy() handles None websocket gracefully."""
|
"""Test that destroy() handles None websocket gracefully."""
|
||||||
|
|
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
|
||||||
mock_publisher_instance.stop = AsyncMock()
|
mock_publisher_instance.stop = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=None, # None websocket
|
ws=None, # None websocket
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not raise exception
|
# Should not raise exception
|
||||||
await objects_import.destroy()
|
await rows_import.destroy()
|
||||||
|
|
||||||
mock_running.stop.assert_called_once()
|
mock_running.stop.assert_called_once()
|
||||||
mock_publisher_instance.stop.assert_called_once()
|
mock_publisher_instance.stop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportMessageProcessing:
|
class TestRowsImportMessageProcessing:
|
||||||
"""Test ObjectsImport message processing."""
|
"""Test RowsImport message processing."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||||
"""Test that receive() processes complete message correctly."""
|
"""Test that receive() processes complete message correctly."""
|
||||||
|
|
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
mock_msg.json.return_value = sample_objects_message
|
mock_msg.json.return_value = sample_objects_message
|
||||||
|
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
# Verify publisher.send was called
|
# Verify publisher.send was called
|
||||||
mock_publisher_instance.send.assert_called_once()
|
mock_publisher_instance.send.assert_called_once()
|
||||||
|
|
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
assert sent_object.metadata.collection == "testcollection"
|
assert sent_object.metadata.collection == "testcollection"
|
||||||
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||||
"""Test that receive() handles message with minimal required fields."""
|
"""Test that receive() handles message with minimal required fields."""
|
||||||
|
|
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
mock_msg.json.return_value = minimal_objects_message
|
mock_msg.json.return_value = minimal_objects_message
|
||||||
|
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
# Verify publisher.send was called
|
# Verify publisher.send was called
|
||||||
mock_publisher_instance.send.assert_called_once()
|
mock_publisher_instance.send.assert_called_once()
|
||||||
|
|
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
assert sent_object.source_span == "" # Default value
|
assert sent_object.source_span == "" # Default value
|
||||||
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
assert len(sent_object.metadata.metadata) == 0 # Default empty list
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() uses appropriate default values for optional fields."""
|
"""Test that receive() uses appropriate default values for optional fields."""
|
||||||
|
|
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
mock_msg.json.return_value = message_data
|
mock_msg.json.return_value = message_data
|
||||||
|
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
# Get the sent object and verify defaults
|
# Get the sent object and verify defaults
|
||||||
sent_object = mock_publisher_instance.send.call_args[0][1]
|
sent_object = mock_publisher_instance.send.call_args[0][1]
|
||||||
|
|
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
|
||||||
assert sent_object.source_span == ""
|
assert sent_object.source_span == ""
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportRunMethod:
|
class TestRowsImportRunMethod:
|
||||||
"""Test ObjectsImport run method."""
|
"""Test RowsImport run method."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that run() loops while running.get() returns True."""
|
"""Test that run() loops while running.get() returns True."""
|
||||||
|
|
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
|
||||||
# Set up running state to return True twice, then False
|
# Set up running state to return True twice, then False
|
||||||
mock_running.get.side_effect = [True, True, False]
|
mock_running.get.side_effect = [True, True, False]
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
await objects_import.run()
|
await rows_import.run()
|
||||||
|
|
||||||
# Verify sleep was called twice (for the two True iterations)
|
# Verify sleep was called twice (for the two True iterations)
|
||||||
assert mock_sleep.call_count == 2
|
assert mock_sleep.call_count == 2
|
||||||
|
|
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
|
||||||
mock_websocket.close.assert_called_once()
|
mock_websocket.close.assert_called_once()
|
||||||
|
|
||||||
# Verify websocket was set to None
|
# Verify websocket was set to None
|
||||||
assert objects_import.ws is None
|
assert rows_import.ws is None
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||||
"""Test that run() handles None websocket gracefully."""
|
"""Test that run() handles None websocket gracefully."""
|
||||||
|
|
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
|
||||||
|
|
||||||
mock_running.get.return_value = False # Exit immediately
|
mock_running.get.return_value = False # Exit immediately
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=None, # None websocket
|
ws=None, # None websocket
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not raise exception
|
# Should not raise exception
|
||||||
await objects_import.run()
|
await rows_import.run()
|
||||||
|
|
||||||
# Verify websocket remains None
|
# Verify websocket remains None
|
||||||
assert objects_import.ws is None
|
assert rows_import.ws is None
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportBatchProcessing:
|
class TestRowsImportBatchProcessing:
|
||||||
"""Test ObjectsImport batch processing functionality."""
|
"""Test RowsImport batch processing functionality."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def batch_objects_message(self):
|
def batch_objects_message(self):
|
||||||
|
|
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
"source_span": "Multiple people found in document"
|
"source_span": "Multiple people found in document"
|
||||||
}
|
}
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||||
"""Test that receive() processes batch message correctly."""
|
"""Test that receive() processes batch message correctly."""
|
||||||
|
|
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
mock_msg.json.return_value = batch_objects_message
|
mock_msg.json.return_value = batch_objects_message
|
||||||
|
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
# Verify publisher.send was called
|
# Verify publisher.send was called
|
||||||
mock_publisher_instance.send.assert_called_once()
|
mock_publisher_instance.send.assert_called_once()
|
||||||
|
|
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
assert sent_object.confidence == 0.85
|
assert sent_object.confidence == 0.85
|
||||||
assert sent_object.source_span == "Multiple people found in document"
|
assert sent_object.source_span == "Multiple people found in document"
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() handles empty batch correctly."""
|
"""Test that receive() handles empty batch correctly."""
|
||||||
|
|
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
mock_msg.json.return_value = empty_batch_message
|
mock_msg.json.return_value = empty_batch_message
|
||||||
|
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
# Should still send the message
|
# Should still send the message
|
||||||
mock_publisher_instance.send.assert_called_once()
|
mock_publisher_instance.send.assert_called_once()
|
||||||
|
|
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
|
||||||
assert len(sent_object.values) == 0
|
assert len(sent_object.values) == 0
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsImportErrorHandling:
|
class TestRowsImportErrorHandling:
|
||||||
"""Test error handling in ObjectsImport."""
|
"""Test error handling in RowsImport."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||||
"""Test that receive() propagates publisher send errors."""
|
"""Test that receive() propagates publisher send errors."""
|
||||||
|
|
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
|
||||||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
|
||||||
mock_msg.json.return_value = sample_objects_message
|
mock_msg.json.return_value = sample_objects_message
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Publisher error"):
|
with pytest.raises(Exception, match="Publisher error"):
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() handles malformed JSON appropriately."""
|
"""Test that receive() handles malformed JSON appropriately."""
|
||||||
mock_publisher_class.return_value = Mock()
|
mock_publisher_class.return_value = Mock()
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
rows_import = RowsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
|
|
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
|
||||||
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||||
|
|
||||||
with pytest.raises(json.JSONDecodeError):
|
with pytest.raises(json.JSONDecodeError):
|
||||||
await objects_import.receive(mock_msg)
|
await rows_import.receive(mock_msg)
|
||||||
|
|
@ -76,7 +76,7 @@ def cities_schema():
|
||||||
def validator():
|
def validator():
|
||||||
"""Create a mock processor with just the validation method"""
|
"""Create a mock processor with just the validation method"""
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from trustgraph.extract.kg.objects.processor import Processor
|
from trustgraph.extract.kg.rows.processor import Processor
|
||||||
|
|
||||||
# Create a mock processor
|
# Create a mock processor
|
||||||
mock_processor = MagicMock()
|
mock_processor = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ class TestFlowClient:
|
||||||
expected_methods = [
|
expected_methods = [
|
||||||
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||||
'graph_embeddings_query', 'embeddings', 'prompt',
|
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||||
'triples_query', 'objects_query'
|
'triples_query', 'rows_query'
|
||||||
]
|
]
|
||||||
|
|
||||||
for method in expected_methods:
|
for method in expected_methods:
|
||||||
|
|
@ -216,7 +216,7 @@ class TestSocketClient:
|
||||||
expected_methods = [
|
expected_methods = [
|
||||||
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||||
'prompt', 'graph_embeddings_query', 'embeddings',
|
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||||
'triples_query', 'objects_query', 'mcp_tool'
|
'triples_query', 'rows_query', 'mcp_tool'
|
||||||
]
|
]
|
||||||
|
|
||||||
for method in expected_methods:
|
for method in expected_methods:
|
||||||
|
|
@ -243,7 +243,7 @@ class TestBulkClient:
|
||||||
'import_graph_embeddings',
|
'import_graph_embeddings',
|
||||||
'import_document_embeddings',
|
'import_document_embeddings',
|
||||||
'import_entity_contexts',
|
'import_entity_contexts',
|
||||||
'import_objects'
|
'import_rows'
|
||||||
]
|
]
|
||||||
|
|
||||||
for method in import_methods:
|
for method in import_methods:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""
|
"""
|
||||||
Unit tests for Cassandra Objects GraphQL Query Processor
|
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
|
||||||
|
|
||||||
Tests the business logic of the GraphQL query processor including:
|
Tests the business logic of the GraphQL query processor including:
|
||||||
- GraphQL schema generation from RowSchema
|
- Schema configuration handling
|
||||||
- Query execution and validation
|
- Query execution using unified rows table
|
||||||
- CQL translation logic
|
- Name sanitization
|
||||||
|
- GraphQL query execution
|
||||||
- Message processing logic
|
- Message processing logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -12,119 +13,91 @@ import pytest
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import strawberry
|
from trustgraph.query.rows.cassandra.service import Processor
|
||||||
from strawberry import Schema
|
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
|
|
||||||
from trustgraph.query.objects.cassandra.service import Processor
|
|
||||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
|
||||||
from trustgraph.schema import RowSchema, Field
|
from trustgraph.schema import RowSchema, Field
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsGraphQLQueryLogic:
|
class TestRowsGraphQLQueryLogic:
|
||||||
"""Test business logic without external dependencies"""
|
"""Test business logic for unified table query implementation"""
|
||||||
|
|
||||||
def test_get_python_type_mapping(self):
|
|
||||||
"""Test schema field type conversion to Python types"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Basic type mappings
|
|
||||||
assert processor.get_python_type("string") == str
|
|
||||||
assert processor.get_python_type("integer") == int
|
|
||||||
assert processor.get_python_type("float") == float
|
|
||||||
assert processor.get_python_type("boolean") == bool
|
|
||||||
assert processor.get_python_type("timestamp") == str
|
|
||||||
assert processor.get_python_type("date") == str
|
|
||||||
assert processor.get_python_type("time") == str
|
|
||||||
assert processor.get_python_type("uuid") == str
|
|
||||||
|
|
||||||
# Unknown type defaults to str
|
|
||||||
assert processor.get_python_type("unknown_type") == str
|
|
||||||
|
|
||||||
def test_create_graphql_type_basic_fields(self):
|
|
||||||
"""Test GraphQL type creation for basic field types"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
|
||||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create test schema
|
|
||||||
schema = RowSchema(
|
|
||||||
name="test_table",
|
|
||||||
description="Test table",
|
|
||||||
fields=[
|
|
||||||
Field(
|
|
||||||
name="id",
|
|
||||||
type="string",
|
|
||||||
primary=True,
|
|
||||||
required=True,
|
|
||||||
description="Primary key"
|
|
||||||
),
|
|
||||||
Field(
|
|
||||||
name="name",
|
|
||||||
type="string",
|
|
||||||
required=True,
|
|
||||||
description="Name field"
|
|
||||||
),
|
|
||||||
Field(
|
|
||||||
name="age",
|
|
||||||
type="integer",
|
|
||||||
required=False,
|
|
||||||
description="Optional age"
|
|
||||||
),
|
|
||||||
Field(
|
|
||||||
name="active",
|
|
||||||
type="boolean",
|
|
||||||
required=False,
|
|
||||||
description="Status flag"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create GraphQL type
|
|
||||||
graphql_type = processor.create_graphql_type("test_table", schema)
|
|
||||||
|
|
||||||
# Verify type was created
|
|
||||||
assert graphql_type is not None
|
|
||||||
assert hasattr(graphql_type, '__name__')
|
|
||||||
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
|
|
||||||
|
|
||||||
def test_sanitize_name_cassandra_compatibility(self):
|
def test_sanitize_name_cassandra_compatibility(self):
|
||||||
"""Test name sanitization for Cassandra field names"""
|
"""Test name sanitization for Cassandra field names"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
|
||||||
# Test field name sanitization (matches storage processor)
|
# Test field name sanitization (uses r_ prefix like storage processor)
|
||||||
assert processor.sanitize_name("simple_field") == "simple_field"
|
assert processor.sanitize_name("simple_field") == "simple_field"
|
||||||
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
|
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
|
||||||
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
|
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
|
||||||
assert processor.sanitize_name("123_field") == "o_123_field"
|
assert processor.sanitize_name("123_field") == "r_123_field"
|
||||||
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
|
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
|
||||||
assert processor.sanitize_name("special!@#chars") == "special___chars"
|
assert processor.sanitize_name("special!@#chars") == "special___chars"
|
||||||
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||||
assert processor.sanitize_name("CamelCase") == "camelcase"
|
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||||
|
|
||||||
def test_sanitize_table_name(self):
|
def test_get_index_names(self):
|
||||||
"""Test table name sanitization (always gets o_ prefix)"""
|
"""Test extraction of index names from schema"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
|
||||||
# Table names always get o_ prefix
|
schema = RowSchema(
|
||||||
assert processor.sanitize_table("simple_table") == "o_simple_table"
|
name="test_schema",
|
||||||
assert processor.sanitize_table("Table-Name") == "o_table_name"
|
fields=[
|
||||||
assert processor.sanitize_table("123table") == "o_123table"
|
Field(name="id", type="string", primary=True),
|
||||||
assert processor.sanitize_table("") == "o_"
|
Field(name="category", type="string", indexed=True),
|
||||||
|
Field(name="name", type="string"), # Not indexed
|
||||||
|
Field(name="status", type="string", indexed=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
index_names = processor.get_index_names(schema)
|
||||||
|
|
||||||
|
assert "id" in index_names
|
||||||
|
assert "category" in index_names
|
||||||
|
assert "status" in index_names
|
||||||
|
assert "name" not in index_names
|
||||||
|
assert len(index_names) == 3
|
||||||
|
|
||||||
|
def test_find_matching_index_exact_match(self):
|
||||||
|
"""Test finding matching index for exact match query"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(
|
||||||
|
name="test_schema",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="category", type="string", indexed=True),
|
||||||
|
Field(name="name", type="string") # Not indexed
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter on indexed field should return match
|
||||||
|
filters = {"category": "electronics"}
|
||||||
|
result = processor.find_matching_index(schema, filters)
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == "category"
|
||||||
|
assert result[1] == ["electronics"]
|
||||||
|
|
||||||
|
# Filter on non-indexed field should return None
|
||||||
|
filters = {"name": "test"}
|
||||||
|
result = processor.find_matching_index(schema, filters)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_schema_config_parsing(self):
|
async def test_schema_config_parsing(self):
|
||||||
"""Test parsing of schema configuration"""
|
"""Test parsing of schema configuration"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.schemas = {}
|
processor.schemas = {}
|
||||||
processor.graphql_types = {}
|
processor.config_key = "schema"
|
||||||
processor.graphql_schema = None
|
processor.schema_builder = MagicMock()
|
||||||
processor.config_key = "schema" # Set the config key
|
processor.schema_builder.clear = MagicMock()
|
||||||
processor.generate_graphql_schema = AsyncMock()
|
processor.schema_builder.add_schema = MagicMock()
|
||||||
|
processor.schema_builder.build = MagicMock(return_value=MagicMock())
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
|
|
||||||
# Create test config
|
# Create test config
|
||||||
schema_config = {
|
schema_config = {
|
||||||
"schema": {
|
"schema": {
|
||||||
|
|
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process config
|
# Process config
|
||||||
await processor.on_schema_config(schema_config, version=1)
|
await processor.on_schema_config(schema_config, version=1)
|
||||||
|
|
||||||
# Verify schema was loaded
|
# Verify schema was loaded
|
||||||
assert "customer" in processor.schemas
|
assert "customer" in processor.schemas
|
||||||
schema = processor.schemas["customer"]
|
schema = processor.schemas["customer"]
|
||||||
assert schema.name == "customer"
|
assert schema.name == "customer"
|
||||||
assert len(schema.fields) == 3
|
assert len(schema.fields) == 3
|
||||||
|
|
||||||
# Verify fields
|
# Verify fields
|
||||||
id_field = next(f for f in schema.fields if f.name == "id")
|
id_field = next(f for f in schema.fields if f.name == "id")
|
||||||
assert id_field.primary is True
|
assert id_field.primary is True
|
||||||
# The field should have been created correctly from JSON
|
|
||||||
# Let's test what we can verify - that the field has the right attributes
|
|
||||||
assert hasattr(id_field, 'required') # Has the required attribute
|
|
||||||
assert hasattr(id_field, 'primary') # Has the primary attribute
|
|
||||||
|
|
||||||
email_field = next(f for f in schema.fields if f.name == "email")
|
email_field = next(f for f in schema.fields if f.name == "email")
|
||||||
assert email_field.indexed is True
|
assert email_field.indexed is True
|
||||||
|
|
||||||
status_field = next(f for f in schema.fields if f.name == "status")
|
status_field = next(f for f in schema.fields if f.name == "status")
|
||||||
assert status_field.enum_values == ["active", "inactive"]
|
assert status_field.enum_values == ["active", "inactive"]
|
||||||
|
|
||||||
# Verify GraphQL schema regeneration was called
|
|
||||||
processor.generate_graphql_schema.assert_called_once()
|
|
||||||
|
|
||||||
def test_cql_query_building_basic(self):
|
# Verify schema builder was called
|
||||||
"""Test basic CQL query construction"""
|
processor.schema_builder.add_schema.assert_called_once()
|
||||||
processor = MagicMock()
|
processor.schema_builder.build.assert_called_once()
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.connect_cassandra = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
|
|
||||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Mock session execute to capture the query
|
|
||||||
mock_result = []
|
|
||||||
processor.session.execute.return_value = mock_result
|
|
||||||
|
|
||||||
# Create test schema
|
|
||||||
schema = RowSchema(
|
|
||||||
name="test_table",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", primary=True),
|
|
||||||
Field(name="name", type="string", indexed=True),
|
|
||||||
Field(name="status", type="string")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test query building
|
|
||||||
asyncio = pytest.importorskip("asyncio")
|
|
||||||
|
|
||||||
async def run_test():
|
|
||||||
await processor.query_cassandra(
|
|
||||||
user="test_user",
|
|
||||||
collection="test_collection",
|
|
||||||
schema_name="test_table",
|
|
||||||
row_schema=schema,
|
|
||||||
filters={"name": "John", "invalid_filter": "ignored"},
|
|
||||||
limit=10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the async test
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(run_test())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
# Verify Cassandra connection and query execution
|
|
||||||
processor.connect_cassandra.assert_called_once()
|
|
||||||
processor.session.execute.assert_called_once()
|
|
||||||
|
|
||||||
# Verify the query structure (can't easily test exact query without complex mocking)
|
|
||||||
call_args = processor.session.execute.call_args
|
|
||||||
query = call_args[0][0] # First positional argument is the query
|
|
||||||
params = call_args[0][1] # Second positional argument is parameters
|
|
||||||
|
|
||||||
# Basic query structure checks
|
|
||||||
assert "SELECT * FROM test_user.o_test_table" in query
|
|
||||||
assert "WHERE" in query
|
|
||||||
assert "collection = %s" in query
|
|
||||||
assert "LIMIT 10" in query
|
|
||||||
|
|
||||||
# Parameters should include collection and name filter
|
|
||||||
assert "test_collection" in params
|
|
||||||
assert "John" in params
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graphql_context_handling(self):
|
async def test_graphql_context_handling(self):
|
||||||
|
|
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.graphql_schema = AsyncMock()
|
processor.graphql_schema = AsyncMock()
|
||||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock schema execution
|
# Mock schema execution
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
|
||||||
mock_result.errors = None
|
mock_result.errors = None
|
||||||
processor.graphql_schema.execute.return_value = mock_result
|
processor.graphql_schema.execute.return_value = mock_result
|
||||||
|
|
||||||
result = await processor.execute_graphql_query(
|
result = await processor.execute_graphql_query(
|
||||||
query='{ customers { id name } }',
|
query='{ customers { id name } }',
|
||||||
variables={},
|
variables={},
|
||||||
|
|
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify schema.execute was called with correct context
|
# Verify schema.execute was called with correct context
|
||||||
processor.graphql_schema.execute.assert_called_once()
|
processor.graphql_schema.execute.assert_called_once()
|
||||||
call_args = processor.graphql_schema.execute.call_args
|
call_args = processor.graphql_schema.execute.call_args
|
||||||
|
|
||||||
# Verify context was passed
|
# Verify context was passed
|
||||||
context = call_args[1]['context_value'] # keyword argument
|
context = call_args[1]['context_value']
|
||||||
assert context["processor"] == processor
|
assert context["processor"] == processor
|
||||||
assert context["user"] == "test_user"
|
assert context["user"] == "test_user"
|
||||||
assert context["collection"] == "test_collection"
|
assert context["collection"] == "test_collection"
|
||||||
|
|
||||||
# Verify result structure
|
# Verify result structure
|
||||||
assert "data" in result
|
assert "data" in result
|
||||||
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
|
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
|
||||||
|
|
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.graphql_schema = AsyncMock()
|
processor.graphql_schema = AsyncMock()
|
||||||
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
|
||||||
|
|
||||||
# Create a simple object to simulate GraphQL error instead of MagicMock
|
# Create a simple object to simulate GraphQL error
|
||||||
class MockError:
|
class MockError:
|
||||||
def __init__(self, message, path, extensions):
|
def __init__(self, message, path, extensions):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.path = path
|
self.path = path
|
||||||
self.extensions = extensions
|
self.extensions = extensions
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.message
|
return self.message
|
||||||
|
|
||||||
mock_error = MockError(
|
mock_error = MockError(
|
||||||
message="Field 'invalid_field' doesn't exist",
|
message="Field 'invalid_field' doesn't exist",
|
||||||
path=["customers", "0", "invalid_field"],
|
path=["customers", "0", "invalid_field"],
|
||||||
extensions={"code": "FIELD_NOT_FOUND"}
|
extensions={"code": "FIELD_NOT_FOUND"}
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.data = None
|
mock_result.data = None
|
||||||
mock_result.errors = [mock_error]
|
mock_result.errors = [mock_error]
|
||||||
processor.graphql_schema.execute.return_value = mock_result
|
processor.graphql_schema.execute.return_value = mock_result
|
||||||
|
|
||||||
result = await processor.execute_graphql_query(
|
result = await processor.execute_graphql_query(
|
||||||
query='{ customers { invalid_field } }',
|
query='{ customers { invalid_field } }',
|
||||||
variables={},
|
variables={},
|
||||||
operation_name=None,
|
operation_name=None,
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify error handling
|
# Verify error handling
|
||||||
assert "errors" in result
|
assert "errors" in result
|
||||||
assert len(result["errors"]) == 1
|
assert len(result["errors"]) == 1
|
||||||
|
|
||||||
error = result["errors"][0]
|
error = result["errors"][0]
|
||||||
assert error["message"] == "Field 'invalid_field' doesn't exist"
|
assert error["message"] == "Field 'invalid_field' doesn't exist"
|
||||||
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
|
assert error["path"] == ["customers", "0", "invalid_field"]
|
||||||
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
|
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
|
||||||
|
|
||||||
def test_schema_generation_basic_structure(self):
|
|
||||||
"""Test basic GraphQL schema generation structure"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"customer": RowSchema(
|
|
||||||
name="customer",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", primary=True),
|
|
||||||
Field(name="name", type="string")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.graphql_types = {}
|
|
||||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
|
||||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Test individual type creation (avoiding the full schema generation which has annotation issues)
|
|
||||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
|
||||||
processor.graphql_types["customer"] = graphql_type
|
|
||||||
|
|
||||||
# Verify type was created
|
|
||||||
assert len(processor.graphql_types) == 1
|
|
||||||
assert "customer" in processor.graphql_types
|
|
||||||
assert processor.graphql_types["customer"] is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_processing_success(self):
|
async def test_message_processing_success(self):
|
||||||
"""Test successful message processing flow"""
|
"""Test successful message processing flow"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.execute_graphql_query = AsyncMock()
|
processor.execute_graphql_query = AsyncMock()
|
||||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock successful query result
|
# Mock successful query result
|
||||||
processor.execute_graphql_query.return_value = {
|
processor.execute_graphql_query.return_value = {
|
||||||
"data": {"customers": [{"id": "1", "name": "John"}]},
|
"data": {"customers": [{"id": "1", "name": "John"}]},
|
||||||
"errors": [],
|
"errors": [],
|
||||||
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
|
"extensions": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_request = ObjectsQueryRequest(
|
mock_request = RowsQueryRequest(
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
query='{ customers { id name } }',
|
query='{ customers { id name } }',
|
||||||
variables={},
|
variables={},
|
||||||
operation_name=None
|
operation_name=None
|
||||||
)
|
)
|
||||||
mock_msg.value.return_value = mock_request
|
mock_msg.value.return_value = mock_request
|
||||||
mock_msg.properties.return_value = {"id": "test-123"}
|
mock_msg.properties.return_value = {"id": "test-123"}
|
||||||
|
|
||||||
# Mock flow
|
# Mock flow
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_response_flow = AsyncMock()
|
mock_response_flow = AsyncMock()
|
||||||
mock_flow.return_value = mock_response_flow
|
mock_flow.return_value = mock_response_flow
|
||||||
|
|
||||||
# Process message
|
# Process message
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
# Verify query was executed
|
# Verify query was executed
|
||||||
processor.execute_graphql_query.assert_called_once_with(
|
processor.execute_graphql_query.assert_called_once_with(
|
||||||
query='{ customers { id name } }',
|
query='{ customers { id name } }',
|
||||||
|
|
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify response was sent
|
# Verify response was sent
|
||||||
mock_response_flow.send.assert_called_once()
|
mock_response_flow.send.assert_called_once()
|
||||||
response_call = mock_response_flow.send.call_args[0][0]
|
response_call = mock_response_flow.send.call_args[0][0]
|
||||||
|
|
||||||
# Verify response structure
|
# Verify response structure
|
||||||
assert isinstance(response_call, ObjectsQueryResponse)
|
assert isinstance(response_call, RowsQueryResponse)
|
||||||
assert response_call.error is None
|
assert response_call.error is None
|
||||||
assert '"customers"' in response_call.data # JSON encoded
|
assert '"customers"' in response_call.data # JSON encoded
|
||||||
assert len(response_call.errors) == 0
|
assert len(response_call.errors) == 0
|
||||||
|
|
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.execute_graphql_query = AsyncMock()
|
processor.execute_graphql_query = AsyncMock()
|
||||||
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
processor.on_message = Processor.on_message.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock query execution error
|
# Mock query execution error
|
||||||
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
|
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_request = ObjectsQueryRequest(
|
mock_request = RowsQueryRequest(
|
||||||
user="test_user",
|
user="test_user",
|
||||||
collection="test_collection",
|
collection="test_collection",
|
||||||
query='{ invalid_query }',
|
query='{ invalid_query }',
|
||||||
|
|
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
|
||||||
)
|
)
|
||||||
mock_msg.value.return_value = mock_request
|
mock_msg.value.return_value = mock_request
|
||||||
mock_msg.properties.return_value = {"id": "test-456"}
|
mock_msg.properties.return_value = {"id": "test-456"}
|
||||||
|
|
||||||
# Mock flow
|
# Mock flow
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_response_flow = AsyncMock()
|
mock_response_flow = AsyncMock()
|
||||||
mock_flow.return_value = mock_response_flow
|
mock_flow.return_value = mock_response_flow
|
||||||
|
|
||||||
# Process message
|
# Process message
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
# Verify error response was sent
|
# Verify error response was sent
|
||||||
mock_response_flow.send.assert_called_once()
|
mock_response_flow.send.assert_called_once()
|
||||||
response_call = mock_response_flow.send.call_args[0][0]
|
response_call = mock_response_flow.send.call_args[0][0]
|
||||||
|
|
||||||
# Verify error response structure
|
# Verify error response structure
|
||||||
assert isinstance(response_call, ObjectsQueryResponse)
|
assert isinstance(response_call, RowsQueryResponse)
|
||||||
assert response_call.error is not None
|
assert response_call.error is not None
|
||||||
assert response_call.error.type == "objects-query-error"
|
assert response_call.error.type == "rows-query-error"
|
||||||
assert "No schema available" in response_call.error.message
|
assert "No schema available" in response_call.error.message
|
||||||
assert response_call.data is None
|
assert response_call.data is None
|
||||||
|
|
||||||
|
|
||||||
class TestCQLQueryGeneration:
|
class TestUnifiedTableQueries:
|
||||||
"""Test CQL query generation logic in isolation"""
|
"""Test queries against the unified rows table"""
|
||||||
|
|
||||||
def test_partition_key_inclusion(self):
|
@pytest.mark.asyncio
|
||||||
"""Test that collection is always included in queries"""
|
async def test_query_with_index_match(self):
|
||||||
|
"""Test query execution with matching index"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.connect_cassandra = MagicMock()
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
# Mock the query building (simplified version)
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
keyspace = processor.sanitize_name("test_user")
|
|
||||||
table = processor.sanitize_table("test_table")
|
# Mock session execute to return test data
|
||||||
|
mock_row = MagicMock()
|
||||||
query = f"SELECT * FROM {keyspace}.{table}"
|
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||||
where_clauses = ["collection = %s"]
|
processor.session.execute.return_value = [mock_row]
|
||||||
|
|
||||||
assert "collection = %s" in where_clauses
|
schema = RowSchema(
|
||||||
assert keyspace == "test_user"
|
name="products",
|
||||||
assert table == "o_test_table"
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="category", type="string", indexed=True),
|
||||||
|
Field(name="name", type="string")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with filter on indexed field
|
||||||
|
results = await processor.query_cassandra(
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
schema_name="products",
|
||||||
|
row_schema=schema,
|
||||||
|
filters={"category": "electronics"},
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify Cassandra was connected and queried
|
||||||
|
processor.connect_cassandra.assert_called_once()
|
||||||
|
processor.session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# Verify query structure - should query unified rows table
|
||||||
|
call_args = processor.session.execute.call_args
|
||||||
|
query = call_args[0][0]
|
||||||
|
params = call_args[0][1]
|
||||||
|
|
||||||
|
assert "SELECT data, source FROM test_user.rows" in query
|
||||||
|
assert "collection = %s" in query
|
||||||
|
assert "schema_name = %s" in query
|
||||||
|
assert "index_name = %s" in query
|
||||||
|
assert "index_value = %s" in query
|
||||||
|
|
||||||
|
assert params[0] == "test_collection"
|
||||||
|
assert params[1] == "products"
|
||||||
|
assert params[2] == "category"
|
||||||
|
assert params[3] == ["electronics"]
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["id"] == "123"
|
||||||
|
assert results[0]["category"] == "electronics"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_without_index_match(self):
|
||||||
|
"""Test query execution without matching index (scan mode)"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.connect_cassandra = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Mock session execute to return test data
|
||||||
|
mock_row1 = MagicMock()
|
||||||
|
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||||
|
mock_row2 = MagicMock()
|
||||||
|
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
||||||
|
processor.session.execute.return_value = [mock_row1, mock_row2]
|
||||||
|
|
||||||
|
schema = RowSchema(
|
||||||
|
name="products",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="name", type="string"), # Not indexed
|
||||||
|
Field(name="price", type="string") # Not indexed
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with filter on non-indexed field
|
||||||
|
results = await processor.query_cassandra(
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
schema_name="products",
|
||||||
|
row_schema=schema,
|
||||||
|
filters={"name": "Product A"},
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query should use ALLOW FILTERING for scan
|
||||||
|
call_args = processor.session.execute.call_args
|
||||||
|
query = call_args[0][0]
|
||||||
|
|
||||||
|
assert "ALLOW FILTERING" in query
|
||||||
|
|
||||||
|
# Should post-filter results
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "Product A"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilterMatching:
|
||||||
|
"""Test filter matching logic"""
|
||||||
|
|
||||||
|
def test_matches_filters_exact_match(self):
|
||||||
|
"""Test exact match filter"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||||
|
|
||||||
|
row = {"status": "active", "name": "test"}
|
||||||
|
assert processor._matches_filters(row, {"status": "active"}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
|
||||||
|
|
||||||
|
def test_matches_filters_comparison_operators(self):
|
||||||
|
"""Test comparison operators in filters"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
|
||||||
|
|
||||||
|
row = {"price": "100.0"}
|
||||||
|
|
||||||
|
# Greater than
|
||||||
|
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
|
||||||
|
|
||||||
|
# Less than
|
||||||
|
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
|
||||||
|
|
||||||
|
# Greater than or equal
|
||||||
|
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
|
||||||
|
|
||||||
|
# Less than or equal
|
||||||
|
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
|
||||||
|
|
||||||
|
def test_matches_filters_contains(self):
|
||||||
|
"""Test contains filter"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
|
||||||
|
|
||||||
|
row = {"description": "A great product for everyone"}
|
||||||
|
|
||||||
|
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
|
||||||
|
|
||||||
|
def test_matches_filters_in_list(self):
|
||||||
|
"""Test in-list filter"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
|
||||||
|
|
||||||
|
row = {"status": "active"}
|
||||||
|
|
||||||
|
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
|
||||||
|
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestIndexedFieldFiltering:
|
||||||
|
"""Test that only indexed or primary key fields can be directly filtered"""
|
||||||
|
|
||||||
def test_indexed_field_filtering(self):
|
def test_indexed_field_filtering(self):
|
||||||
"""Test that only indexed or primary key fields can be filtered"""
|
"""Test that only indexed or primary key fields can be filtered"""
|
||||||
# Create schema with mixed field types
|
|
||||||
schema = RowSchema(
|
schema = RowSchema(
|
||||||
name="test",
|
name="test",
|
||||||
fields=[
|
fields=[
|
||||||
Field(name="id", type="string", primary=True),
|
Field(name="id", type="string", primary=True),
|
||||||
Field(name="indexed_field", type="string", indexed=True),
|
Field(name="indexed_field", type="string", indexed=True),
|
||||||
Field(name="normal_field", type="string", indexed=False),
|
Field(name="normal_field", type="string", indexed=False),
|
||||||
Field(name="another_field", type="string")
|
Field(name="another_field", type="string")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
filters = {
|
filters = {
|
||||||
"id": "test123", # Primary key - should be included
|
"id": "test123", # Primary key - should be included
|
||||||
"indexed_field": "value", # Indexed - should be included
|
"indexed_field": "value", # Indexed - should be included
|
||||||
"normal_field": "ignored", # Not indexed - should be ignored
|
"normal_field": "ignored", # Not indexed - should be ignored
|
||||||
"another_field": "also_ignored" # Not indexed - should be ignored
|
"another_field": "also_ignored" # Not indexed - should be ignored
|
||||||
}
|
}
|
||||||
|
|
||||||
# Simulate the filtering logic from the processor
|
# Simulate the filtering logic from the processor
|
||||||
valid_filters = []
|
valid_filters = []
|
||||||
for field_name, value in filters.items():
|
for field_name, value in filters.items():
|
||||||
|
|
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
|
||||||
schema_field = next((f for f in schema.fields if f.name == field_name), None)
|
schema_field = next((f for f in schema.fields if f.name == field_name), None)
|
||||||
if schema_field and (schema_field.indexed or schema_field.primary):
|
if schema_field and (schema_field.indexed or schema_field.primary):
|
||||||
valid_filters.append((field_name, value))
|
valid_filters.append((field_name, value))
|
||||||
|
|
||||||
# Only id and indexed_field should be included
|
# Only id and indexed_field should be included
|
||||||
assert len(valid_filters) == 2
|
assert len(valid_filters) == 2
|
||||||
field_names = [f[0] for f in valid_filters]
|
field_names = [f[0] for f in valid_filters]
|
||||||
|
|
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
|
||||||
assert "indexed_field" in field_names
|
assert "indexed_field" in field_names
|
||||||
assert "normal_field" not in field_names
|
assert "normal_field" not in field_names
|
||||||
assert "another_field" not in field_names
|
assert "another_field" not in field_names
|
||||||
|
|
||||||
|
|
||||||
class TestGraphQLSchemaGeneration:
|
|
||||||
"""Test GraphQL schema generation in detail"""
|
|
||||||
|
|
||||||
def test_field_type_annotations(self):
|
|
||||||
"""Test that GraphQL types have correct field annotations"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
|
||||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create schema with various field types
|
|
||||||
schema = RowSchema(
|
|
||||||
name="test",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", required=True, primary=True),
|
|
||||||
Field(name="count", type="integer", required=True),
|
|
||||||
Field(name="price", type="float", required=False),
|
|
||||||
Field(name="active", type="boolean", required=False),
|
|
||||||
Field(name="optional_text", type="string", required=False)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create GraphQL type
|
|
||||||
graphql_type = processor.create_graphql_type("test", schema)
|
|
||||||
|
|
||||||
# Verify type was created successfully
|
|
||||||
assert graphql_type is not None
|
|
||||||
|
|
||||||
def test_basic_type_creation(self):
|
|
||||||
"""Test that GraphQL types are created correctly"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"customer": RowSchema(
|
|
||||||
name="customer",
|
|
||||||
fields=[Field(name="id", type="string", primary=True)]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.graphql_types = {}
|
|
||||||
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
|
|
||||||
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create GraphQL type directly
|
|
||||||
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
|
|
||||||
processor.graphql_types["customer"] = graphql_type
|
|
||||||
|
|
||||||
# Verify customer type was created
|
|
||||||
assert "customer" in processor.graphql_types
|
|
||||||
assert processor.graphql_types["customer"] is not None
|
|
||||||
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from trustgraph.schema import (
|
from trustgraph.schema import (
|
||||||
StructuredQueryRequest, StructuredQueryResponse,
|
StructuredQueryRequest, StructuredQueryResponse,
|
||||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
RowsQueryRequest, RowsQueryResponse,
|
||||||
Error, GraphQLError
|
Error, GraphQLError
|
||||||
)
|
)
|
||||||
from trustgraph.retrieval.structured_query.service import Processor
|
from trustgraph.retrieval.structured_query.service import Processor
|
||||||
|
|
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock objects query service response
|
# Mock objects query service response
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
|
||||||
# Verify objects query service was called correctly
|
# Verify objects query service was called correctly
|
||||||
mock_objects_client.request.assert_called_once()
|
mock_objects_client.request.assert_called_once()
|
||||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
assert isinstance(objects_call_args, RowsQueryRequest)
|
||||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||||
assert objects_call_args.variables == {"state": "NY"}
|
assert objects_call_args.variables == {"state": "NY"}
|
||||||
assert objects_call_args.user == "trustgraph"
|
assert objects_call_args.user == "trustgraph"
|
||||||
|
|
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
|
||||||
assert response.error is not None
|
assert response.error is not None
|
||||||
assert "empty GraphQL query" in response.error.message
|
assert "empty GraphQL query" in response.error.message
|
||||||
|
|
||||||
async def test_objects_query_service_error(self, processor):
|
async def test_rows_query_service_error(self, processor):
|
||||||
"""Test handling of objects query service errors"""
|
"""Test handling of objects query service errors"""
|
||||||
# Arrange
|
# Arrange
|
||||||
request = StructuredQueryRequest(
|
request = StructuredQueryRequest(
|
||||||
|
|
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock objects query service error
|
# Mock objects query service error
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
|
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
|
||||||
data=None,
|
data=None,
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
|
||||||
response = response_call[0][0]
|
response = response_call[0][0]
|
||||||
|
|
||||||
assert response.error is not None
|
assert response.error is not None
|
||||||
assert "Objects query service error" in response.error.message
|
assert "Rows query service error" in response.error.message
|
||||||
assert "Table 'customers' not found" in response.error.message
|
assert "Table 'customers' not found" in response.error.message
|
||||||
|
|
||||||
async def test_graphql_errors_handling(self, processor):
|
async def test_graphql_errors_handling(self, processor):
|
||||||
|
|
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data=None,
|
data=None,
|
||||||
errors=graphql_errors,
|
errors=graphql_errors,
|
||||||
|
|
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock objects response
|
# Mock objects response
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
|
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
|
||||||
errors=None
|
errors=None
|
||||||
|
|
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
|
||||||
confidence=0.9
|
confidence=0.9
|
||||||
)
|
)
|
||||||
|
|
||||||
objects_response = ObjectsQueryResponse(
|
objects_response = RowsQueryResponse(
|
||||||
error=None,
|
error=None,
|
||||||
data=None, # Null data
|
data=None, # Null data
|
||||||
errors=None,
|
errors=None,
|
||||||
|
|
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
|
||||||
def flow_router(service_name):
|
def flow_router(service_name):
|
||||||
if service_name == "nlp-query-request":
|
if service_name == "nlp-query-request":
|
||||||
return mock_nlp_client
|
return mock_nlp_client
|
||||||
elif service_name == "objects-query-request":
|
elif service_name == "rows-query-request":
|
||||||
return mock_objects_client
|
return mock_objects_client
|
||||||
elif service_name == "response":
|
elif service_name == "response":
|
||||||
return flow_response
|
return flow_response
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import pytest
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||||
|
|
||||||
|
|
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
|
||||||
assert processor.cassandra_password is None
|
assert processor.cassandra_password is None
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsWriterConfiguration:
|
class TestRowsWriterConfiguration:
|
||||||
"""Test Cassandra configuration in objects writer processor."""
|
"""Test Cassandra configuration in objects writer processor."""
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
def test_environment_variable_configuration(self, mock_cluster):
|
def test_environment_variable_configuration(self, mock_cluster):
|
||||||
"""Test processor picks up configuration from environment variables."""
|
"""Test processor picks up configuration from environment variables."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
|
|
||||||
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
|
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
|
||||||
assert processor.cassandra_username == 'obj-env-user'
|
assert processor.cassandra_username == 'obj-env-user'
|
||||||
assert processor.cassandra_password == 'obj-env-pass'
|
assert processor.cassandra_password == 'obj-env-pass'
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
|
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
|
||||||
"""Test that Cassandra connection uses hosts list correctly."""
|
"""Test that Cassandra connection uses hosts list correctly."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Verify cluster was called with hosts list
|
# Verify cluster was called with hosts list
|
||||||
|
|
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
|
||||||
assert 'contact_points' in call_args.kwargs
|
assert 'contact_points' in call_args.kwargs
|
||||||
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
|
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
|
||||||
|
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
|
||||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
|
||||||
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
|
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
|
||||||
"""Test authentication is configured when credentials are provided."""
|
"""Test authentication is configured when credentials are provided."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
|
||||||
mock_cluster.return_value = mock_cluster_instance
|
mock_cluster.return_value = mock_cluster_instance
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
processor = RowsWriter(taskgroup=MagicMock())
|
||||||
processor.connect_cassandra()
|
processor.connect_cassandra()
|
||||||
|
|
||||||
# Verify auth provider was created with correct credentials
|
# Verify auth provider was created with correct credentials
|
||||||
|
|
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
|
||||||
def test_objects_writer_add_args(self):
|
def test_objects_writer_add_args(self):
|
||||||
"""Test that objects writer adds standard Cassandra arguments."""
|
"""Test that objects writer adds standard Cassandra arguments."""
|
||||||
import argparse
|
import argparse
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ObjectsWriter.add_args(parser)
|
RowsWriter.add_args(parser)
|
||||||
|
|
||||||
# Parse empty args to check that arguments exist
|
# Parse empty args to check that arguments exist
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
|
||||||
|
|
@ -1,533 +0,0 @@
|
||||||
"""
|
|
||||||
Unit tests for Cassandra Object Storage Processor
|
|
||||||
|
|
||||||
Tests the business logic of the object storage processor including:
|
|
||||||
- Schema configuration handling
|
|
||||||
- Type conversions
|
|
||||||
- Name sanitization
|
|
||||||
- Table structure generation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
|
||||||
import json
|
|
||||||
|
|
||||||
from trustgraph.storage.objects.cassandra.write import Processor
|
|
||||||
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsCassandraStorageLogic:
|
|
||||||
"""Test business logic without FlowProcessor dependencies"""
|
|
||||||
|
|
||||||
def test_sanitize_name(self):
|
|
||||||
"""Test name sanitization for Cassandra compatibility"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Test various name patterns (back to original logic)
|
|
||||||
assert processor.sanitize_name("simple_name") == "simple_name"
|
|
||||||
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
|
||||||
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
|
||||||
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
|
|
||||||
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
|
||||||
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
|
||||||
|
|
||||||
def test_get_cassandra_type(self):
|
|
||||||
"""Test field type conversion to Cassandra types"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Basic type mappings
|
|
||||||
assert processor.get_cassandra_type("string") == "text"
|
|
||||||
assert processor.get_cassandra_type("boolean") == "boolean"
|
|
||||||
assert processor.get_cassandra_type("timestamp") == "timestamp"
|
|
||||||
assert processor.get_cassandra_type("uuid") == "uuid"
|
|
||||||
|
|
||||||
# Integer types with size hints
|
|
||||||
assert processor.get_cassandra_type("integer", size=2) == "int"
|
|
||||||
assert processor.get_cassandra_type("integer", size=8) == "bigint"
|
|
||||||
|
|
||||||
# Float types with size hints
|
|
||||||
assert processor.get_cassandra_type("float", size=2) == "float"
|
|
||||||
assert processor.get_cassandra_type("float", size=8) == "double"
|
|
||||||
|
|
||||||
# Unknown type defaults to text
|
|
||||||
assert processor.get_cassandra_type("unknown_type") == "text"
|
|
||||||
|
|
||||||
def test_convert_value(self):
|
|
||||||
"""Test value conversion for different field types"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Integer conversions
|
|
||||||
assert processor.convert_value("123", "integer") == 123
|
|
||||||
assert processor.convert_value(123.5, "integer") == 123
|
|
||||||
assert processor.convert_value(None, "integer") is None
|
|
||||||
|
|
||||||
# Float conversions
|
|
||||||
assert processor.convert_value("123.45", "float") == 123.45
|
|
||||||
assert processor.convert_value(123, "float") == 123.0
|
|
||||||
|
|
||||||
# Boolean conversions
|
|
||||||
assert processor.convert_value("true", "boolean") is True
|
|
||||||
assert processor.convert_value("false", "boolean") is False
|
|
||||||
assert processor.convert_value("1", "boolean") is True
|
|
||||||
assert processor.convert_value("0", "boolean") is False
|
|
||||||
assert processor.convert_value("yes", "boolean") is True
|
|
||||||
assert processor.convert_value("no", "boolean") is False
|
|
||||||
|
|
||||||
# String conversions
|
|
||||||
assert processor.convert_value(123, "string") == "123"
|
|
||||||
assert processor.convert_value(True, "string") == "True"
|
|
||||||
|
|
||||||
def test_table_creation_cql_generation(self):
|
|
||||||
"""Test CQL generation for table creation"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {}
|
|
||||||
processor.known_keyspaces = set()
|
|
||||||
processor.known_tables = {}
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
|
||||||
def mock_ensure_keyspace(keyspace):
|
|
||||||
processor.known_keyspaces.add(keyspace)
|
|
||||||
processor.known_tables[keyspace] = set()
|
|
||||||
processor.ensure_keyspace = mock_ensure_keyspace
|
|
||||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create test schema
|
|
||||||
schema = RowSchema(
|
|
||||||
name="customer_records",
|
|
||||||
description="Test customer schema",
|
|
||||||
fields=[
|
|
||||||
Field(
|
|
||||||
name="customer_id",
|
|
||||||
type="string",
|
|
||||||
size=50,
|
|
||||||
primary=True,
|
|
||||||
required=True,
|
|
||||||
indexed=False
|
|
||||||
),
|
|
||||||
Field(
|
|
||||||
name="email",
|
|
||||||
type="string",
|
|
||||||
size=100,
|
|
||||||
required=True,
|
|
||||||
indexed=True
|
|
||||||
),
|
|
||||||
Field(
|
|
||||||
name="age",
|
|
||||||
type="integer",
|
|
||||||
size=4,
|
|
||||||
required=False,
|
|
||||||
indexed=False
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call ensure_table
|
|
||||||
processor.ensure_table("test_user", "customer_records", schema)
|
|
||||||
|
|
||||||
# Verify keyspace was ensured (check that it was added to known_keyspaces)
|
|
||||||
assert "test_user" in processor.known_keyspaces
|
|
||||||
|
|
||||||
# Check the CQL that was executed (first call should be table creation)
|
|
||||||
all_calls = processor.session.execute.call_args_list
|
|
||||||
table_creation_cql = all_calls[0][0][0] # First call
|
|
||||||
|
|
||||||
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
|
|
||||||
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
|
|
||||||
assert "collection text" in table_creation_cql
|
|
||||||
assert "customer_id text" in table_creation_cql
|
|
||||||
assert "email text" in table_creation_cql
|
|
||||||
assert "age int" in table_creation_cql
|
|
||||||
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
|
|
||||||
|
|
||||||
def test_table_creation_without_primary_key(self):
|
|
||||||
"""Test table creation when no primary key is defined"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {}
|
|
||||||
processor.known_keyspaces = set()
|
|
||||||
processor.known_tables = {}
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
|
||||||
def mock_ensure_keyspace(keyspace):
|
|
||||||
processor.known_keyspaces.add(keyspace)
|
|
||||||
processor.known_tables[keyspace] = set()
|
|
||||||
processor.ensure_keyspace = mock_ensure_keyspace
|
|
||||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create schema without primary key
|
|
||||||
schema = RowSchema(
|
|
||||||
name="events",
|
|
||||||
description="Event log",
|
|
||||||
fields=[
|
|
||||||
Field(name="event_type", type="string", size=50),
|
|
||||||
Field(name="timestamp", type="timestamp", size=0)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call ensure_table
|
|
||||||
processor.ensure_table("test_user", "events", schema)
|
|
||||||
|
|
||||||
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
|
|
||||||
executed_cql = processor.session.execute.call_args[0][0]
|
|
||||||
assert "synthetic_id uuid" in executed_cql
|
|
||||||
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_schema_config_parsing(self):
|
|
||||||
"""Test parsing of schema configurations"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {}
|
|
||||||
processor.config_key = "schema"
|
|
||||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create test configuration
|
|
||||||
config = {
|
|
||||||
"schema": {
|
|
||||||
"customer_records": json.dumps({
|
|
||||||
"name": "customer_records",
|
|
||||||
"description": "Customer data",
|
|
||||||
"fields": [
|
|
||||||
{
|
|
||||||
"name": "id",
|
|
||||||
"type": "string",
|
|
||||||
"primary_key": True,
|
|
||||||
"required": True
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "name",
|
|
||||||
"type": "string",
|
|
||||||
"required": True
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "balance",
|
|
||||||
"type": "float",
|
|
||||||
"size": 8
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process configuration
|
|
||||||
await processor.on_schema_config(config, version=1)
|
|
||||||
|
|
||||||
# Verify schema was loaded
|
|
||||||
assert "customer_records" in processor.schemas
|
|
||||||
schema = processor.schemas["customer_records"]
|
|
||||||
assert schema.name == "customer_records"
|
|
||||||
assert len(schema.fields) == 3
|
|
||||||
|
|
||||||
# Check field properties
|
|
||||||
id_field = schema.fields[0]
|
|
||||||
assert id_field.name == "id"
|
|
||||||
assert id_field.type == "string"
|
|
||||||
assert id_field.primary is True
|
|
||||||
# Note: Field.required always returns False due to Pulsar schema limitations
|
|
||||||
# The actual required value is tracked during schema parsing
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_object_processing_logic(self):
|
|
||||||
"""Test the logic for processing ExtractedObject"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"test_schema": RowSchema(
|
|
||||||
name="test_schema",
|
|
||||||
description="Test",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", size=50, primary=True),
|
|
||||||
Field(name="value", type="integer", size=4)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.ensure_table = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
|
||||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
|
||||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
|
||||||
|
|
||||||
# Create test object
|
|
||||||
test_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="test-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="test_collection",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="test_schema",
|
|
||||||
values=[{"id": "123", "value": "456"}],
|
|
||||||
confidence=0.9,
|
|
||||||
source_span="test source"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create mock message
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = test_obj
|
|
||||||
|
|
||||||
# Process object
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify table was ensured
|
|
||||||
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
|
|
||||||
|
|
||||||
# Verify insert was executed (keyspace normal, table with o_ prefix)
|
|
||||||
processor.session.execute.assert_called_once()
|
|
||||||
insert_cql = processor.session.execute.call_args[0][0]
|
|
||||||
values = processor.session.execute.call_args[0][1]
|
|
||||||
|
|
||||||
assert "INSERT INTO test_user.o_test_schema" in insert_cql
|
|
||||||
assert "collection" in insert_cql
|
|
||||||
assert values[0] == "test_collection" # collection value
|
|
||||||
assert values[1] == "123" # id value (from values[0])
|
|
||||||
assert values[2] == 456 # converted integer value (from values[0])
|
|
||||||
|
|
||||||
def test_secondary_index_creation(self):
|
|
||||||
"""Test that secondary indexes are created for indexed fields"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {}
|
|
||||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
|
||||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
|
||||||
def mock_ensure_keyspace(keyspace):
|
|
||||||
processor.known_keyspaces.add(keyspace)
|
|
||||||
if keyspace not in processor.known_tables:
|
|
||||||
processor.known_tables[keyspace] = set()
|
|
||||||
processor.ensure_keyspace = mock_ensure_keyspace
|
|
||||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create schema with indexed field
|
|
||||||
schema = RowSchema(
|
|
||||||
name="products",
|
|
||||||
description="Product catalog",
|
|
||||||
fields=[
|
|
||||||
Field(name="product_id", type="string", size=50, primary=True),
|
|
||||||
Field(name="category", type="string", size=30, indexed=True),
|
|
||||||
Field(name="price", type="float", size=8, indexed=True)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call ensure_table
|
|
||||||
processor.ensure_table("test_user", "products", schema)
|
|
||||||
|
|
||||||
# Should have 3 calls: create table + 2 indexes
|
|
||||||
assert processor.session.execute.call_count == 3
|
|
||||||
|
|
||||||
# Check index creation calls (table has o_ prefix, fields don't)
|
|
||||||
calls = processor.session.execute.call_args_list
|
|
||||||
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
|
|
||||||
assert len(index_calls) == 2
|
|
||||||
assert any("o_products_category_idx" in call for call in index_calls)
|
|
||||||
assert any("o_products_price_idx" in call for call in index_calls)
|
|
||||||
|
|
||||||
|
|
||||||
class TestObjectsCassandraStorageBatchLogic:
|
|
||||||
"""Test batch processing logic in Cassandra storage"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_batch_object_processing_logic(self):
|
|
||||||
"""Test processing of batch ExtractedObjects"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"batch_schema": RowSchema(
|
|
||||||
name="batch_schema",
|
|
||||||
description="Test batch schema",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", size=50, primary=True),
|
|
||||||
Field(name="name", type="string", size=100),
|
|
||||||
Field(name="value", type="integer", size=4)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
|
||||||
processor.ensure_table = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Create batch object with multiple values
|
|
||||||
batch_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="batch-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="batch_collection",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="batch_schema",
|
|
||||||
values=[
|
|
||||||
{"id": "001", "name": "First", "value": "100"},
|
|
||||||
{"id": "002", "name": "Second", "value": "200"},
|
|
||||||
{"id": "003", "name": "Third", "value": "300"}
|
|
||||||
],
|
|
||||||
confidence=0.95,
|
|
||||||
source_span="batch source"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create mock message
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = batch_obj
|
|
||||||
|
|
||||||
# Process batch object
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify table was ensured once
|
|
||||||
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
|
|
||||||
|
|
||||||
# Verify 3 separate insert calls (one per batch item)
|
|
||||||
assert processor.session.execute.call_count == 3
|
|
||||||
|
|
||||||
# Check each insert call
|
|
||||||
calls = processor.session.execute.call_args_list
|
|
||||||
for i, call in enumerate(calls):
|
|
||||||
insert_cql = call[0][0]
|
|
||||||
values = call[0][1]
|
|
||||||
|
|
||||||
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
|
|
||||||
assert "collection" in insert_cql
|
|
||||||
|
|
||||||
# Check values for each batch item
|
|
||||||
assert values[0] == "batch_collection" # collection
|
|
||||||
assert values[1] == f"00{i+1}" # id from batch item i
|
|
||||||
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
|
|
||||||
assert values[3] == (i+1) * 100 # converted integer value
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_batch_processing_logic(self):
|
|
||||||
"""Test processing of empty batch ExtractedObjects"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"empty_schema": RowSchema(
|
|
||||||
name="empty_schema",
|
|
||||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.ensure_table = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
|
||||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
|
||||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
|
||||||
|
|
||||||
# Create empty batch object
|
|
||||||
empty_batch_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="empty-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="empty_collection",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="empty_schema",
|
|
||||||
values=[], # Empty batch
|
|
||||||
confidence=1.0,
|
|
||||||
source_span="empty source"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = empty_batch_obj
|
|
||||||
|
|
||||||
# Process empty batch object
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify table was ensured
|
|
||||||
processor.ensure_table.assert_called_once()
|
|
||||||
|
|
||||||
# Verify no insert calls for empty batch
|
|
||||||
processor.session.execute.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_item_batch_processing_logic(self):
|
|
||||||
"""Test processing of single-item batch (backward compatibility)"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.schemas = {
|
|
||||||
"single_schema": RowSchema(
|
|
||||||
name="single_schema",
|
|
||||||
fields=[
|
|
||||||
Field(name="id", type="string", size=50, primary=True),
|
|
||||||
Field(name="data", type="string", size=100)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
processor.ensure_table = MagicMock()
|
|
||||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
|
||||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
processor.session = MagicMock()
|
|
||||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
|
||||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
|
||||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
|
||||||
|
|
||||||
# Create single-item batch object (backward compatibility case)
|
|
||||||
single_batch_obj = ExtractedObject(
|
|
||||||
metadata=Metadata(
|
|
||||||
id="single-001",
|
|
||||||
user="test_user",
|
|
||||||
collection="single_collection",
|
|
||||||
metadata=[]
|
|
||||||
),
|
|
||||||
schema_name="single_schema",
|
|
||||||
values=[{"id": "single-1", "data": "single data"}], # Array with one item
|
|
||||||
confidence=0.8,
|
|
||||||
source_span="single source"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.value.return_value = single_batch_obj
|
|
||||||
|
|
||||||
# Process single-item batch object
|
|
||||||
await processor.on_object(msg, None, None)
|
|
||||||
|
|
||||||
# Verify table was ensured
|
|
||||||
processor.ensure_table.assert_called_once()
|
|
||||||
|
|
||||||
# Verify exactly one insert call
|
|
||||||
processor.session.execute.assert_called_once()
|
|
||||||
|
|
||||||
insert_cql = processor.session.execute.call_args[0][0]
|
|
||||||
values = processor.session.execute.call_args[0][1]
|
|
||||||
|
|
||||||
assert "INSERT INTO test_user.o_single_schema" in insert_cql
|
|
||||||
assert values[0] == "single_collection" # collection
|
|
||||||
assert values[1] == "single-1" # id value
|
|
||||||
assert values[2] == "single data" # data value
|
|
||||||
|
|
||||||
def test_batch_value_conversion_logic(self):
|
|
||||||
"""Test value conversion works correctly for batch items"""
|
|
||||||
processor = MagicMock()
|
|
||||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
|
||||||
|
|
||||||
# Test various conversion scenarios that would occur in batch processing
|
|
||||||
test_cases = [
|
|
||||||
# Integer conversions for batch items
|
|
||||||
("123", "integer", 123),
|
|
||||||
("456", "integer", 456),
|
|
||||||
("789", "integer", 789),
|
|
||||||
# Float conversions for batch items
|
|
||||||
("12.5", "float", 12.5),
|
|
||||||
("34.7", "float", 34.7),
|
|
||||||
# Boolean conversions for batch items
|
|
||||||
("true", "boolean", True),
|
|
||||||
("false", "boolean", False),
|
|
||||||
("1", "boolean", True),
|
|
||||||
("0", "boolean", False),
|
|
||||||
# String conversions for batch items
|
|
||||||
(123, "string", "123"),
|
|
||||||
(45.6, "string", "45.6"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for input_val, field_type, expected_output in test_cases:
|
|
||||||
result = processor.convert_value(input_val, field_type)
|
|
||||||
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"
|
|
||||||
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
435
tests/unit/test_storage/test_row_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,435 @@
|
||||||
|
"""
|
||||||
|
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
|
||||||
|
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from unittest import IsolatedAsyncioTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
"""Test Qdrant row embeddings storage functionality"""
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||||
|
"""Test basic Qdrant processor initialization"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'store_uri': 'http://localhost:6333',
|
||||||
|
'api_key': 'test-api-key',
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-qdrant-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
mock_qdrant_client.assert_called_once_with(
|
||||||
|
url='http://localhost:6333', api_key='test-api-key'
|
||||||
|
)
|
||||||
|
assert hasattr(processor, 'qdrant')
|
||||||
|
assert processor.qdrant == mock_qdrant_instance
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||||
|
"""Test processor initialization with default values"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-qdrant-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
mock_qdrant_client.assert_called_once_with(
|
||||||
|
url='http://localhost:6333', api_key=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_sanitize_name(self, mock_qdrant_client):
|
||||||
|
"""Test name sanitization for Qdrant collections"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_client.return_value = MagicMock()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Test basic sanitization
|
||||||
|
assert processor.sanitize_name("simple") == "simple"
|
||||||
|
assert processor.sanitize_name("with-dash") == "with_dash"
|
||||||
|
assert processor.sanitize_name("with.dot") == "with_dot"
|
||||||
|
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||||
|
|
||||||
|
# Test numeric prefix handling
|
||||||
|
assert processor.sanitize_name("123start") == "r_123start"
|
||||||
|
assert processor.sanitize_name("_underscore") == "r__underscore"
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_get_collection_name(self, mock_qdrant_client):
|
||||||
|
"""Test Qdrant collection name generation"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_client.return_value = MagicMock()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
collection_name = processor.get_collection_name(
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
schema_name="customer_data",
|
||||||
|
dimension=384
|
||||||
|
)
|
||||||
|
|
||||||
|
assert collection_name == "rows_test_user_test_collection_customer_data_384"
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
|
||||||
|
"""Test that ensure_collection creates a new collection when needed"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_instance.collection_exists.return_value = False
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
processor.ensure_collection("test_collection", 384)
|
||||||
|
|
||||||
|
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||||
|
mock_qdrant_instance.create_collection.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the collection is cached
|
||||||
|
assert "test_collection" in processor.created_collections
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||||
|
"""Test that ensure_collection skips creation when collection exists"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
processor.ensure_collection("existing_collection", 384)
|
||||||
|
|
||||||
|
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||||
|
mock_qdrant_instance.create_collection.assert_not_called()
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
|
||||||
|
"""Test that ensure_collection uses cache for previously created collections"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.created_collections.add("cached_collection")
|
||||||
|
|
||||||
|
processor.ensure_collection("cached_collection", 384)
|
||||||
|
|
||||||
|
# Should not check or create - just return
|
||||||
|
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||||
|
mock_qdrant_instance.create_collection.assert_not_called()
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||||
|
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||||
|
"""Test processing basic row embeddings message"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
mock_uuid.uuid4.return_value = 'test-uuid-123'
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
|
# Create embeddings message
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'test_user'
|
||||||
|
metadata.collection = 'test_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
embedding = RowIndexEmbedding(
|
||||||
|
index_name='customer_id',
|
||||||
|
index_value=['CUST001'],
|
||||||
|
text='CUST001',
|
||||||
|
vectors=[[0.1, 0.2, 0.3]]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_msg = RowEmbeddings(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='customers',
|
||||||
|
embeddings=[embedding]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock message wrapper
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = embeddings_msg
|
||||||
|
|
||||||
|
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
# Verify upsert was called
|
||||||
|
mock_qdrant_instance.upsert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify upsert parameters
|
||||||
|
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||||
|
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
|
||||||
|
|
||||||
|
point = upsert_call_args[1]['points'][0]
|
||||||
|
assert point.vector == [0.1, 0.2, 0.3]
|
||||||
|
assert point.payload['index_name'] == 'customer_id'
|
||||||
|
assert point.payload['index_value'] == ['CUST001']
|
||||||
|
assert point.payload['text'] == 'CUST001'
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
|
||||||
|
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
|
||||||
|
"""Test processing embeddings with multiple vectors"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
mock_uuid.uuid4.return_value = 'test-uuid'
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'test_user'
|
||||||
|
metadata.collection = 'test_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
# Embedding with multiple vectors
|
||||||
|
embedding = RowIndexEmbedding(
|
||||||
|
index_name='name',
|
||||||
|
index_value=['John Doe'],
|
||||||
|
text='John Doe',
|
||||||
|
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_msg = RowEmbeddings(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='people',
|
||||||
|
embeddings=[embedding]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = embeddings_msg
|
||||||
|
|
||||||
|
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
# Should be called 3 times (once per vector)
|
||||||
|
assert mock_qdrant_instance.upsert.call_count == 3
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
|
||||||
|
"""Test that embeddings with no vectors are skipped"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'test_user'
|
||||||
|
metadata.collection = 'test_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
# Embedding with no vectors
|
||||||
|
embedding = RowIndexEmbedding(
|
||||||
|
index_name='id',
|
||||||
|
index_value=['123'],
|
||||||
|
text='123',
|
||||||
|
vectors=[] # Empty vectors
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_msg = RowEmbeddings(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='items',
|
||||||
|
embeddings=[embedding]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = embeddings_msg
|
||||||
|
|
||||||
|
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
# Should not call upsert for empty vectors
|
||||||
|
mock_qdrant_instance.upsert.assert_not_called()
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
|
||||||
|
"""Test that messages for unknown collections are dropped"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
# No collections registered
|
||||||
|
|
||||||
|
metadata = MagicMock()
|
||||||
|
metadata.user = 'unknown_user'
|
||||||
|
metadata.collection = 'unknown_collection'
|
||||||
|
metadata.id = 'doc-123'
|
||||||
|
|
||||||
|
embedding = RowIndexEmbedding(
|
||||||
|
index_name='id',
|
||||||
|
index_value=['123'],
|
||||||
|
text='123',
|
||||||
|
vectors=[[0.1, 0.2]]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings_msg = RowEmbeddings(
|
||||||
|
metadata=metadata,
|
||||||
|
schema_name='items',
|
||||||
|
embeddings=[embedding]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = embeddings_msg
|
||||||
|
|
||||||
|
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
# Should not call upsert for unknown collection
|
||||||
|
mock_qdrant_instance.upsert.assert_not_called()
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_delete_collection(self, mock_qdrant_client):
|
||||||
|
"""Test deleting all collections for a user/collection"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
|
||||||
|
# Mock collections list
|
||||||
|
mock_coll1 = MagicMock()
|
||||||
|
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
|
||||||
|
mock_coll2 = MagicMock()
|
||||||
|
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
|
||||||
|
mock_coll3 = MagicMock()
|
||||||
|
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
|
||||||
|
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||||
|
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
|
||||||
|
|
||||||
|
await processor.delete_collection('test_user', 'test_collection')
|
||||||
|
|
||||||
|
# Should delete only the matching collections
|
||||||
|
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||||
|
|
||||||
|
# Verify the cached collection was removed
|
||||||
|
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
|
||||||
|
|
||||||
|
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||||
|
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||||
|
"""Test deleting collections for a specific schema"""
|
||||||
|
from trustgraph.storage.row_embeddings.qdrant.write import Processor
|
||||||
|
|
||||||
|
mock_qdrant_instance = MagicMock()
|
||||||
|
|
||||||
|
mock_coll1 = MagicMock()
|
||||||
|
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
|
||||||
|
mock_coll2 = MagicMock()
|
||||||
|
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
|
||||||
|
|
||||||
|
mock_collections = MagicMock()
|
||||||
|
mock_collections.collections = [mock_coll1, mock_coll2]
|
||||||
|
mock_qdrant_instance.get_collections.return_value = mock_collections
|
||||||
|
|
||||||
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'taskgroup': AsyncMock(),
|
||||||
|
'id': 'test-processor'
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
await processor.delete_collection_schema(
|
||||||
|
'test_user', 'test_collection', 'customers'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only delete the customers schema collection
|
||||||
|
mock_qdrant_instance.delete_collection.assert_called_once()
|
||||||
|
call_args = mock_qdrant_instance.delete_collection.call_args[0]
|
||||||
|
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
474
tests/unit/test_storage/test_rows_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,474 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
|
||||||
|
|
||||||
|
Tests the business logic of the row storage processor including:
|
||||||
|
- Schema configuration handling
|
||||||
|
- Name sanitization
|
||||||
|
- Unified table structure
|
||||||
|
- Index management
|
||||||
|
- Row storage with multi-index support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from trustgraph.storage.rows.cassandra.write import Processor
|
||||||
|
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TestRowsCassandraStorageLogic:
|
||||||
|
"""Test business logic for unified table implementation"""
|
||||||
|
|
||||||
|
def test_sanitize_name(self):
|
||||||
|
"""Test name sanitization for Cassandra compatibility"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Test various name patterns
|
||||||
|
assert processor.sanitize_name("simple_name") == "simple_name"
|
||||||
|
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
|
||||||
|
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
|
||||||
|
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
|
||||||
|
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
|
||||||
|
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
|
||||||
|
assert processor.sanitize_name("UPPERCASE") == "uppercase"
|
||||||
|
assert processor.sanitize_name("CamelCase") == "camelcase"
|
||||||
|
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
|
||||||
|
|
||||||
|
def test_get_index_names(self):
|
||||||
|
"""Test extraction of index names from schema"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Schema with primary and indexed fields
|
||||||
|
schema = RowSchema(
|
||||||
|
name="test_schema",
|
||||||
|
description="Test",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="category", type="string", indexed=True),
|
||||||
|
Field(name="name", type="string"), # Not indexed
|
||||||
|
Field(name="status", type="string", indexed=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
index_names = processor.get_index_names(schema)
|
||||||
|
|
||||||
|
# Should include primary key and indexed fields
|
||||||
|
assert "id" in index_names
|
||||||
|
assert "category" in index_names
|
||||||
|
assert "status" in index_names
|
||||||
|
assert "name" not in index_names # Not indexed
|
||||||
|
assert len(index_names) == 3
|
||||||
|
|
||||||
|
def test_get_index_names_no_indexes(self):
|
||||||
|
"""Test schema with no indexed fields"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
|
||||||
|
schema = RowSchema(
|
||||||
|
name="no_index_schema",
|
||||||
|
fields=[
|
||||||
|
Field(name="data1", type="string"),
|
||||||
|
Field(name="data2", type="string")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
index_names = processor.get_index_names(schema)
|
||||||
|
assert len(index_names) == 0
|
||||||
|
|
||||||
|
def test_build_index_value(self):
|
||||||
|
"""Test building index values from row data"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
|
||||||
|
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
|
||||||
|
|
||||||
|
# Single field index
|
||||||
|
result = processor.build_index_value(value_map, "id")
|
||||||
|
assert result == ["123"]
|
||||||
|
|
||||||
|
result = processor.build_index_value(value_map, "category")
|
||||||
|
assert result == ["electronics"]
|
||||||
|
|
||||||
|
# Missing field returns empty string
|
||||||
|
result = processor.build_index_value(value_map, "missing")
|
||||||
|
assert result == [""]
|
||||||
|
|
||||||
|
def test_build_index_value_composite(self):
|
||||||
|
"""Test building composite index values"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
|
||||||
|
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
|
||||||
|
|
||||||
|
# Composite index (comma-separated field names)
|
||||||
|
result = processor.build_index_value(value_map, "region,category")
|
||||||
|
assert result == ["us-west", "electronics"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_schema_config_parsing(self):
|
||||||
|
"""Test parsing of schema configurations"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.schemas = {}
|
||||||
|
processor.config_key = "schema"
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Create test configuration
|
||||||
|
config = {
|
||||||
|
"schema": {
|
||||||
|
"customer_records": json.dumps({
|
||||||
|
"name": "customer_records",
|
||||||
|
"description": "Customer data",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"type": "string",
|
||||||
|
"primary_key": True,
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "name",
|
||||||
|
"type": "string",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "category",
|
||||||
|
"type": "string",
|
||||||
|
"indexed": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process configuration
|
||||||
|
await processor.on_schema_config(config, version=1)
|
||||||
|
|
||||||
|
# Verify schema was loaded
|
||||||
|
assert "customer_records" in processor.schemas
|
||||||
|
schema = processor.schemas["customer_records"]
|
||||||
|
assert schema.name == "customer_records"
|
||||||
|
assert len(schema.fields) == 3
|
||||||
|
|
||||||
|
# Check field properties
|
||||||
|
id_field = schema.fields[0]
|
||||||
|
assert id_field.name == "id"
|
||||||
|
assert id_field.type == "string"
|
||||||
|
assert id_field.primary is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_object_processing_stores_data_map(self):
|
||||||
|
"""Test that row processing stores data as map<text, text>"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.schemas = {
|
||||||
|
"test_schema": RowSchema(
|
||||||
|
name="test_schema",
|
||||||
|
description="Test",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", size=50, primary=True),
|
||||||
|
Field(name="value", type="string", size=100)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
processor.tables_initialized = {"test_user"}
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
processor.ensure_tables = MagicMock()
|
||||||
|
processor.register_partitions = MagicMock()
|
||||||
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Create test object
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="test-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="test_schema",
|
||||||
|
values=[{"id": "123", "value": "test_data"}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span="test source"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock message
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
# Process object
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify insert was executed
|
||||||
|
processor.session.execute.assert_called()
|
||||||
|
insert_call = processor.session.execute.call_args
|
||||||
|
insert_cql = insert_call[0][0]
|
||||||
|
values = insert_call[0][1]
|
||||||
|
|
||||||
|
# Verify using unified rows table
|
||||||
|
assert "INSERT INTO test_user.rows" in insert_cql
|
||||||
|
|
||||||
|
# Values should be: (collection, schema_name, index_name, index_value, data, source)
|
||||||
|
assert values[0] == "test_collection" # collection
|
||||||
|
assert values[1] == "test_schema" # schema_name
|
||||||
|
assert values[2] == "id" # index_name (primary key field)
|
||||||
|
assert values[3] == ["123"] # index_value as list
|
||||||
|
assert values[4] == {"id": "123", "value": "test_data"} # data map
|
||||||
|
assert values[5] == "" # source
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_object_processing_multiple_indexes(self):
|
||||||
|
"""Test that row is written once per indexed field"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.schemas = {
|
||||||
|
"multi_index_schema": RowSchema(
|
||||||
|
name="multi_index_schema",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="category", type="string", indexed=True),
|
||||||
|
Field(name="status", type="string", indexed=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
processor.tables_initialized = {"test_user"}
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
processor.ensure_tables = MagicMock()
|
||||||
|
processor.register_partitions = MagicMock()
|
||||||
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
|
|
||||||
|
test_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="test-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="multi_index_schema",
|
||||||
|
values=[{"id": "123", "category": "electronics", "status": "active"}],
|
||||||
|
confidence=0.9,
|
||||||
|
source_span=""
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = test_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Should have 3 inserts (one per indexed field: id, category, status)
|
||||||
|
assert processor.session.execute.call_count == 3
|
||||||
|
|
||||||
|
# Check that different index_names were used
|
||||||
|
index_names_used = set()
|
||||||
|
for call in processor.session.execute.call_args_list:
|
||||||
|
values = call[0][1]
|
||||||
|
index_names_used.add(values[2]) # index_name is 3rd value
|
||||||
|
|
||||||
|
assert index_names_used == {"id", "category", "status"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRowsCassandraStorageBatchLogic:
|
||||||
|
"""Test batch processing logic for unified table implementation"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_object_processing(self):
|
||||||
|
"""Test processing of batch ExtractedObjects"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.schemas = {
|
||||||
|
"batch_schema": RowSchema(
|
||||||
|
name="batch_schema",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="name", type="string")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
processor.tables_initialized = {"test_user"}
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
processor.ensure_tables = MagicMock()
|
||||||
|
processor.register_partitions = MagicMock()
|
||||||
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Create batch object with multiple values
|
||||||
|
batch_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="batch-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="batch_collection",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="batch_schema",
|
||||||
|
values=[
|
||||||
|
{"id": "001", "name": "First"},
|
||||||
|
{"id": "002", "name": "Second"},
|
||||||
|
{"id": "003", "name": "Third"}
|
||||||
|
],
|
||||||
|
confidence=0.95,
|
||||||
|
source_span=""
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = batch_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Should have 3 inserts (one per row, one index per row since only primary key)
|
||||||
|
assert processor.session.execute.call_count == 3
|
||||||
|
|
||||||
|
# Check each insert has different id
|
||||||
|
ids_inserted = set()
|
||||||
|
for call in processor.session.execute.call_args_list:
|
||||||
|
values = call[0][1]
|
||||||
|
ids_inserted.add(tuple(values[3])) # index_value is 4th value
|
||||||
|
|
||||||
|
assert ids_inserted == {("001",), ("002",), ("003",)}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_batch_processing(self):
|
||||||
|
"""Test processing of empty batch ExtractedObjects"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.schemas = {
|
||||||
|
"empty_schema": RowSchema(
|
||||||
|
name="empty_schema",
|
||||||
|
fields=[Field(name="id", type="string", primary=True)]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
processor.tables_initialized = {"test_user"}
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||||
|
processor.ensure_tables = MagicMock()
|
||||||
|
processor.register_partitions = MagicMock()
|
||||||
|
processor.collection_exists = MagicMock(return_value=True)
|
||||||
|
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||||
|
|
||||||
|
# Create empty batch object
|
||||||
|
empty_batch_obj = ExtractedObject(
|
||||||
|
metadata=Metadata(
|
||||||
|
id="empty-001",
|
||||||
|
user="test_user",
|
||||||
|
collection="empty_collection",
|
||||||
|
metadata=[]
|
||||||
|
),
|
||||||
|
schema_name="empty_schema",
|
||||||
|
values=[], # Empty batch
|
||||||
|
confidence=1.0,
|
||||||
|
source_span=""
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.value.return_value = empty_batch_obj
|
||||||
|
|
||||||
|
await processor.on_object(msg, None, None)
|
||||||
|
|
||||||
|
# Verify no insert calls for empty batch
|
||||||
|
processor.session.execute.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnifiedTableStructure:
|
||||||
|
"""Test the unified rows table structure"""
|
||||||
|
|
||||||
|
def test_ensure_tables_creates_unified_structure(self):
|
||||||
|
"""Test that ensure_tables creates the unified rows table"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.known_keyspaces = {"test_user"}
|
||||||
|
processor.tables_initialized = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.ensure_keyspace = MagicMock()
|
||||||
|
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||||
|
|
||||||
|
processor.ensure_tables("test_user")
|
||||||
|
|
||||||
|
# Should have 2 calls: create rows table + create row_partitions table
|
||||||
|
assert processor.session.execute.call_count == 2
|
||||||
|
|
||||||
|
# Check rows table creation
|
||||||
|
rows_cql = processor.session.execute.call_args_list[0][0][0]
|
||||||
|
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
|
||||||
|
assert "collection text" in rows_cql
|
||||||
|
assert "schema_name text" in rows_cql
|
||||||
|
assert "index_name text" in rows_cql
|
||||||
|
assert "index_value frozen<list<text>>" in rows_cql
|
||||||
|
assert "data map<text, text>" in rows_cql
|
||||||
|
assert "source text" in rows_cql
|
||||||
|
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
|
||||||
|
|
||||||
|
# Check row_partitions table creation
|
||||||
|
partitions_cql = processor.session.execute.call_args_list[1][0][0]
|
||||||
|
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
|
||||||
|
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
|
||||||
|
|
||||||
|
# Verify keyspace added to initialized set
|
||||||
|
assert "test_user" in processor.tables_initialized
|
||||||
|
|
||||||
|
def test_ensure_tables_idempotent(self):
|
||||||
|
"""Test that ensure_tables is idempotent"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.tables_initialized = {"test_user"} # Already initialized
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||||
|
|
||||||
|
processor.ensure_tables("test_user")
|
||||||
|
|
||||||
|
# Should not execute any CQL since already initialized
|
||||||
|
processor.session.execute.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPartitionRegistration:
|
||||||
|
"""Test partition registration for tracking what's stored"""
|
||||||
|
|
||||||
|
def test_register_partitions(self):
|
||||||
|
"""Test registering partitions for a collection/schema pair"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.registered_partitions = set()
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.schemas = {
|
||||||
|
"test_schema": RowSchema(
|
||||||
|
name="test_schema",
|
||||||
|
fields=[
|
||||||
|
Field(name="id", type="string", primary=True),
|
||||||
|
Field(name="category", type="string", indexed=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||||
|
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||||
|
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||||
|
|
||||||
|
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||||
|
|
||||||
|
# Should have 2 inserts (one per index: id, category)
|
||||||
|
assert processor.session.execute.call_count == 2
|
||||||
|
|
||||||
|
# Verify cache was updated
|
||||||
|
assert ("test_collection", "test_schema") in processor.registered_partitions
|
||||||
|
|
||||||
|
def test_register_partitions_idempotent(self):
|
||||||
|
"""Test that partition registration is idempotent"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
|
||||||
|
processor.session = MagicMock()
|
||||||
|
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||||
|
|
||||||
|
processor.register_partitions("test_user", "test_collection", "test_schema")
|
||||||
|
|
||||||
|
# Should not execute any CQL since already registered
|
||||||
|
processor.session.execute.assert_not_called()
|
||||||
|
|
@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert hasattr(processor, 'client')
|
assert hasattr(processor, 'client')
|
||||||
assert hasattr(processor, 'safety_settings')
|
assert hasattr(processor, 'safety_settings')
|
||||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert processor.default_model == 'gemini-1.5-pro'
|
assert processor.default_model == 'gemini-1.5-pro'
|
||||||
assert processor.temperature == 0.7
|
assert processor.temperature == 0.7
|
||||||
assert processor.max_output == 4096
|
assert processor.max_output == 4096
|
||||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||||
assert processor.temperature == 0.0 # default_temperature
|
assert processor.temperature == 0.0 # default_temperature
|
||||||
assert processor.max_output == 8192 # default_max_output
|
assert processor.max_output == 8192 # default_max_output
|
||||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify Google AI Studio client was called with correct API key
|
# Verify Google AI Studio client was called with correct API key
|
||||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
|
||||||
|
|
||||||
# Verify processor has the client
|
# Verify processor has the client
|
||||||
assert processor.client == mock_genai_client
|
assert processor.client == mock_genai_client
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ from .exceptions import (
|
||||||
LoadError,
|
LoadError,
|
||||||
LookupError,
|
LookupError,
|
||||||
NLPQueryError,
|
NLPQueryError,
|
||||||
ObjectsQueryError,
|
RowsQueryError,
|
||||||
RequestError,
|
RequestError,
|
||||||
StructuredQueryError,
|
StructuredQueryError,
|
||||||
UnexpectedError,
|
UnexpectedError,
|
||||||
|
|
@ -161,7 +161,7 @@ __all__ = [
|
||||||
"LoadError",
|
"LoadError",
|
||||||
"LookupError",
|
"LookupError",
|
||||||
"NLPQueryError",
|
"NLPQueryError",
|
||||||
"ObjectsQueryError",
|
"RowsQueryError",
|
||||||
"RequestError",
|
"RequestError",
|
||||||
"StructuredQueryError",
|
"StructuredQueryError",
|
||||||
"UnexpectedError",
|
"UnexpectedError",
|
||||||
|
|
|
||||||
|
|
@ -115,15 +115,15 @@ class AsyncBulkClient:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
yield json.loads(raw_message)
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""Bulk import objects via WebSocket"""
|
"""Bulk import rows via WebSocket"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||||
if self.token:
|
if self.token:
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
async for obj in objects:
|
async for row in rows:
|
||||||
await websocket.send(json.dumps(obj))
|
await websocket.send(json.dumps(row))
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
"""Close connections"""
|
"""Close connections"""
|
||||||
|
|
|
||||||
|
|
@ -708,18 +708,18 @@ class AsyncFlowInstance:
|
||||||
|
|
||||||
return await self.request("triples", request_data)
|
return await self.request("triples", request_data)
|
||||||
|
|
||||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||||
operation_name: Optional[str] = None, **kwargs: Any):
|
operation_name: Optional[str] = None, **kwargs: Any):
|
||||||
"""
|
"""
|
||||||
Execute a GraphQL query on stored objects.
|
Execute a GraphQL query on stored rows.
|
||||||
|
|
||||||
Queries structured data objects using GraphQL syntax. Supports complex
|
Queries structured data rows using GraphQL syntax. Supports complex
|
||||||
queries with variables and named operations.
|
queries with variables and named operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: GraphQL query string
|
query: GraphQL query string
|
||||||
user: User identifier
|
user: User identifier
|
||||||
collection: Collection identifier containing objects
|
collection: Collection identifier containing rows
|
||||||
variables: Optional GraphQL query variables
|
variables: Optional GraphQL query variables
|
||||||
operation_name: Optional operation name for multi-operation queries
|
operation_name: Optional operation name for multi-operation queries
|
||||||
**kwargs: Additional service-specific parameters
|
**kwargs: Additional service-specific parameters
|
||||||
|
|
@ -743,7 +743,7 @@ class AsyncFlowInstance:
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
|
|
||||||
result = await flow.objects_query(
|
result = await flow.rows_query(
|
||||||
query=query,
|
query=query,
|
||||||
user="trustgraph",
|
user="trustgraph",
|
||||||
collection="users",
|
collection="users",
|
||||||
|
|
@ -765,4 +765,4 @@ class AsyncFlowInstance:
|
||||||
request_data["operationName"] = operation_name
|
request_data["operationName"] = operation_name
|
||||||
request_data.update(kwargs)
|
request_data.update(kwargs)
|
||||||
|
|
||||||
return await self.request("objects", request_data)
|
return await self.request("rows", request_data)
|
||||||
|
|
|
||||||
|
|
@ -320,9 +320,9 @@ class AsyncSocketFlowInstance:
|
||||||
|
|
||||||
return await self.client._send_request("triples", self.flow_id, request)
|
return await self.client._send_request("triples", self.flow_id, request)
|
||||||
|
|
||||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||||
operation_name: Optional[str] = None, **kwargs):
|
operation_name: Optional[str] = None, **kwargs):
|
||||||
"""GraphQL query"""
|
"""GraphQL query against structured rows"""
|
||||||
request = {
|
request = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"user": user,
|
"user": user,
|
||||||
|
|
@ -334,7 +334,7 @@ class AsyncSocketFlowInstance:
|
||||||
request["operationName"] = operation_name
|
request["operationName"] = operation_name
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
||||||
return await self.client._send_request("objects", self.flow_id, request)
|
return await self.client._send_request("rows", self.flow_id, request)
|
||||||
|
|
||||||
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
|
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
|
||||||
"""Execute MCP tool"""
|
"""Execute MCP tool"""
|
||||||
|
|
|
||||||
|
|
@ -530,45 +530,45 @@ class BulkClient:
|
||||||
async for raw_message in websocket:
|
async for raw_message in websocket:
|
||||||
yield json.loads(raw_message)
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Bulk import structured objects into a flow.
|
Bulk import structured rows into a flow.
|
||||||
|
|
||||||
Efficiently uploads structured data objects via WebSocket streaming
|
Efficiently uploads structured data rows via WebSocket streaming
|
||||||
for use in GraphQL queries.
|
for use in GraphQL queries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flow: Flow identifier
|
flow: Flow identifier
|
||||||
objects: Iterator yielding object dictionaries
|
rows: Iterator yielding row dictionaries
|
||||||
**kwargs: Additional parameters (reserved for future use)
|
**kwargs: Additional parameters (reserved for future use)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
bulk = api.bulk()
|
bulk = api.bulk()
|
||||||
|
|
||||||
# Generate objects to import
|
# Generate rows to import
|
||||||
def object_generator():
|
def row_generator():
|
||||||
yield {"id": "obj1", "name": "Object 1", "value": 100}
|
yield {"id": "row1", "name": "Row 1", "value": 100}
|
||||||
yield {"id": "obj2", "name": "Object 2", "value": 200}
|
yield {"id": "row2", "name": "Row 2", "value": 200}
|
||||||
# ... more objects
|
# ... more rows
|
||||||
|
|
||||||
bulk.import_objects(
|
bulk.import_rows(
|
||||||
flow="default",
|
flow="default",
|
||||||
objects=object_generator()
|
rows=row_generator()
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
self._run_async(self._import_objects_async(flow, objects))
|
self._run_async(self._import_rows_async(flow, rows))
|
||||||
|
|
||||||
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
|
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
||||||
"""Async implementation of objects import"""
|
"""Async implementation of rows import"""
|
||||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||||
if self.token:
|
if self.token:
|
||||||
ws_url = f"{ws_url}?token={self.token}"
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
for obj in objects:
|
for row in rows:
|
||||||
await websocket.send(json.dumps(obj))
|
await websocket.send(json.dumps(row))
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close connections"""
|
"""Close connections"""
|
||||||
|
|
|
||||||
|
|
@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ObjectsQueryError(TrustGraphException):
|
class RowsQueryError(TrustGraphException):
|
||||||
"""Objects query service error"""
|
"""Rows query service error"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = {
|
||||||
"load-error": LoadError,
|
"load-error": LoadError,
|
||||||
"lookup-error": LookupError,
|
"lookup-error": LookupError,
|
||||||
"nlp-query-error": NLPQueryError,
|
"nlp-query-error": NLPQueryError,
|
||||||
"objects-query-error": ObjectsQueryError,
|
"rows-query-error": RowsQueryError,
|
||||||
"request-error": RequestError,
|
"request-error": RequestError,
|
||||||
"structured-query-error": StructuredQueryError,
|
"structured-query-error": StructuredQueryError,
|
||||||
"unexpected-error": UnexpectedError,
|
"unexpected-error": UnexpectedError,
|
||||||
|
|
|
||||||
|
|
@ -1001,12 +1001,12 @@ class FlowInstance:
|
||||||
input
|
input
|
||||||
)
|
)
|
||||||
|
|
||||||
def objects_query(
|
def rows_query(
|
||||||
self, query, user="trustgraph", collection="default",
|
self, query, user="trustgraph", collection="default",
|
||||||
variables=None, operation_name=None
|
variables=None, operation_name=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute a GraphQL query against structured objects in the knowledge graph.
|
Execute a GraphQL query against structured rows in the knowledge graph.
|
||||||
|
|
||||||
Queries structured data using GraphQL syntax, allowing complex queries
|
Queries structured data using GraphQL syntax, allowing complex queries
|
||||||
with filtering, aggregation, and relationship traversal.
|
with filtering, aggregation, and relationship traversal.
|
||||||
|
|
@ -1038,7 +1038,7 @@ class FlowInstance:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
result = flow.objects_query(
|
result = flow.rows_query(
|
||||||
query=query,
|
query=query,
|
||||||
user="trustgraph",
|
user="trustgraph",
|
||||||
collection="scientists"
|
collection="scientists"
|
||||||
|
|
@ -1053,7 +1053,7 @@ class FlowInstance:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
result = flow.objects_query(
|
result = flow.rows_query(
|
||||||
query=query,
|
query=query,
|
||||||
variables={"name": "Marie Curie"}
|
variables={"name": "Marie Curie"}
|
||||||
)
|
)
|
||||||
|
|
@ -1074,7 +1074,7 @@ class FlowInstance:
|
||||||
input["operation_name"] = operation_name
|
input["operation_name"] = operation_name
|
||||||
|
|
||||||
response = self.request(
|
response = self.request(
|
||||||
"service/objects",
|
"service/rows",
|
||||||
input
|
input
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -789,7 +789,7 @@ class SocketFlowInstance:
|
||||||
|
|
||||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||||
|
|
||||||
def objects_query(
|
def rows_query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
user: str,
|
user: str,
|
||||||
|
|
@ -799,7 +799,7 @@ class SocketFlowInstance:
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a GraphQL query against structured objects.
|
Execute a GraphQL query against structured rows.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: GraphQL query string
|
query: GraphQL query string
|
||||||
|
|
@ -826,7 +826,7 @@ class SocketFlowInstance:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
result = flow.objects_query(
|
result = flow.rows_query(
|
||||||
query=query,
|
query=query,
|
||||||
user="trustgraph",
|
user="trustgraph",
|
||||||
collection="scientists"
|
collection="scientists"
|
||||||
|
|
@ -844,7 +844,7 @@ class SocketFlowInstance:
|
||||||
request["operationName"] = operation_name
|
request["operationName"] = operation_name
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
||||||
return self.client._send_request_sync("objects", self.flow_id, request, False)
|
return self.client._send_request_sync("rows", self.flow_id, request, False)
|
||||||
|
|
||||||
def mcp_tool(
|
def mcp_tool(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from .translators.embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||||
)
|
)
|
||||||
from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||||
|
|
@ -113,9 +113,9 @@ TranslatorRegistry.register_service(
|
||||||
)
|
)
|
||||||
|
|
||||||
TranslatorRegistry.register_service(
|
TranslatorRegistry.register_service(
|
||||||
"objects-query",
|
"rows-query",
|
||||||
ObjectsQueryRequestTranslator(),
|
RowsQueryRequestTranslator(),
|
||||||
ObjectsQueryResponseTranslator()
|
RowsQueryResponseTranslator()
|
||||||
)
|
)
|
||||||
|
|
||||||
TranslatorRegistry.register_service(
|
TranslatorRegistry.register_service(
|
||||||
|
|
|
||||||
|
|
@ -17,5 +17,5 @@ from .embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||||
)
|
)
|
||||||
from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
|
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||||
|
|
|
||||||
|
|
@ -1,44 +1,44 @@
|
||||||
from typing import Dict, Any, Tuple, Optional
|
from typing import Dict, Any, Tuple, Optional
|
||||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
from ...schema import RowsQueryRequest, RowsQueryResponse
|
||||||
from .base import MessageTranslator
|
from .base import MessageTranslator
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
class ObjectsQueryRequestTranslator(MessageTranslator):
|
class RowsQueryRequestTranslator(MessageTranslator):
|
||||||
"""Translator for ObjectsQueryRequest schema objects"""
|
"""Translator for RowsQueryRequest schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryRequest:
|
def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryRequest:
|
||||||
return ObjectsQueryRequest(
|
return RowsQueryRequest(
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default"),
|
collection=data.get("collection", "default"),
|
||||||
query=data.get("query", ""),
|
query=data.get("query", ""),
|
||||||
variables=data.get("variables", {}),
|
variables=data.get("variables", {}),
|
||||||
operation_name=data.get("operation_name", None)
|
operation_name=data.get("operation_name", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: ObjectsQueryRequest) -> Dict[str, Any]:
|
def from_pulsar(self, obj: RowsQueryRequest) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection,
|
"collection": obj.collection,
|
||||||
"query": obj.query,
|
"query": obj.query,
|
||||||
"variables": dict(obj.variables) if obj.variables else {}
|
"variables": dict(obj.variables) if obj.variables else {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if obj.operation_name:
|
if obj.operation_name:
|
||||||
result["operation_name"] = obj.operation_name
|
result["operation_name"] = obj.operation_name
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ObjectsQueryResponseTranslator(MessageTranslator):
|
class RowsQueryResponseTranslator(MessageTranslator):
|
||||||
"""Translator for ObjectsQueryResponse schema objects"""
|
"""Translator for RowsQueryResponse schema objects"""
|
||||||
|
|
||||||
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryResponse:
|
def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryResponse:
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: ObjectsQueryResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: RowsQueryResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# Handle GraphQL response data
|
# Handle GraphQL response data
|
||||||
if obj.data:
|
if obj.data:
|
||||||
try:
|
try:
|
||||||
|
|
@ -47,7 +47,7 @@ class ObjectsQueryResponseTranslator(MessageTranslator):
|
||||||
result["data"] = obj.data
|
result["data"] = obj.data
|
||||||
else:
|
else:
|
||||||
result["data"] = None
|
result["data"] = None
|
||||||
|
|
||||||
# Handle GraphQL errors
|
# Handle GraphQL errors
|
||||||
if obj.errors:
|
if obj.errors:
|
||||||
result["errors"] = []
|
result["errors"] = []
|
||||||
|
|
@ -60,20 +60,20 @@ class ObjectsQueryResponseTranslator(MessageTranslator):
|
||||||
if error.extensions:
|
if error.extensions:
|
||||||
error_dict["extensions"] = dict(error.extensions)
|
error_dict["extensions"] = dict(error.extensions)
|
||||||
result["errors"].append(error_dict)
|
result["errors"].append(error_dict)
|
||||||
|
|
||||||
# Handle extensions
|
# Handle extensions
|
||||||
if obj.extensions:
|
if obj.extensions:
|
||||||
result["extensions"] = dict(obj.extensions)
|
result["extensions"] = dict(obj.extensions)
|
||||||
|
|
||||||
# Handle system-level error
|
# Handle system-level error
|
||||||
if obj.error:
|
if obj.error:
|
||||||
result["error"] = {
|
result["error"] = {
|
||||||
"type": obj.error.type,
|
"type": obj.error.type,
|
||||||
"message": obj.error.message
|
"message": obj.error.message
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: ObjectsQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -60,3 +60,23 @@ class StructuredObjectEmbedding:
|
||||||
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
# Row embeddings are embeddings associated with indexed field values
|
||||||
|
# in structured row data. Each index gets embedded separately.
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowIndexEmbedding:
|
||||||
|
"""Single row's embedding for one index"""
|
||||||
|
index_name: str = "" # The indexed field name(s)
|
||||||
|
index_value: list[str] = field(default_factory=list) # The field value(s)
|
||||||
|
text: str = "" # Text that was embedded
|
||||||
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddings:
|
||||||
|
"""Batched row embeddings for a schema"""
|
||||||
|
metadata: Metadata | None = None
|
||||||
|
schema_name: str = ""
|
||||||
|
embeddings: list[RowIndexEmbedding] = field(default_factory=list)
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from .library import *
|
||||||
from .lookup import *
|
from .lookup import *
|
||||||
from .nlp_query import *
|
from .nlp_query import *
|
||||||
from .structured_query import *
|
from .structured_query import *
|
||||||
from .objects_query import *
|
from .rows_query import *
|
||||||
from .diagnosis import *
|
from .diagnosis import *
|
||||||
from .collection import *
|
from .collection import *
|
||||||
from .storage import *
|
from .storage import *
|
||||||
|
|
@ -59,4 +59,39 @@ document_embeddings_request_queue = topic(
|
||||||
)
|
)
|
||||||
document_embeddings_response_queue = topic(
|
document_embeddings_response_queue = topic(
|
||||||
"document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow'
|
"document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow'
|
||||||
|
)
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
# Row embeddings query - for semantic/fuzzy matching on row index values
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowIndexMatch:
|
||||||
|
"""A single matching row index from a semantic search"""
|
||||||
|
index_name: str = "" # The indexed field(s)
|
||||||
|
index_value: list[str] = field(default_factory=list) # The index values
|
||||||
|
text: str = "" # The text that was embedded
|
||||||
|
score: float = 0.0 # Similarity score
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddingsRequest:
|
||||||
|
"""Request for row embeddings semantic search"""
|
||||||
|
vectors: list[list[float]] = field(default_factory=list) # Query vectors
|
||||||
|
limit: int = 10 # Max results to return
|
||||||
|
user: str = "" # User/keyspace
|
||||||
|
collection: str = "" # Collection name
|
||||||
|
schema_name: str = "" # Schema name to search within
|
||||||
|
index_name: str | None = None # Optional: filter to specific index
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddingsResponse:
|
||||||
|
"""Response from row embeddings semantic search"""
|
||||||
|
error: Error | None = None
|
||||||
|
matches: list[RowIndexMatch] = field(default_factory=list)
|
||||||
|
|
||||||
|
row_embeddings_request_queue = topic(
|
||||||
|
"row-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
||||||
|
)
|
||||||
|
row_embeddings_response_queue = topic(
|
||||||
|
"row-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow'
|
||||||
)
|
)
|
||||||
|
|
@ -6,7 +6,7 @@ from ..core.topic import topic
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Objects Query Service - executes GraphQL queries against structured data
|
# Rows Query Service - executes GraphQL queries against structured data
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphQLError:
|
class GraphQLError:
|
||||||
|
|
@ -15,7 +15,7 @@ class GraphQLError:
|
||||||
extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata
|
extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectsQueryRequest:
|
class RowsQueryRequest:
|
||||||
user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest)
|
user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest)
|
||||||
collection: str = "" # Data collection identifier (required for partition key)
|
collection: str = "" # Data collection identifier (required for partition key)
|
||||||
query: str = "" # GraphQL query string
|
query: str = "" # GraphQL query string
|
||||||
|
|
@ -23,7 +23,7 @@ class ObjectsQueryRequest:
|
||||||
operation_name: Optional[str] = None # Operation to execute for multi-operation documents
|
operation_name: Optional[str] = None # Operation to execute for multi-operation documents
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ObjectsQueryResponse:
|
class RowsQueryResponse:
|
||||||
error: Error | None = None # System-level error (connection, timeout, etc.)
|
error: Error | None = None # System-level error (connection, timeout, etc.)
|
||||||
data: str = "" # JSON-encoded GraphQL response data
|
data: str = "" # JSON-encoded GraphQL response data
|
||||||
errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors
|
errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors
|
||||||
|
|
@ -48,7 +48,7 @@ tg-invoke-graph-embeddings = "trustgraph.cli.invoke_graph_embeddings:main"
|
||||||
tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main"
|
tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main"
|
||||||
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
|
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
|
||||||
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
|
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
|
||||||
tg-invoke-objects-query = "trustgraph.cli.invoke_objects_query:main"
|
tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
||||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Uses the ObjectsQuery service to execute GraphQL queries against structured data
|
Uses the RowsQuery service to execute GraphQL queries against structured data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -81,7 +81,7 @@ def format_table_data(rows, table_name, output_format):
|
||||||
else:
|
else:
|
||||||
return json.dumps({table_name: rows}, indent=2)
|
return json.dumps({table_name: rows}, indent=2)
|
||||||
|
|
||||||
def objects_query(
|
def rows_query(
|
||||||
url, flow_id, query, user, collection, variables, operation_name, output_format='table'
|
url, flow_id, query, user, collection, variables, operation_name, output_format='table'
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ def objects_query(
|
||||||
print(f"Error parsing variables JSON: {e}", file=sys.stderr)
|
print(f"Error parsing variables JSON: {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
resp = api.objects_query(
|
resp = api.rows_query(
|
||||||
query=query,
|
query=query,
|
||||||
user=user,
|
user=user,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
|
|
@ -126,7 +126,7 @@ def objects_query(
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='tg-invoke-objects-query',
|
prog='tg-invoke-rows-query',
|
||||||
description=__doc__,
|
description=__doc__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -181,7 +181,7 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
objects_query(
|
rows_query(
|
||||||
url=args.url,
|
url=args.url,
|
||||||
flow_id=args.flow_id,
|
flow_id=args.flow_id,
|
||||||
query=args.query,
|
query=args.query,
|
||||||
|
|
@ -573,19 +573,19 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample
|
||||||
return output_records, descriptor
|
return output_records, descriptor
|
||||||
|
|
||||||
|
|
||||||
def _send_to_trustgraph(objects, api_url, flow, batch_size=1000, token=None):
|
def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None):
|
||||||
"""Send ExtractedObject records to TrustGraph using Python API"""
|
"""Send ExtractedObject records to TrustGraph using Python API"""
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
|
|
||||||
try:
|
try:
|
||||||
total_records = len(objects)
|
total_records = len(rows)
|
||||||
logger.info(f"Importing {total_records} records to TrustGraph...")
|
logger.info(f"Importing {total_records} records to TrustGraph...")
|
||||||
|
|
||||||
# Use Python API bulk import
|
# Use Python API bulk import
|
||||||
api = Api(api_url, token=token)
|
api = Api(api_url, token=token)
|
||||||
bulk = api.bulk()
|
bulk = api.bulk()
|
||||||
|
|
||||||
bulk.import_objects(flow=flow, objects=iter(objects))
|
bulk.import_rows(flow=flow, rows=iter(rows))
|
||||||
|
|
||||||
logger.info(f"Successfully imported {total_records} records to TrustGraph")
|
logger.info(f"Successfully imported {total_records} records to TrustGraph")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
|
||||||
chunker-recursive = "trustgraph.chunking.recursive:run"
|
chunker-recursive = "trustgraph.chunking.recursive:run"
|
||||||
chunker-token = "trustgraph.chunking.token:run"
|
chunker-token = "trustgraph.chunking.token:run"
|
||||||
config-svc = "trustgraph.config.service:run"
|
config-svc = "trustgraph.config.service:run"
|
||||||
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
|
||||||
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
|
||||||
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
|
||||||
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
|
||||||
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
|
||||||
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
|
||||||
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
|
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
|
||||||
document-rag = "trustgraph.retrieval.document_rag:run"
|
document-rag = "trustgraph.retrieval.document_rag:run"
|
||||||
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
|
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
|
||||||
embeddings-ollama = "trustgraph.embeddings.ollama:run"
|
embeddings-ollama = "trustgraph.embeddings.ollama:run"
|
||||||
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
|
||||||
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
|
||||||
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
|
||||||
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
|
||||||
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
|
||||||
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||||
kg-extract-objects = "trustgraph.extract.kg.objects:run"
|
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||||
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
|
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
|
||||||
kg-extract-topics = "trustgraph.extract.kg.topics:run"
|
kg-extract-topics = "trustgraph.extract.kg.topics:run"
|
||||||
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
|
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
|
||||||
|
|
@ -90,8 +90,11 @@ librarian = "trustgraph.librarian:run"
|
||||||
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
||||||
metering = "trustgraph.metering:run"
|
metering = "trustgraph.metering:run"
|
||||||
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
||||||
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
|
||||||
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
|
||||||
|
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
|
||||||
|
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
|
||||||
|
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
|
||||||
pdf-decoder = "trustgraph.decoding.pdf:run"
|
pdf-decoder = "trustgraph.decoding.pdf:run"
|
||||||
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
|
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
|
||||||
prompt-template = "trustgraph.prompt.template:run"
|
prompt-template = "trustgraph.prompt.template:run"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
|
||||||
|
from . embeddings import *
|
||||||
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
|
||||||
|
from . embeddings import run
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,263 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
Row embeddings processor. Calls the embeddings service to compute embeddings
|
||||||
|
for indexed field values in extracted row data.
|
||||||
|
|
||||||
|
Input is ExtractedObject (structured row data with schema).
|
||||||
|
Output is RowEmbeddings (row data with embeddings for indexed fields).
|
||||||
|
|
||||||
|
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||||
|
Stage 1 (this processor): Compute embeddings
|
||||||
|
Stage 2 (row-embeddings-write-*): Store embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
|
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
|
||||||
|
from ... schema import RowSchema, Field
|
||||||
|
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
|
||||||
|
from ... base import ProducerSpec, CollectionConfigHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_ident = "row-embeddings"
|
||||||
|
default_batch_size = 10
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id", default_ident)
|
||||||
|
self.batch_size = params.get("batch_size", default_batch_size)
|
||||||
|
|
||||||
|
# Config key for schemas
|
||||||
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"id": id,
|
||||||
|
"config_type": self.config_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name="input",
|
||||||
|
schema=ExtractedObject,
|
||||||
|
handler=self.on_message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
EmbeddingsClientSpec(
|
||||||
|
request_name="embeddings-request",
|
||||||
|
response_name="embeddings-response",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ProducerSpec(
|
||||||
|
name="output",
|
||||||
|
schema=RowEmbeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register config handlers
|
||||||
|
self.register_config_handler(self.on_schema_config)
|
||||||
|
self.register_config_handler(self.on_collection_config)
|
||||||
|
|
||||||
|
# Schema storage: name -> RowSchema
|
||||||
|
self.schemas: Dict[str, RowSchema] = {}
|
||||||
|
|
||||||
|
async def on_schema_config(self, config, version):
|
||||||
|
"""Handle schema configuration updates"""
|
||||||
|
logger.info(f"Loading schema configuration version {version}")
|
||||||
|
|
||||||
|
# Clear existing schemas
|
||||||
|
self.schemas = {}
|
||||||
|
|
||||||
|
# Check if our config type exists
|
||||||
|
if self.config_key not in config:
|
||||||
|
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the schemas dictionary for our type
|
||||||
|
schemas_config = config[self.config_key]
|
||||||
|
|
||||||
|
# Process each schema in the schemas config
|
||||||
|
for schema_name, schema_json in schemas_config.items():
|
||||||
|
try:
|
||||||
|
# Parse the JSON schema definition
|
||||||
|
schema_def = json.loads(schema_json)
|
||||||
|
|
||||||
|
# Create Field objects
|
||||||
|
fields = []
|
||||||
|
for field_def in schema_def.get("fields", []):
|
||||||
|
field = Field(
|
||||||
|
name=field_def["name"],
|
||||||
|
type=field_def["type"],
|
||||||
|
size=field_def.get("size", 0),
|
||||||
|
primary=field_def.get("primary_key", False),
|
||||||
|
description=field_def.get("description", ""),
|
||||||
|
required=field_def.get("required", False),
|
||||||
|
enum_values=field_def.get("enum", []),
|
||||||
|
indexed=field_def.get("indexed", False)
|
||||||
|
)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
# Create RowSchema
|
||||||
|
row_schema = RowSchema(
|
||||||
|
name=schema_def.get("name", schema_name),
|
||||||
|
description=schema_def.get("description", ""),
|
||||||
|
fields=fields
|
||||||
|
)
|
||||||
|
|
||||||
|
self.schemas[schema_name] = row_schema
|
||||||
|
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||||
|
|
||||||
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||||
|
"""Get all index names for a schema."""
|
||||||
|
index_names = []
|
||||||
|
for field in schema.fields:
|
||||||
|
if field.primary or field.indexed:
|
||||||
|
index_names.append(field.name)
|
||||||
|
return index_names
|
||||||
|
|
||||||
|
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||||
|
"""Build the index_value list for a given index."""
|
||||||
|
field_names = [f.strip() for f in index_name.split(',')]
|
||||||
|
values = []
|
||||||
|
for field_name in field_names:
|
||||||
|
value = value_map.get(field_name)
|
||||||
|
values.append(str(value) if value is not None else "")
|
||||||
|
return values
|
||||||
|
|
||||||
|
def build_text_for_embedding(self, index_value: List[str]) -> str:
|
||||||
|
"""Build text representation for embedding from index values."""
|
||||||
|
# Space-join the values for composite indexes
|
||||||
|
return " ".join(index_value)
|
||||||
|
|
||||||
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
"""Process incoming ExtractedObject and compute embeddings"""
|
||||||
|
|
||||||
|
obj = msg.value()
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(obj.values)} rows, "
|
||||||
|
f"schema {obj.schema_name}, doc {obj.metadata.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate collection exists before processing
|
||||||
|
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||||
|
logger.warning(
|
||||||
|
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
|
||||||
|
f"does not exist in config. Dropping message."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get schema definition
|
||||||
|
schema = self.schemas.get(obj.schema_name)
|
||||||
|
if not schema:
|
||||||
|
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all index names for this schema
|
||||||
|
index_names = self.get_index_names(schema)
|
||||||
|
|
||||||
|
if not index_names:
|
||||||
|
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Track unique texts to avoid duplicate embeddings
|
||||||
|
# text -> (index_name, index_value)
|
||||||
|
texts_to_embed: Dict[str, tuple] = {}
|
||||||
|
|
||||||
|
# Collect all texts that need embeddings
|
||||||
|
for value_map in obj.values:
|
||||||
|
for index_name in index_names:
|
||||||
|
index_value = self.build_index_value(value_map, index_name)
|
||||||
|
|
||||||
|
# Skip empty values
|
||||||
|
if not index_value or all(v == "" for v in index_value):
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = self.build_text_for_embedding(index_value)
|
||||||
|
if text and text not in texts_to_embed:
|
||||||
|
texts_to_embed[text] = (index_name, index_value)
|
||||||
|
|
||||||
|
if not texts_to_embed:
|
||||||
|
logger.info("No texts to embed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
embeddings_list = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for text, (index_name, index_value) in texts_to_embed.items():
|
||||||
|
vectors = await flow("embeddings-request").embed(text=text)
|
||||||
|
|
||||||
|
embeddings_list.append(
|
||||||
|
RowIndexEmbedding(
|
||||||
|
index_name=index_name,
|
||||||
|
index_value=index_value,
|
||||||
|
text=text,
|
||||||
|
vectors=vectors
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send in batches to avoid oversized messages
|
||||||
|
for i in range(0, len(embeddings_list), self.batch_size):
|
||||||
|
batch = embeddings_list[i:i + self.batch_size]
|
||||||
|
result = RowEmbeddings(
|
||||||
|
metadata=obj.metadata,
|
||||||
|
schema_name=obj.schema_name,
|
||||||
|
embeddings=batch,
|
||||||
|
)
|
||||||
|
await flow("output").send(result)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computed {len(embeddings_list)} embeddings for "
|
||||||
|
f"{len(obj.values)} rows ({len(index_names)} indexes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Exception during embedding computation", exc_info=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||||
|
"""Collection creation notification - no action needed for embedding stage"""
|
||||||
|
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
|
||||||
|
|
||||||
|
async def delete_collection(self, user: str, collection: str):
|
||||||
|
"""Collection deletion notification - no action needed for embedding stage"""
|
||||||
|
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
type=int,
|
||||||
|
default=default_batch_size,
|
||||||
|
help=f'Maximum embeddings per output message (default: {default_batch_size})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-type',
|
||||||
|
default='schema',
|
||||||
|
help='Configuration type prefix for schemas (default: schema)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Object extraction service - extracts structured objects from text chunks
|
Row extraction service - extracts structured rows from text chunks
|
||||||
based on configured schemas.
|
based on configured schemas.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from .... base import PromptClientSpec
|
from .... base import PromptClientSpec
|
||||||
from .... messaging.translators import row_schema_translator
|
from .... messaging.translators import row_schema_translator
|
||||||
|
|
||||||
default_ident = "kg-extract-objects"
|
default_ident = "kg-extract-rows"
|
||||||
|
|
||||||
|
|
||||||
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
|
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
|
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
"""Entry point for kg-extract-objects command"""
|
"""Entry point for kg-extract-rows command"""
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
@ -20,7 +20,7 @@ from . prompt import PromptRequestor
|
||||||
from . graph_rag import GraphRagRequestor
|
from . graph_rag import GraphRagRequestor
|
||||||
from . document_rag import DocumentRagRequestor
|
from . document_rag import DocumentRagRequestor
|
||||||
from . triples_query import TriplesQueryRequestor
|
from . triples_query import TriplesQueryRequestor
|
||||||
from . objects_query import ObjectsQueryRequestor
|
from . rows_query import RowsQueryRequestor
|
||||||
from . nlp_query import NLPQueryRequestor
|
from . nlp_query import NLPQueryRequestor
|
||||||
from . structured_query import StructuredQueryRequestor
|
from . structured_query import StructuredQueryRequestor
|
||||||
from . structured_diag import StructuredDiagRequestor
|
from . structured_diag import StructuredDiagRequestor
|
||||||
|
|
@ -40,7 +40,7 @@ from . triples_import import TriplesImport
|
||||||
from . graph_embeddings_import import GraphEmbeddingsImport
|
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||||
from . entity_contexts_import import EntityContextsImport
|
from . entity_contexts_import import EntityContextsImport
|
||||||
from . objects_import import ObjectsImport
|
from . rows_import import RowsImport
|
||||||
|
|
||||||
from . core_export import CoreExport
|
from . core_export import CoreExport
|
||||||
from . core_import import CoreImport
|
from . core_import import CoreImport
|
||||||
|
|
@ -58,7 +58,7 @@ request_response_dispatchers = {
|
||||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||||
"document-embeddings": DocumentEmbeddingsQueryRequestor,
|
"document-embeddings": DocumentEmbeddingsQueryRequestor,
|
||||||
"triples": TriplesQueryRequestor,
|
"triples": TriplesQueryRequestor,
|
||||||
"objects": ObjectsQueryRequestor,
|
"rows": RowsQueryRequestor,
|
||||||
"nlp-query": NLPQueryRequestor,
|
"nlp-query": NLPQueryRequestor,
|
||||||
"structured-query": StructuredQueryRequestor,
|
"structured-query": StructuredQueryRequestor,
|
||||||
"structured-diag": StructuredDiagRequestor,
|
"structured-diag": StructuredDiagRequestor,
|
||||||
|
|
@ -89,7 +89,7 @@ import_dispatchers = {
|
||||||
"graph-embeddings": GraphEmbeddingsImport,
|
"graph-embeddings": GraphEmbeddingsImport,
|
||||||
"document-embeddings": DocumentEmbeddingsImport,
|
"document-embeddings": DocumentEmbeddingsImport,
|
||||||
"entity-contexts": EntityContextsImport,
|
"entity-contexts": EntityContextsImport,
|
||||||
"objects": ObjectsImport,
|
"rows": RowsImport,
|
||||||
}
|
}
|
||||||
|
|
||||||
class DispatcherWrapper:
|
class DispatcherWrapper:
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from . serialize import to_subgraph
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ObjectsImport:
|
class RowsImport:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, ws, running, backend, queue
|
self, ws, running, backend, queue
|
||||||
|
|
@ -20,7 +20,7 @@ class ObjectsImport:
|
||||||
|
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self.running = running
|
self.running = running
|
||||||
|
|
||||||
self.publisher = Publisher(
|
self.publisher = Publisher(
|
||||||
backend, topic = queue, schema = ExtractedObject
|
backend, topic = queue, schema = ExtractedObject
|
||||||
)
|
)
|
||||||
|
|
@ -73,4 +73,4 @@ class ObjectsImport:
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
self.ws = None
|
self.ws = None
|
||||||
|
|
@ -1,30 +1,30 @@
|
||||||
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
|
from ... schema import RowsQueryRequest, RowsQueryResponse
|
||||||
from ... messaging import TranslatorRegistry
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class ObjectsQueryRequestor(ServiceRequestor):
|
class RowsQueryRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, backend, request_queue, response_queue, timeout,
|
self, backend, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
super(ObjectsQueryRequestor, self).__init__(
|
super(RowsQueryRequestor, self).__init__(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
request_queue=request_queue,
|
request_queue=request_queue,
|
||||||
response_queue=response_queue,
|
response_queue=response_queue,
|
||||||
request_schema=ObjectsQueryRequest,
|
request_schema=RowsQueryRequest,
|
||||||
response_schema=ObjectsQueryResponse,
|
response_schema=RowsQueryResponse,
|
||||||
subscription = subscriber,
|
subscription = subscriber,
|
||||||
consumer_name = consumer,
|
consumer_name = consumer,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
|
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
|
||||||
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
|
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return self.request_translator.to_pulsar(body)
|
return self.request_translator.to_pulsar(body)
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return self.response_translator.from_response_with_completion(message)
|
return self.response_translator.from_response_with_completion(message)
|
||||||
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
22
trustgraph-flow/trustgraph/query/graphql/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""
|
||||||
|
Shared GraphQL utilities for row query services.
|
||||||
|
|
||||||
|
This module provides reusable GraphQL components including:
|
||||||
|
- Filter types (IntFilter, StringFilter, FloatFilter)
|
||||||
|
- Dynamic schema generation from RowSchema definitions
|
||||||
|
- Filter parsing utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||||
|
from .schema import GraphQLSchemaBuilder
|
||||||
|
from .filters import parse_filter_key, parse_where_clause
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IntFilter",
|
||||||
|
"StringFilter",
|
||||||
|
"FloatFilter",
|
||||||
|
"SortDirection",
|
||||||
|
"GraphQLSchemaBuilder",
|
||||||
|
"parse_filter_key",
|
||||||
|
"parse_where_clause",
|
||||||
|
]
|
||||||
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
104
trustgraph-flow/trustgraph/query/graphql/filters.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
"""
|
||||||
|
Filter parsing utilities for GraphQL row queries.
|
||||||
|
|
||||||
|
Provides functions to parse GraphQL filter objects into a normalized
|
||||||
|
format that can be used by different query backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse GraphQL filter key into field name and operator.
|
||||||
|
|
||||||
|
Supports common GraphQL filter patterns:
|
||||||
|
- field_name -> (field_name, "eq")
|
||||||
|
- field_name_gt -> (field_name, "gt")
|
||||||
|
- field_name_gte -> (field_name, "gte")
|
||||||
|
- field_name_lt -> (field_name, "lt")
|
||||||
|
- field_name_lte -> (field_name, "lte")
|
||||||
|
- field_name_in -> (field_name, "in")
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_key: The filter key string from GraphQL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (field_name, operator)
|
||||||
|
"""
|
||||||
|
if not filter_key:
|
||||||
|
return ("", "eq")
|
||||||
|
|
||||||
|
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
||||||
|
|
||||||
|
for op_suffix in operators:
|
||||||
|
if filter_key.endswith(op_suffix):
|
||||||
|
field_name = filter_key[:-len(op_suffix)]
|
||||||
|
operator = op_suffix[1:] # Remove the leading underscore
|
||||||
|
return (field_name, operator)
|
||||||
|
|
||||||
|
# Default to equality if no operator suffix found
|
||||||
|
return (filter_key, "eq")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_where_clause(where_obj) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse the idiomatic nested GraphQL filter structure into a flat dict.
|
||||||
|
|
||||||
|
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
|
||||||
|
into a dictionary mapping field names with operators to values.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Input: where_obj with email.eq = "foo@bar.com"
|
||||||
|
Output: {"email": "foo@bar.com"}
|
||||||
|
|
||||||
|
Input: where_obj with age.gt = 21
|
||||||
|
Output: {"age_gt": 21}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
where_obj: The GraphQL where clause object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping field_operator keys to values
|
||||||
|
"""
|
||||||
|
if not where_obj:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
conditions = {}
|
||||||
|
|
||||||
|
logger.debug(f"Parsing where clause: {where_obj}")
|
||||||
|
|
||||||
|
for field_name, filter_obj in where_obj.__dict__.items():
|
||||||
|
if filter_obj is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
||||||
|
|
||||||
|
if hasattr(filter_obj, '__dict__'):
|
||||||
|
# This is a filter object (StringFilter, IntFilter, etc.)
|
||||||
|
for operator, value in filter_obj.__dict__.items():
|
||||||
|
if value is not None:
|
||||||
|
logger.debug(f"Found operator {operator} with value {value}")
|
||||||
|
# Map GraphQL operators to our internal format
|
||||||
|
if operator == "eq":
|
||||||
|
conditions[field_name] = value
|
||||||
|
elif operator in ["gt", "gte", "lt", "lte"]:
|
||||||
|
conditions[f"{field_name}_{operator}"] = value
|
||||||
|
elif operator == "in_":
|
||||||
|
conditions[f"{field_name}_in"] = value
|
||||||
|
elif operator == "contains":
|
||||||
|
conditions[f"{field_name}_contains"] = value
|
||||||
|
elif operator == "startsWith":
|
||||||
|
conditions[f"{field_name}_startsWith"] = value
|
||||||
|
elif operator == "endsWith":
|
||||||
|
conditions[f"{field_name}_endsWith"] = value
|
||||||
|
elif operator == "not_":
|
||||||
|
conditions[f"{field_name}_not"] = value
|
||||||
|
elif operator == "not_in":
|
||||||
|
conditions[f"{field_name}_not_in"] = value
|
||||||
|
|
||||||
|
logger.debug(f"Final parsed conditions: {conditions}")
|
||||||
|
return conditions
|
||||||
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
251
trustgraph-flow/trustgraph/query/graphql/schema.py
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
"""
|
||||||
|
Dynamic GraphQL schema generation from RowSchema definitions.
|
||||||
|
|
||||||
|
Provides a builder class that creates Strawberry GraphQL schemas
|
||||||
|
from TrustGraph RowSchema definitions, with pluggable query backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List, Callable, Awaitable
|
||||||
|
|
||||||
|
import strawberry
|
||||||
|
from strawberry import Schema
|
||||||
|
from strawberry.types import Info
|
||||||
|
|
||||||
|
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for query callback function
|
||||||
|
QueryCallback = Callable[
|
||||||
|
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
|
||||||
|
Awaitable[List[Dict[str, Any]]]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphQLSchemaBuilder:
|
||||||
|
"""
|
||||||
|
Builds GraphQL schemas from RowSchema definitions.
|
||||||
|
|
||||||
|
This class extracts the GraphQL schema generation logic so it can be
|
||||||
|
reused across different query backends (Cassandra, etc.).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
builder = GraphQLSchemaBuilder()
|
||||||
|
|
||||||
|
# Add schemas
|
||||||
|
for name, row_schema in schemas.items():
|
||||||
|
builder.add_schema(name, row_schema)
|
||||||
|
|
||||||
|
# Build with a query callback
|
||||||
|
schema = builder.build(query_callback)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.schemas: Dict[str, Any] = {} # name -> RowSchema
|
||||||
|
self.graphql_types: Dict[str, type] = {}
|
||||||
|
self.filter_types: Dict[str, type] = {}
|
||||||
|
|
||||||
|
def add_schema(self, name: str, row_schema) -> None:
|
||||||
|
"""
|
||||||
|
Add a RowSchema to the builder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The schema name (used as the GraphQL query field name)
|
||||||
|
row_schema: The RowSchema object defining fields
|
||||||
|
"""
|
||||||
|
self.schemas[name] = row_schema
|
||||||
|
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
|
||||||
|
self.filter_types[name] = self._create_filter_type(name, row_schema)
|
||||||
|
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all schemas from the builder."""
|
||||||
|
self.schemas = {}
|
||||||
|
self.graphql_types = {}
|
||||||
|
self.filter_types = {}
|
||||||
|
|
||||||
|
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
|
||||||
|
"""
|
||||||
|
Build the GraphQL schema with the provided query callback.
|
||||||
|
|
||||||
|
The query callback will be invoked when resolving queries, with:
|
||||||
|
- user: str
|
||||||
|
- collection: str
|
||||||
|
- schema_name: str
|
||||||
|
- row_schema: RowSchema
|
||||||
|
- filters: Dict[str, Any]
|
||||||
|
- limit: int
|
||||||
|
- order_by: Optional[str]
|
||||||
|
- direction: Optional[SortDirection]
|
||||||
|
|
||||||
|
It should return a list of row dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_callback: Async function to execute queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Strawberry Schema, or None if no schemas are loaded
|
||||||
|
"""
|
||||||
|
if not self.schemas:
|
||||||
|
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create the Query class with resolvers
|
||||||
|
query_dict = {'__annotations__': {}}
|
||||||
|
|
||||||
|
for schema_name, row_schema in self.schemas.items():
|
||||||
|
graphql_type = self.graphql_types[schema_name]
|
||||||
|
filter_type = self.filter_types[schema_name]
|
||||||
|
|
||||||
|
# Create resolver function for this schema
|
||||||
|
resolver_func = self._make_resolver(
|
||||||
|
schema_name, row_schema, graphql_type, filter_type, query_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add field to query dictionary
|
||||||
|
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
|
||||||
|
query_dict['__annotations__'][schema_name] = List[graphql_type]
|
||||||
|
|
||||||
|
# Create the Query class
|
||||||
|
Query = type('Query', (), query_dict)
|
||||||
|
Query = strawberry.type(Query)
|
||||||
|
|
||||||
|
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
||||||
|
schema = strawberry.Schema(
|
||||||
|
query=Query,
|
||||||
|
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
||||||
|
)
|
||||||
|
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def _get_python_type(self, field_type: str):
|
||||||
|
"""Convert schema field type to Python type for GraphQL."""
|
||||||
|
type_mapping = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"float": float,
|
||||||
|
"boolean": bool,
|
||||||
|
"timestamp": str, # Use string for timestamps in GraphQL
|
||||||
|
"date": str,
|
||||||
|
"time": str,
|
||||||
|
"uuid": str
|
||||||
|
}
|
||||||
|
return type_mapping.get(field_type, str)
|
||||||
|
|
||||||
|
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
|
||||||
|
"""Create a GraphQL output type from a RowSchema."""
|
||||||
|
# Create annotations for the GraphQL type
|
||||||
|
annotations = {}
|
||||||
|
defaults = {}
|
||||||
|
|
||||||
|
for field in row_schema.fields:
|
||||||
|
python_type = self._get_python_type(field.type)
|
||||||
|
|
||||||
|
# Make field optional if not required
|
||||||
|
if not field.required and not field.primary:
|
||||||
|
annotations[field.name] = Optional[python_type]
|
||||||
|
defaults[field.name] = None
|
||||||
|
else:
|
||||||
|
annotations[field.name] = python_type
|
||||||
|
|
||||||
|
# Create the class dynamically
|
||||||
|
type_name = f"{schema_name.capitalize()}Type"
|
||||||
|
graphql_class = type(
|
||||||
|
type_name,
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"__annotations__": annotations,
|
||||||
|
**defaults
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply strawberry decorator
|
||||||
|
return strawberry.type(graphql_class)
|
||||||
|
|
||||||
|
def _create_filter_type(self, schema_name: str, row_schema) -> type:
|
||||||
|
"""Create a dynamic filter input type for a schema."""
|
||||||
|
filter_type_name = f"{schema_name.capitalize()}Filter"
|
||||||
|
|
||||||
|
# Add __annotations__ and defaults for the fields
|
||||||
|
annotations = {}
|
||||||
|
defaults = {}
|
||||||
|
|
||||||
|
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
||||||
|
|
||||||
|
for field in row_schema.fields:
|
||||||
|
logger.debug(
|
||||||
|
f"Field {field.name}: type={field.type}, "
|
||||||
|
f"indexed={field.indexed}, primary={field.primary}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow filtering on any field
|
||||||
|
if field.type == "integer":
|
||||||
|
annotations[field.name] = Optional[IntFilter]
|
||||||
|
defaults[field.name] = None
|
||||||
|
elif field.type == "float":
|
||||||
|
annotations[field.name] = Optional[FloatFilter]
|
||||||
|
defaults[field.name] = None
|
||||||
|
elif field.type == "string":
|
||||||
|
annotations[field.name] = Optional[StringFilter]
|
||||||
|
defaults[field.name] = None
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the class dynamically
|
||||||
|
FilterType = type(
|
||||||
|
filter_type_name,
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"__annotations__": annotations,
|
||||||
|
**defaults
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply strawberry input decorator
|
||||||
|
FilterType = strawberry.input(FilterType)
|
||||||
|
|
||||||
|
return FilterType
|
||||||
|
|
||||||
|
def _make_resolver(
|
||||||
|
self,
|
||||||
|
schema_name: str,
|
||||||
|
row_schema,
|
||||||
|
graphql_type: type,
|
||||||
|
filter_type: type,
|
||||||
|
query_callback: QueryCallback
|
||||||
|
):
|
||||||
|
"""Create a resolver function for a schema."""
|
||||||
|
from .filters import parse_where_clause
|
||||||
|
|
||||||
|
async def resolver(
|
||||||
|
info: Info,
|
||||||
|
where: Optional[filter_type] = None,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[SortDirection] = None,
|
||||||
|
limit: Optional[int] = 100
|
||||||
|
) -> List[graphql_type]:
|
||||||
|
# Get context values
|
||||||
|
user = info.context["user"]
|
||||||
|
collection = info.context["collection"]
|
||||||
|
|
||||||
|
# Parse the where clause
|
||||||
|
filters = parse_where_clause(where)
|
||||||
|
|
||||||
|
# Call the query backend
|
||||||
|
results = await query_callback(
|
||||||
|
user, collection, schema_name, row_schema,
|
||||||
|
filters, limit, order_by, direction
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to GraphQL types
|
||||||
|
graphql_results = []
|
||||||
|
for row in results:
|
||||||
|
graphql_obj = graphql_type(**row)
|
||||||
|
graphql_results.append(graphql_obj)
|
||||||
|
|
||||||
|
return graphql_results
|
||||||
|
|
||||||
|
return resolver
|
||||||
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
56
trustgraph-flow/trustgraph/query/graphql/types.py
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""
|
||||||
|
GraphQL filter and sort types for row queries.
|
||||||
|
|
||||||
|
These types are used to build dynamic GraphQL schemas for querying
|
||||||
|
structured row data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import strawberry
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.input
|
||||||
|
class IntFilter:
|
||||||
|
"""Filter type for integer fields."""
|
||||||
|
eq: Optional[int] = None
|
||||||
|
gt: Optional[int] = None
|
||||||
|
gte: Optional[int] = None
|
||||||
|
lt: Optional[int] = None
|
||||||
|
lte: Optional[int] = None
|
||||||
|
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
||||||
|
not_: Optional[int] = strawberry.field(name="not", default=None)
|
||||||
|
not_in: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.input
|
||||||
|
class StringFilter:
|
||||||
|
"""Filter type for string fields."""
|
||||||
|
eq: Optional[str] = None
|
||||||
|
contains: Optional[str] = None
|
||||||
|
startsWith: Optional[str] = None
|
||||||
|
endsWith: Optional[str] = None
|
||||||
|
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
||||||
|
not_: Optional[str] = strawberry.field(name="not", default=None)
|
||||||
|
not_in: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.input
|
||||||
|
class FloatFilter:
|
||||||
|
"""Filter type for float fields."""
|
||||||
|
eq: Optional[float] = None
|
||||||
|
gt: Optional[float] = None
|
||||||
|
gte: Optional[float] = None
|
||||||
|
lt: Optional[float] = None
|
||||||
|
lte: Optional[float] = None
|
||||||
|
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
||||||
|
not_: Optional[float] = strawberry.field(name="not", default=None)
|
||||||
|
not_in: Optional[List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@strawberry.enum
|
||||||
|
class SortDirection(Enum):
|
||||||
|
"""Sort direction for query results."""
|
||||||
|
ASC = "asc"
|
||||||
|
DESC = "desc"
|
||||||
|
|
@ -1,738 +0,0 @@
|
||||||
"""
|
|
||||||
Objects query service using GraphQL. Input is a GraphQL query with variables.
|
|
||||||
Output is GraphQL response data with any errors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
from typing import Dict, Any, Optional, List, Set
|
|
||||||
from enum import Enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from cassandra.cluster import Cluster
|
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
|
||||||
|
|
||||||
import strawberry
|
|
||||||
from strawberry import Schema
|
|
||||||
from strawberry.types import Info
|
|
||||||
from strawberry.scalars import JSON
|
|
||||||
from strawberry.tools import create_type
|
|
||||||
|
|
||||||
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
|
||||||
from .... schema import Error, RowSchema, Field as SchemaField
|
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
|
||||||
|
|
||||||
# Module logger
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
default_ident = "objects-query"
|
|
||||||
|
|
||||||
# GraphQL filter input types
|
|
||||||
@strawberry.input
|
|
||||||
class IntFilter:
|
|
||||||
eq: Optional[int] = None
|
|
||||||
gt: Optional[int] = None
|
|
||||||
gte: Optional[int] = None
|
|
||||||
lt: Optional[int] = None
|
|
||||||
lte: Optional[int] = None
|
|
||||||
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
|
|
||||||
not_: Optional[int] = strawberry.field(name="not", default=None)
|
|
||||||
not_in: Optional[List[int]] = None
|
|
||||||
|
|
||||||
@strawberry.input
|
|
||||||
class StringFilter:
|
|
||||||
eq: Optional[str] = None
|
|
||||||
contains: Optional[str] = None
|
|
||||||
startsWith: Optional[str] = None
|
|
||||||
endsWith: Optional[str] = None
|
|
||||||
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
|
|
||||||
not_: Optional[str] = strawberry.field(name="not", default=None)
|
|
||||||
not_in: Optional[List[str]] = None
|
|
||||||
|
|
||||||
@strawberry.input
|
|
||||||
class FloatFilter:
|
|
||||||
eq: Optional[float] = None
|
|
||||||
gt: Optional[float] = None
|
|
||||||
gte: Optional[float] = None
|
|
||||||
lt: Optional[float] = None
|
|
||||||
lte: Optional[float] = None
|
|
||||||
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
|
|
||||||
not_: Optional[float] = strawberry.field(name="not", default=None)
|
|
||||||
not_in: Optional[List[float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
|
|
||||||
id = params.get("id", default_ident)
|
|
||||||
|
|
||||||
# Get Cassandra parameters
|
|
||||||
cassandra_host = params.get("cassandra_host")
|
|
||||||
cassandra_username = params.get("cassandra_username")
|
|
||||||
cassandra_password = params.get("cassandra_password")
|
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
|
||||||
host=cassandra_host,
|
|
||||||
username=cassandra_username,
|
|
||||||
password=cassandra_password
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store resolved configuration with proper names
|
|
||||||
self.cassandra_host = hosts # Store as list
|
|
||||||
self.cassandra_username = username
|
|
||||||
self.cassandra_password = password
|
|
||||||
|
|
||||||
# Config key for schemas
|
|
||||||
self.config_key = params.get("config_type", "schema")
|
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
|
||||||
**params | {
|
|
||||||
"id": id,
|
|
||||||
"config_type": self.config_key,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_specification(
|
|
||||||
ConsumerSpec(
|
|
||||||
name = "request",
|
|
||||||
schema = ObjectsQueryRequest,
|
|
||||||
handler = self.on_message
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_specification(
|
|
||||||
ProducerSpec(
|
|
||||||
name = "response",
|
|
||||||
schema = ObjectsQueryResponse,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register config handler for schema updates
|
|
||||||
self.register_config_handler(self.on_schema_config)
|
|
||||||
|
|
||||||
# Schema storage: name -> RowSchema
|
|
||||||
self.schemas: Dict[str, RowSchema] = {}
|
|
||||||
|
|
||||||
# GraphQL schema
|
|
||||||
self.graphql_schema: Optional[Schema] = None
|
|
||||||
|
|
||||||
# GraphQL types cache
|
|
||||||
self.graphql_types: Dict[str, type] = {}
|
|
||||||
|
|
||||||
# Cassandra session
|
|
||||||
self.cluster = None
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
# Known keyspaces and tables
|
|
||||||
self.known_keyspaces: Set[str] = set()
|
|
||||||
self.known_tables: Dict[str, Set[str]] = {}
|
|
||||||
|
|
||||||
def connect_cassandra(self):
|
|
||||||
"""Connect to Cassandra cluster"""
|
|
||||||
if self.session:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
|
||||||
auth_provider = PlainTextAuthProvider(
|
|
||||||
username=self.cassandra_username,
|
|
||||||
password=self.cassandra_password
|
|
||||||
)
|
|
||||||
self.cluster = Cluster(
|
|
||||||
contact_points=self.cassandra_host,
|
|
||||||
auth_provider=auth_provider
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
|
||||||
|
|
||||||
self.session = self.cluster.connect()
|
|
||||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
|
||||||
"""Sanitize names for Cassandra compatibility"""
|
|
||||||
import re
|
|
||||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
||||||
if safe_name and not safe_name[0].isalpha():
|
|
||||||
safe_name = 'o_' + safe_name
|
|
||||||
return safe_name.lower()
|
|
||||||
|
|
||||||
def sanitize_table(self, name: str) -> str:
|
|
||||||
"""Sanitize table names for Cassandra compatibility"""
|
|
||||||
import re
|
|
||||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
||||||
safe_name = 'o_' + safe_name
|
|
||||||
return safe_name.lower()
|
|
||||||
|
|
||||||
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
|
|
||||||
"""Parse GraphQL filter key into field name and operator"""
|
|
||||||
if not filter_key:
|
|
||||||
return ("", "eq")
|
|
||||||
|
|
||||||
# Support common GraphQL filter patterns:
|
|
||||||
# field_name -> (field_name, "eq")
|
|
||||||
# field_name_gt -> (field_name, "gt")
|
|
||||||
# field_name_gte -> (field_name, "gte")
|
|
||||||
# field_name_lt -> (field_name, "lt")
|
|
||||||
# field_name_lte -> (field_name, "lte")
|
|
||||||
# field_name_in -> (field_name, "in")
|
|
||||||
|
|
||||||
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
|
|
||||||
|
|
||||||
for op_suffix in operators:
|
|
||||||
if filter_key.endswith(op_suffix):
|
|
||||||
field_name = filter_key[:-len(op_suffix)]
|
|
||||||
operator = op_suffix[1:] # Remove the leading underscore
|
|
||||||
return (field_name, operator)
|
|
||||||
|
|
||||||
# Default to equality if no operator suffix found
|
|
||||||
return (filter_key, "eq")
|
|
||||||
|
|
||||||
async def on_schema_config(self, config, version):
|
|
||||||
"""Handle schema configuration updates"""
|
|
||||||
logger.info(f"Loading schema configuration version {version}")
|
|
||||||
|
|
||||||
# Clear existing schemas
|
|
||||||
self.schemas = {}
|
|
||||||
self.graphql_types = {}
|
|
||||||
|
|
||||||
# Check if our config type exists
|
|
||||||
if self.config_key not in config:
|
|
||||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the schemas dictionary for our type
|
|
||||||
schemas_config = config[self.config_key]
|
|
||||||
|
|
||||||
# Process each schema in the schemas config
|
|
||||||
for schema_name, schema_json in schemas_config.items():
|
|
||||||
try:
|
|
||||||
# Parse the JSON schema definition
|
|
||||||
schema_def = json.loads(schema_json)
|
|
||||||
|
|
||||||
# Create Field objects
|
|
||||||
fields = []
|
|
||||||
for field_def in schema_def.get("fields", []):
|
|
||||||
field = SchemaField(
|
|
||||||
name=field_def["name"],
|
|
||||||
type=field_def["type"],
|
|
||||||
size=field_def.get("size", 0),
|
|
||||||
primary=field_def.get("primary_key", False),
|
|
||||||
description=field_def.get("description", ""),
|
|
||||||
required=field_def.get("required", False),
|
|
||||||
enum_values=field_def.get("enum", []),
|
|
||||||
indexed=field_def.get("indexed", False)
|
|
||||||
)
|
|
||||||
fields.append(field)
|
|
||||||
|
|
||||||
# Create RowSchema
|
|
||||||
row_schema = RowSchema(
|
|
||||||
name=schema_def.get("name", schema_name),
|
|
||||||
description=schema_def.get("description", ""),
|
|
||||||
fields=fields
|
|
||||||
)
|
|
||||||
|
|
||||||
self.schemas[schema_name] = row_schema
|
|
||||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
|
||||||
|
|
||||||
# Regenerate GraphQL schema
|
|
||||||
self.generate_graphql_schema()
|
|
||||||
|
|
||||||
def get_python_type(self, field_type: str):
|
|
||||||
"""Convert schema field type to Python type for GraphQL"""
|
|
||||||
type_mapping = {
|
|
||||||
"string": str,
|
|
||||||
"integer": int,
|
|
||||||
"float": float,
|
|
||||||
"boolean": bool,
|
|
||||||
"timestamp": str, # Use string for timestamps in GraphQL
|
|
||||||
"date": str,
|
|
||||||
"time": str,
|
|
||||||
"uuid": str
|
|
||||||
}
|
|
||||||
return type_mapping.get(field_type, str)
|
|
||||||
|
|
||||||
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
|
|
||||||
"""Create a GraphQL type from a RowSchema"""
|
|
||||||
|
|
||||||
# Create annotations for the GraphQL type
|
|
||||||
annotations = {}
|
|
||||||
defaults = {}
|
|
||||||
|
|
||||||
for field in row_schema.fields:
|
|
||||||
python_type = self.get_python_type(field.type)
|
|
||||||
|
|
||||||
# Make field optional if not required
|
|
||||||
if not field.required and not field.primary:
|
|
||||||
annotations[field.name] = Optional[python_type]
|
|
||||||
defaults[field.name] = None
|
|
||||||
else:
|
|
||||||
annotations[field.name] = python_type
|
|
||||||
|
|
||||||
# Create the class dynamically
|
|
||||||
type_name = f"{schema_name.capitalize()}Type"
|
|
||||||
graphql_class = type(
|
|
||||||
type_name,
|
|
||||||
(),
|
|
||||||
{
|
|
||||||
"__annotations__": annotations,
|
|
||||||
**defaults
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply strawberry decorator
|
|
||||||
return strawberry.type(graphql_class)
|
|
||||||
|
|
||||||
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
|
|
||||||
"""Create a dynamic filter input type for a schema"""
|
|
||||||
# Create the filter type dynamically
|
|
||||||
filter_type_name = f"{schema_name.capitalize()}Filter"
|
|
||||||
|
|
||||||
# Add __annotations__ and defaults for the fields
|
|
||||||
annotations = {}
|
|
||||||
defaults = {}
|
|
||||||
|
|
||||||
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
|
|
||||||
|
|
||||||
for field in row_schema.fields:
|
|
||||||
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
|
|
||||||
|
|
||||||
# Allow filtering on any field for now, not just indexed/primary
|
|
||||||
# if field.indexed or field.primary:
|
|
||||||
if field.type == "integer":
|
|
||||||
annotations[field.name] = Optional[IntFilter]
|
|
||||||
defaults[field.name] = None
|
|
||||||
logger.info(f"Added IntFilter for {field.name}")
|
|
||||||
elif field.type == "float":
|
|
||||||
annotations[field.name] = Optional[FloatFilter]
|
|
||||||
defaults[field.name] = None
|
|
||||||
logger.info(f"Added FloatFilter for {field.name}")
|
|
||||||
elif field.type == "string":
|
|
||||||
annotations[field.name] = Optional[StringFilter]
|
|
||||||
defaults[field.name] = None
|
|
||||||
logger.info(f"Added StringFilter for {field.name}")
|
|
||||||
|
|
||||||
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
|
|
||||||
|
|
||||||
# Create the class dynamically
|
|
||||||
FilterType = type(
|
|
||||||
filter_type_name,
|
|
||||||
(),
|
|
||||||
{
|
|
||||||
"__annotations__": annotations,
|
|
||||||
**defaults
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply strawberry input decorator
|
|
||||||
FilterType = strawberry.input(FilterType)
|
|
||||||
|
|
||||||
return FilterType
|
|
||||||
|
|
||||||
def create_sort_direction_enum(self):
|
|
||||||
"""Create sort direction enum"""
|
|
||||||
@strawberry.enum
|
|
||||||
class SortDirection(Enum):
|
|
||||||
ASC = "asc"
|
|
||||||
DESC = "desc"
|
|
||||||
|
|
||||||
return SortDirection
|
|
||||||
|
|
||||||
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
|
|
||||||
"""Parse the idiomatic nested filter structure"""
|
|
||||||
if not where_obj:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
conditions = {}
|
|
||||||
|
|
||||||
logger.info(f"Parsing where clause: {where_obj}")
|
|
||||||
|
|
||||||
for field_name, filter_obj in where_obj.__dict__.items():
|
|
||||||
if filter_obj is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
|
|
||||||
|
|
||||||
if hasattr(filter_obj, '__dict__'):
|
|
||||||
# This is a filter object (StringFilter, IntFilter, etc.)
|
|
||||||
for operator, value in filter_obj.__dict__.items():
|
|
||||||
if value is not None:
|
|
||||||
logger.info(f"Found operator {operator} with value {value}")
|
|
||||||
# Map GraphQL operators to our internal format
|
|
||||||
if operator == "eq":
|
|
||||||
conditions[field_name] = value
|
|
||||||
elif operator in ["gt", "gte", "lt", "lte"]:
|
|
||||||
conditions[f"{field_name}_{operator}"] = value
|
|
||||||
elif operator == "in_":
|
|
||||||
conditions[f"{field_name}_in"] = value
|
|
||||||
elif operator == "contains":
|
|
||||||
conditions[f"{field_name}_contains"] = value
|
|
||||||
|
|
||||||
logger.info(f"Final parsed conditions: {conditions}")
|
|
||||||
return conditions
|
|
||||||
|
|
||||||
def generate_graphql_schema(self):
|
|
||||||
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
|
|
||||||
if not self.schemas:
|
|
||||||
logger.warning("No schemas loaded, cannot generate GraphQL schema")
|
|
||||||
self.graphql_schema = None
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create GraphQL types and filter types for each schema
|
|
||||||
filter_types = {}
|
|
||||||
sort_direction_enum = self.create_sort_direction_enum()
|
|
||||||
|
|
||||||
for schema_name, row_schema in self.schemas.items():
|
|
||||||
graphql_type = self.create_graphql_type(schema_name, row_schema)
|
|
||||||
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
|
|
||||||
|
|
||||||
self.graphql_types[schema_name] = graphql_type
|
|
||||||
filter_types[schema_name] = filter_type
|
|
||||||
|
|
||||||
# Create the Query class with resolvers
|
|
||||||
query_dict = {'__annotations__': {}}
|
|
||||||
|
|
||||||
for schema_name, row_schema in self.schemas.items():
|
|
||||||
graphql_type = self.graphql_types[schema_name]
|
|
||||||
filter_type = filter_types[schema_name]
|
|
||||||
|
|
||||||
# Create resolver function for this schema
|
|
||||||
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
|
|
||||||
async def resolver(
|
|
||||||
info: Info,
|
|
||||||
where: Optional[f_type] = None,
|
|
||||||
order_by: Optional[str] = None,
|
|
||||||
direction: Optional[sort_enum] = None,
|
|
||||||
limit: Optional[int] = 100
|
|
||||||
) -> List[g_type]:
|
|
||||||
# Get the processor instance from context
|
|
||||||
processor = info.context["processor"]
|
|
||||||
user = info.context["user"]
|
|
||||||
collection = info.context["collection"]
|
|
||||||
|
|
||||||
# Parse the idiomatic where clause
|
|
||||||
filters = processor.parse_idiomatic_where_clause(where)
|
|
||||||
|
|
||||||
# Query Cassandra
|
|
||||||
results = await processor.query_cassandra(
|
|
||||||
user, collection, s_name, r_schema,
|
|
||||||
filters, limit, order_by, direction
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to GraphQL types
|
|
||||||
graphql_results = []
|
|
||||||
for row in results:
|
|
||||||
graphql_obj = g_type(**row)
|
|
||||||
graphql_results.append(graphql_obj)
|
|
||||||
|
|
||||||
return graphql_results
|
|
||||||
|
|
||||||
return resolver
|
|
||||||
|
|
||||||
# Add resolver to query
|
|
||||||
resolver_name = schema_name
|
|
||||||
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
|
|
||||||
|
|
||||||
# Add field to query dictionary
|
|
||||||
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
|
|
||||||
query_dict['__annotations__'][resolver_name] = List[graphql_type]
|
|
||||||
|
|
||||||
# Create the Query class
|
|
||||||
Query = type('Query', (), query_dict)
|
|
||||||
Query = strawberry.type(Query)
|
|
||||||
|
|
||||||
# Create the schema with auto_camel_case disabled to keep snake_case field names
|
|
||||||
self.graphql_schema = strawberry.Schema(
|
|
||||||
query=Query,
|
|
||||||
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
|
|
||||||
)
|
|
||||||
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
|
|
||||||
|
|
||||||
async def query_cassandra(
|
|
||||||
self,
|
|
||||||
user: str,
|
|
||||||
collection: str,
|
|
||||||
schema_name: str,
|
|
||||||
row_schema: RowSchema,
|
|
||||||
filters: Dict[str, Any],
|
|
||||||
limit: int,
|
|
||||||
order_by: Optional[str] = None,
|
|
||||||
direction: Optional[Any] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Execute a query against Cassandra"""
|
|
||||||
|
|
||||||
# Connect if needed
|
|
||||||
self.connect_cassandra()
|
|
||||||
|
|
||||||
# Build the query
|
|
||||||
keyspace = self.sanitize_name(user)
|
|
||||||
table = self.sanitize_table(schema_name)
|
|
||||||
|
|
||||||
# Start with basic SELECT
|
|
||||||
query = f"SELECT * FROM {keyspace}.{table}"
|
|
||||||
|
|
||||||
# Add WHERE clauses
|
|
||||||
where_clauses = [f"collection = %s"]
|
|
||||||
params = [collection]
|
|
||||||
|
|
||||||
# Add filters for indexed or primary key fields
|
|
||||||
for filter_key, value in filters.items():
|
|
||||||
if value is not None:
|
|
||||||
# Parse field name and operator from filter key
|
|
||||||
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
|
|
||||||
result = self.parse_filter_key(filter_key)
|
|
||||||
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
|
|
||||||
|
|
||||||
if not result or len(result) != 2:
|
|
||||||
logger.error(f"parse_filter_key returned invalid result: {result}")
|
|
||||||
continue # Skip this filter
|
|
||||||
|
|
||||||
field_name, operator = result
|
|
||||||
|
|
||||||
# Find the field in schema
|
|
||||||
schema_field = None
|
|
||||||
for f in row_schema.fields:
|
|
||||||
if f.name == field_name:
|
|
||||||
schema_field = f
|
|
||||||
break
|
|
||||||
|
|
||||||
if schema_field:
|
|
||||||
safe_field = self.sanitize_name(field_name)
|
|
||||||
|
|
||||||
# Build WHERE clause based on operator
|
|
||||||
if operator == "eq":
|
|
||||||
where_clauses.append(f"{safe_field} = %s")
|
|
||||||
params.append(value)
|
|
||||||
elif operator == "gt":
|
|
||||||
where_clauses.append(f"{safe_field} > %s")
|
|
||||||
params.append(value)
|
|
||||||
elif operator == "gte":
|
|
||||||
where_clauses.append(f"{safe_field} >= %s")
|
|
||||||
params.append(value)
|
|
||||||
elif operator == "lt":
|
|
||||||
where_clauses.append(f"{safe_field} < %s")
|
|
||||||
params.append(value)
|
|
||||||
elif operator == "lte":
|
|
||||||
where_clauses.append(f"{safe_field} <= %s")
|
|
||||||
params.append(value)
|
|
||||||
elif operator == "in":
|
|
||||||
if isinstance(value, list):
|
|
||||||
placeholders = ",".join(["%s"] * len(value))
|
|
||||||
where_clauses.append(f"{safe_field} IN ({placeholders})")
|
|
||||||
params.extend(value)
|
|
||||||
else:
|
|
||||||
# Default to equality for unknown operators
|
|
||||||
where_clauses.append(f"{safe_field} = %s")
|
|
||||||
params.append(value)
|
|
||||||
|
|
||||||
if where_clauses:
|
|
||||||
query += " WHERE " + " AND ".join(where_clauses)
|
|
||||||
|
|
||||||
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
|
|
||||||
cassandra_order_by_added = False
|
|
||||||
if order_by and direction:
|
|
||||||
# Validate that order_by field exists in schema
|
|
||||||
order_field_exists = any(f.name == order_by for f in row_schema.fields)
|
|
||||||
if order_field_exists:
|
|
||||||
safe_order_field = self.sanitize_name(order_by)
|
|
||||||
direction_str = "ASC" if direction.value == "asc" else "DESC"
|
|
||||||
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
|
|
||||||
query += f" ORDER BY {safe_order_field} {direction_str}"
|
|
||||||
|
|
||||||
# Add limit first (must come before ALLOW FILTERING)
|
|
||||||
if limit:
|
|
||||||
query += f" LIMIT {limit}"
|
|
||||||
|
|
||||||
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
|
|
||||||
query += " ALLOW FILTERING"
|
|
||||||
|
|
||||||
# Execute query
|
|
||||||
try:
|
|
||||||
result = self.session.execute(query, params)
|
|
||||||
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
|
|
||||||
except Exception as e:
|
|
||||||
# If ORDER BY fails, try without it
|
|
||||||
if order_by and direction and "ORDER BY" in query:
|
|
||||||
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
|
|
||||||
# Remove ORDER BY clause and retry
|
|
||||||
query_parts = query.split(" ORDER BY ")
|
|
||||||
if len(query_parts) == 2:
|
|
||||||
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
|
|
||||||
result = self.session.execute(query_without_order, params)
|
|
||||||
cassandra_order_by_added = False
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Convert rows to dicts
|
|
||||||
results = []
|
|
||||||
for row in result:
|
|
||||||
row_dict = {}
|
|
||||||
for field in row_schema.fields:
|
|
||||||
safe_field = self.sanitize_name(field.name)
|
|
||||||
if hasattr(row, safe_field):
|
|
||||||
value = getattr(row, safe_field)
|
|
||||||
# Use original field name in result
|
|
||||||
row_dict[field.name] = value
|
|
||||||
results.append(row_dict)
|
|
||||||
|
|
||||||
# Post-query sorting if Cassandra didn't handle ORDER BY
|
|
||||||
if order_by and direction and not cassandra_order_by_added:
|
|
||||||
reverse_order = (direction.value == "desc")
|
|
||||||
try:
|
|
||||||
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def execute_graphql_query(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
variables: Dict[str, Any],
|
|
||||||
operation_name: Optional[str],
|
|
||||||
user: str,
|
|
||||||
collection: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Execute a GraphQL query"""
|
|
||||||
|
|
||||||
if not self.graphql_schema:
|
|
||||||
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
|
||||||
|
|
||||||
# Create context for the query
|
|
||||||
context = {
|
|
||||||
"processor": self,
|
|
||||||
"user": user,
|
|
||||||
"collection": collection
|
|
||||||
}
|
|
||||||
|
|
||||||
# Execute the query
|
|
||||||
result = await self.graphql_schema.execute(
|
|
||||||
query,
|
|
||||||
variable_values=variables,
|
|
||||||
operation_name=operation_name,
|
|
||||||
context_value=context
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build response
|
|
||||||
response = {}
|
|
||||||
|
|
||||||
if result.data:
|
|
||||||
response["data"] = result.data
|
|
||||||
else:
|
|
||||||
response["data"] = None
|
|
||||||
|
|
||||||
if result.errors:
|
|
||||||
response["errors"] = [
|
|
||||||
{
|
|
||||||
"message": str(error),
|
|
||||||
"path": getattr(error, "path", []),
|
|
||||||
"extensions": getattr(error, "extensions", {})
|
|
||||||
}
|
|
||||||
for error in result.errors
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
response["errors"] = []
|
|
||||||
|
|
||||||
# Add extensions if any
|
|
||||||
if hasattr(result, "extensions") and result.extensions:
|
|
||||||
response["extensions"] = result.extensions
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
|
||||||
"""Handle incoming query request"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
request = msg.value()
|
|
||||||
|
|
||||||
# Sender-produced ID
|
|
||||||
id = msg.properties()["id"]
|
|
||||||
|
|
||||||
logger.debug(f"Handling objects query request {id}...")
|
|
||||||
|
|
||||||
# Execute GraphQL query
|
|
||||||
result = await self.execute_graphql_query(
|
|
||||||
query=request.query,
|
|
||||||
variables=dict(request.variables) if request.variables else {},
|
|
||||||
operation_name=request.operation_name,
|
|
||||||
user=request.user,
|
|
||||||
collection=request.collection
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create response
|
|
||||||
graphql_errors = []
|
|
||||||
if "errors" in result and result["errors"]:
|
|
||||||
for err in result["errors"]:
|
|
||||||
graphql_error = GraphQLError(
|
|
||||||
message=err.get("message", ""),
|
|
||||||
path=err.get("path", []),
|
|
||||||
extensions=err.get("extensions", {})
|
|
||||||
)
|
|
||||||
graphql_errors.append(graphql_error)
|
|
||||||
|
|
||||||
response = ObjectsQueryResponse(
|
|
||||||
error=None,
|
|
||||||
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
|
||||||
errors=graphql_errors,
|
|
||||||
extensions=result.get("extensions", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Sending objects query response...")
|
|
||||||
await flow("response").send(response, properties={"id": id})
|
|
||||||
|
|
||||||
logger.debug("Objects query request completed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
|
|
||||||
logger.error(f"Exception in objects query service: {e}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info("Sending error response...")
|
|
||||||
|
|
||||||
response = ObjectsQueryResponse(
|
|
||||||
error = Error(
|
|
||||||
type = "objects-query-error",
|
|
||||||
message = str(e),
|
|
||||||
),
|
|
||||||
data = None,
|
|
||||||
errors = [],
|
|
||||||
extensions = {}
|
|
||||||
)
|
|
||||||
|
|
||||||
await flow("response").send(response, properties={"id": id})
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Clean up Cassandra connections"""
|
|
||||||
if self.cluster:
|
|
||||||
self.cluster.shutdown()
|
|
||||||
logger.info("Closed Cassandra connection")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_args(parser):
|
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
|
||||||
add_cassandra_args(parser)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--config-type',
|
|
||||||
default='schema',
|
|
||||||
help='Configuration type prefix for schemas (default: schema)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
|
||||||
"""Entry point for objects-query-graphql-cassandra command"""
|
|
||||||
Processor.launch(default_ident, __doc__)
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""
|
||||||
|
Row embeddings query modules.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Qdrant row embeddings query service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import Processor, run, default_ident
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
|
||||||
|
from .service import run
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
@ -0,0 +1,209 @@
|
||||||
|
"""
|
||||||
|
Row embeddings query service for Qdrant.
|
||||||
|
|
||||||
|
Input is query vectors plus user/collection/schema context.
|
||||||
|
Output is matching row index information (index_name, index_value) for
|
||||||
|
use in subsequent Cassandra lookups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||||
|
|
||||||
|
from .... schema import (
|
||||||
|
RowEmbeddingsRequest, RowEmbeddingsResponse,
|
||||||
|
RowIndexMatch, Error
|
||||||
|
)
|
||||||
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_ident = "row-embeddings-query"
|
||||||
|
default_store_uri = 'http://localhost:6333'
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
|
store_uri = params.get("store_uri", default_store_uri)
|
||||||
|
api_key = params.get("api_key", None)
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"id": id,
|
||||||
|
"store_uri": store_uri,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name="request",
|
||||||
|
schema=RowEmbeddingsRequest,
|
||||||
|
handler=self.on_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ProducerSpec(
|
||||||
|
name="response",
|
||||||
|
schema=RowEmbeddingsResponse
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
|
||||||
|
def sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names for Qdrant collection naming"""
|
||||||
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||||
|
if safe_name and not safe_name[0].isalpha():
|
||||||
|
safe_name = 'r_' + safe_name
|
||||||
|
return safe_name.lower()
|
||||||
|
|
||||||
|
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
|
||||||
|
"""Find the Qdrant collection for a given user/collection/schema"""
|
||||||
|
prefix = (
|
||||||
|
f"rows_{self.sanitize_name(user)}_"
|
||||||
|
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_collections = self.qdrant.get_collections().collections
|
||||||
|
matching = [
|
||||||
|
coll.name for coll in all_collections
|
||||||
|
if coll.name.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
if matching:
|
||||||
|
# Return first match (there should typically be only one per dimension)
|
||||||
|
return matching[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||||
|
"""Execute row embeddings query"""
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# Find the collection for this user/collection/schema
|
||||||
|
qdrant_collection = self.find_collection(
|
||||||
|
request.user, request.collection, request.schema_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not qdrant_collection:
|
||||||
|
logger.info(
|
||||||
|
f"No Qdrant collection found for "
|
||||||
|
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
for vec in request.vectors:
|
||||||
|
try:
|
||||||
|
# Build optional filter for index_name
|
||||||
|
query_filter = None
|
||||||
|
if request.index_name:
|
||||||
|
query_filter = Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="index_name",
|
||||||
|
match=MatchValue(value=request.index_name)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query Qdrant
|
||||||
|
search_result = self.qdrant.query_points(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
query=vec,
|
||||||
|
limit=request.limit,
|
||||||
|
with_payload=True,
|
||||||
|
query_filter=query_filter,
|
||||||
|
).points
|
||||||
|
|
||||||
|
# Convert to RowIndexMatch objects
|
||||||
|
for point in search_result:
|
||||||
|
payload = point.payload or {}
|
||||||
|
match = RowIndexMatch(
|
||||||
|
index_name=payload.get("index_name", ""),
|
||||||
|
index_value=payload.get("index_value", []),
|
||||||
|
text=payload.get("text", ""),
|
||||||
|
score=point.score if hasattr(point, 'score') else 0.0
|
||||||
|
)
|
||||||
|
matches.append(match)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
"""Handle incoming query request"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = msg.value()
|
||||||
|
|
||||||
|
# Sender-produced ID
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Handling row embeddings query for "
|
||||||
|
f"{request.user}/{request.collection}/{request.schema_name}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
matches = await self.query_row_embeddings(request)
|
||||||
|
|
||||||
|
response = RowEmbeddingsResponse(
|
||||||
|
error=None,
|
||||||
|
matches=matches
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Returning {len(matches)} matches")
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
|
||||||
|
|
||||||
|
response = RowEmbeddingsResponse(
|
||||||
|
error=Error(
|
||||||
|
type="row-embeddings-query-error",
|
||||||
|
message=str(e)
|
||||||
|
),
|
||||||
|
matches=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
"""Add command-line arguments"""
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--store-uri',
|
||||||
|
default=default_store_uri,
|
||||||
|
help=f'Qdrant store URI (default: {default_store_uri})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-k', '--api-key',
|
||||||
|
default=None,
|
||||||
|
help='API key for Qdrant (default: None)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Entry point for row-embeddings-query-qdrant command"""
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
523
trustgraph-flow/trustgraph/query/rows/cassandra/service.py
Normal file
|
|
@ -0,0 +1,523 @@
|
||||||
|
"""
|
||||||
|
Row query service using GraphQL. Input is a GraphQL query with variables.
|
||||||
|
Output is GraphQL response data with any errors.
|
||||||
|
|
||||||
|
Queries against the unified 'rows' table with schema:
|
||||||
|
- collection: text
|
||||||
|
- schema_name: text
|
||||||
|
- index_name: text
|
||||||
|
- index_value: frozen<list<text>>
|
||||||
|
- data: map<text, text>
|
||||||
|
- source: text
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Dict, Any, Optional, List, Set
|
||||||
|
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
|
|
||||||
|
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
|
from .... schema import Error, RowSchema, Field as SchemaField
|
||||||
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
|
||||||
|
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_ident = "rows-query"
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
|
# Get Cassandra parameters
|
||||||
|
cassandra_host = params.get("cassandra_host")
|
||||||
|
cassandra_username = params.get("cassandra_username")
|
||||||
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
|
# Resolve configuration with environment variable fallback
|
||||||
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
|
host=cassandra_host,
|
||||||
|
username=cassandra_username,
|
||||||
|
password=cassandra_password
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store resolved configuration with proper names
|
||||||
|
self.cassandra_host = hosts # Store as list
|
||||||
|
self.cassandra_username = username
|
||||||
|
self.cassandra_password = password
|
||||||
|
|
||||||
|
# Config key for schemas
|
||||||
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"id": id,
|
||||||
|
"config_type": self.config_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name="request",
|
||||||
|
schema=RowsQueryRequest,
|
||||||
|
handler=self.on_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ProducerSpec(
|
||||||
|
name="response",
|
||||||
|
schema=RowsQueryResponse,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register config handler for schema updates
|
||||||
|
self.register_config_handler(self.on_schema_config)
|
||||||
|
|
||||||
|
# Schema storage: name -> RowSchema
|
||||||
|
self.schemas: Dict[str, RowSchema] = {}
|
||||||
|
|
||||||
|
# GraphQL schema builder and generated schema
|
||||||
|
self.schema_builder = GraphQLSchemaBuilder()
|
||||||
|
self.graphql_schema = None
|
||||||
|
|
||||||
|
# Cassandra session
|
||||||
|
self.cluster = None
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
# Known keyspaces
|
||||||
|
self.known_keyspaces: Set[str] = set()
|
||||||
|
|
||||||
|
def connect_cassandra(self):
|
||||||
|
"""Connect to Cassandra cluster"""
|
||||||
|
if self.session:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.cassandra_username and self.cassandra_password:
|
||||||
|
auth_provider = PlainTextAuthProvider(
|
||||||
|
username=self.cassandra_username,
|
||||||
|
password=self.cassandra_password
|
||||||
|
)
|
||||||
|
self.cluster = Cluster(
|
||||||
|
contact_points=self.cassandra_host,
|
||||||
|
auth_provider=auth_provider
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||||
|
|
||||||
|
self.session = self.cluster.connect()
|
||||||
|
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names for Cassandra compatibility"""
|
||||||
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||||
|
if safe_name and not safe_name[0].isalpha():
|
||||||
|
safe_name = 'r_' + safe_name
|
||||||
|
return safe_name.lower()
|
||||||
|
|
||||||
|
async def on_schema_config(self, config, version):
|
||||||
|
"""Handle schema configuration updates"""
|
||||||
|
logger.info(f"Loading schema configuration version {version}")
|
||||||
|
|
||||||
|
# Clear existing schemas
|
||||||
|
self.schemas = {}
|
||||||
|
self.schema_builder.clear()
|
||||||
|
|
||||||
|
# Check if our config type exists
|
||||||
|
if self.config_key not in config:
|
||||||
|
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the schemas dictionary for our type
|
||||||
|
schemas_config = config[self.config_key]
|
||||||
|
|
||||||
|
# Process each schema in the schemas config
|
||||||
|
for schema_name, schema_json in schemas_config.items():
|
||||||
|
try:
|
||||||
|
# Parse the JSON schema definition
|
||||||
|
schema_def = json.loads(schema_json)
|
||||||
|
|
||||||
|
# Create Field objects
|
||||||
|
fields = []
|
||||||
|
for field_def in schema_def.get("fields", []):
|
||||||
|
field = SchemaField(
|
||||||
|
name=field_def["name"],
|
||||||
|
type=field_def["type"],
|
||||||
|
size=field_def.get("size", 0),
|
||||||
|
primary=field_def.get("primary_key", False),
|
||||||
|
description=field_def.get("description", ""),
|
||||||
|
required=field_def.get("required", False),
|
||||||
|
enum_values=field_def.get("enum", []),
|
||||||
|
indexed=field_def.get("indexed", False)
|
||||||
|
)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
# Create RowSchema
|
||||||
|
row_schema = RowSchema(
|
||||||
|
name=schema_def.get("name", schema_name),
|
||||||
|
description=schema_def.get("description", ""),
|
||||||
|
fields=fields
|
||||||
|
)
|
||||||
|
|
||||||
|
self.schemas[schema_name] = row_schema
|
||||||
|
self.schema_builder.add_schema(schema_name, row_schema)
|
||||||
|
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||||
|
|
||||||
|
# Regenerate GraphQL schema
|
||||||
|
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
|
||||||
|
|
||||||
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||||
|
"""Get all index names for a schema."""
|
||||||
|
index_names = []
|
||||||
|
for field in schema.fields:
|
||||||
|
if field.primary or field.indexed:
|
||||||
|
index_names.append(field.name)
|
||||||
|
return index_names
|
||||||
|
|
||||||
|
def find_matching_index(
|
||||||
|
self,
|
||||||
|
schema: RowSchema,
|
||||||
|
filters: Dict[str, Any]
|
||||||
|
) -> Optional[tuple]:
|
||||||
|
"""
|
||||||
|
Find an index that can satisfy the query filters.
|
||||||
|
Returns (index_name, index_value) if found, None otherwise.
|
||||||
|
|
||||||
|
For exact match queries, we need a filter on an indexed field.
|
||||||
|
"""
|
||||||
|
index_names = self.get_index_names(schema)
|
||||||
|
|
||||||
|
# Look for an exact match filter on an indexed field
|
||||||
|
for index_name in index_names:
|
||||||
|
if index_name in filters:
|
||||||
|
value = filters[index_name]
|
||||||
|
# Single field index -> single element list
|
||||||
|
index_value = [str(value)]
|
||||||
|
return (index_name, index_value)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def query_cassandra(
|
||||||
|
self,
|
||||||
|
user: str,
|
||||||
|
collection: str,
|
||||||
|
schema_name: str,
|
||||||
|
row_schema: RowSchema,
|
||||||
|
filters: Dict[str, Any],
|
||||||
|
limit: int,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
direction: Optional[SortDirection] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Execute a query against the unified Cassandra rows table.
|
||||||
|
|
||||||
|
For exact match queries on indexed fields, we can query directly.
|
||||||
|
For other queries, we need to scan and post-filter.
|
||||||
|
"""
|
||||||
|
# Connect if needed
|
||||||
|
self.connect_cassandra()
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(user)
|
||||||
|
|
||||||
|
# Try to find an index that matches the filters
|
||||||
|
index_match = self.find_matching_index(row_schema, filters)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if index_match:
|
||||||
|
# Direct query using index
|
||||||
|
index_name, index_value = index_match
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT data, source FROM {safe_keyspace}.rows
|
||||||
|
WHERE collection = %s
|
||||||
|
AND schema_name = %s
|
||||||
|
AND index_name = %s
|
||||||
|
AND index_value = %s
|
||||||
|
"""
|
||||||
|
params = [collection, schema_name, index_name, index_value]
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.session.execute(query, params)
|
||||||
|
for row in rows:
|
||||||
|
# Convert data map to dict with proper field names
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
results.append(row_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No direct index match - scan all rows for this schema
|
||||||
|
# This is less efficient but necessary for non-indexed queries
|
||||||
|
logger.warning(
|
||||||
|
f"No index match for filters {filters} - scanning all indexes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all index names for this schema
|
||||||
|
index_names = self.get_index_names(row_schema)
|
||||||
|
|
||||||
|
if not index_names:
|
||||||
|
logger.warning(f"Schema {schema_name} has no indexes")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Query using the first index (arbitrary choice for scan)
|
||||||
|
primary_index = index_names[0]
|
||||||
|
|
||||||
|
# We need to scan all values for this index
|
||||||
|
# This requires ALLOW FILTERING or a different approach
|
||||||
|
query = f"""
|
||||||
|
SELECT data, source FROM {safe_keyspace}.rows
|
||||||
|
WHERE collection = %s
|
||||||
|
AND schema_name = %s
|
||||||
|
AND index_name = %s
|
||||||
|
ALLOW FILTERING
|
||||||
|
"""
|
||||||
|
params = [collection, schema_name, primary_index]
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.session.execute(query, params)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
|
||||||
|
# Apply post-filters
|
||||||
|
if self._matches_filters(row_dict, filters, row_schema):
|
||||||
|
results.append(row_dict)
|
||||||
|
|
||||||
|
if limit and len(results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Post-query sorting if requested
|
||||||
|
if order_by and results:
|
||||||
|
reverse_order = direction and direction.value == "desc"
|
||||||
|
try:
|
||||||
|
results.sort(
|
||||||
|
key=lambda x: x.get(order_by, ""),
|
||||||
|
reverse=reverse_order
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to sort results by {order_by}: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _matches_filters(
|
||||||
|
self,
|
||||||
|
row_dict: Dict[str, Any],
|
||||||
|
filters: Dict[str, Any],
|
||||||
|
row_schema: RowSchema
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a row matches the given filters."""
|
||||||
|
for filter_key, filter_value in filters.items():
|
||||||
|
if filter_value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse filter key for operator
|
||||||
|
if '_' in filter_key:
|
||||||
|
parts = filter_key.rsplit('_', 1)
|
||||||
|
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
||||||
|
field_name = parts[0]
|
||||||
|
operator = parts[1]
|
||||||
|
else:
|
||||||
|
field_name = filter_key
|
||||||
|
operator = 'eq'
|
||||||
|
else:
|
||||||
|
field_name = filter_key
|
||||||
|
operator = 'eq'
|
||||||
|
|
||||||
|
row_value = row_dict.get(field_name)
|
||||||
|
if row_value is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Convert types for comparison
|
||||||
|
try:
|
||||||
|
if operator == 'eq':
|
||||||
|
if str(row_value) != str(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'gt':
|
||||||
|
if float(row_value) <= float(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'gte':
|
||||||
|
if float(row_value) < float(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'lt':
|
||||||
|
if float(row_value) >= float(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'lte':
|
||||||
|
if float(row_value) > float(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'contains':
|
||||||
|
if str(filter_value) not in str(row_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'in':
|
||||||
|
if str(row_value) not in [str(v) for v in filter_value]:
|
||||||
|
return False
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def execute_graphql_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
variables: Dict[str, Any],
|
||||||
|
operation_name: Optional[str],
|
||||||
|
user: str,
|
||||||
|
collection: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute a GraphQL query"""
|
||||||
|
|
||||||
|
if not self.graphql_schema:
|
||||||
|
raise RuntimeError("No GraphQL schema available - no schemas loaded")
|
||||||
|
|
||||||
|
# Create context for the query
|
||||||
|
context = {
|
||||||
|
"processor": self,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
result = await self.graphql_schema.execute(
|
||||||
|
query,
|
||||||
|
variable_values=variables,
|
||||||
|
operation_name=operation_name,
|
||||||
|
context_value=context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if result.data:
|
||||||
|
response["data"] = result.data
|
||||||
|
else:
|
||||||
|
response["data"] = None
|
||||||
|
|
||||||
|
if result.errors:
|
||||||
|
response["errors"] = [
|
||||||
|
{
|
||||||
|
"message": str(error),
|
||||||
|
"path": getattr(error, "path", []),
|
||||||
|
"extensions": getattr(error, "extensions", {})
|
||||||
|
}
|
||||||
|
for error in result.errors
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
response["errors"] = []
|
||||||
|
|
||||||
|
# Add extensions if any
|
||||||
|
if hasattr(result, "extensions") and result.extensions:
|
||||||
|
response["extensions"] = result.extensions
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
"""Handle incoming query request"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = msg.value()
|
||||||
|
|
||||||
|
# Sender-produced ID
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
|
logger.debug(f"Handling objects query request {id}...")
|
||||||
|
|
||||||
|
# Execute GraphQL query
|
||||||
|
result = await self.execute_graphql_query(
|
||||||
|
query=request.query,
|
||||||
|
variables=dict(request.variables) if request.variables else {},
|
||||||
|
operation_name=request.operation_name,
|
||||||
|
user=request.user,
|
||||||
|
collection=request.collection
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
graphql_errors = []
|
||||||
|
if "errors" in result and result["errors"]:
|
||||||
|
for err in result["errors"]:
|
||||||
|
graphql_error = GraphQLError(
|
||||||
|
message=err.get("message", ""),
|
||||||
|
path=err.get("path", []),
|
||||||
|
extensions=err.get("extensions", {})
|
||||||
|
)
|
||||||
|
graphql_errors.append(graphql_error)
|
||||||
|
|
||||||
|
response = RowsQueryResponse(
|
||||||
|
error=None,
|
||||||
|
data=json.dumps(result.get("data")) if result.get("data") else "null",
|
||||||
|
errors=graphql_errors,
|
||||||
|
extensions=result.get("extensions", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Sending objects query response...")
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
logger.debug("Objects query request completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
logger.error(f"Exception in rows query service: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info("Sending error response...")
|
||||||
|
|
||||||
|
response = RowsQueryResponse(
|
||||||
|
error=Error(
|
||||||
|
type="rows-query-error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
|
data=None,
|
||||||
|
errors=[],
|
||||||
|
extensions={}
|
||||||
|
)
|
||||||
|
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Clean up Cassandra connections"""
|
||||||
|
if self.cluster:
|
||||||
|
self.cluster.shutdown()
|
||||||
|
logger.info("Closed Cassandra connection")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
"""Add command-line arguments"""
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
add_cassandra_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-type',
|
||||||
|
default='schema',
|
||||||
|
help='Configuration type prefix for schemas (default: schema)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Entry point for rows-query-cassandra command"""
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Structured Query Service - orchestrates natural language question processing.
|
Structured Query Service - orchestrates natural language question processing.
|
||||||
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
|
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
|
||||||
and returns the results.
|
and returns the results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
|
||||||
|
|
||||||
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
from ...schema import StructuredQueryRequest, StructuredQueryResponse
|
||||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||||
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
|
from ...schema import RowsQueryRequest, RowsQueryResponse
|
||||||
from ...schema import Error
|
from ...schema import Error
|
||||||
|
|
||||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
|
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
|
||||||
|
|
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Client spec for calling objects query service
|
# Client spec for calling rows query service
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
RequestResponseSpec(
|
RequestResponseSpec(
|
||||||
request_name = "objects-query-request",
|
request_name = "rows-query-request",
|
||||||
response_name = "objects-query-response",
|
response_name = "rows-query-response",
|
||||||
request_schema = ObjectsQueryRequest,
|
request_schema = RowsQueryRequest,
|
||||||
response_schema = ObjectsQueryResponse
|
response_schema = RowsQueryResponse
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
|
||||||
variables_as_strings[key] = str(value)
|
variables_as_strings[key] = str(value)
|
||||||
|
|
||||||
# Use user/collection values from request
|
# Use user/collection values from request
|
||||||
objects_request = ObjectsQueryRequest(
|
objects_request = RowsQueryRequest(
|
||||||
user=request.user,
|
user=request.user,
|
||||||
collection=request.collection,
|
collection=request.collection,
|
||||||
query=nlp_response.graphql_query,
|
query=nlp_response.graphql_query,
|
||||||
|
|
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
|
||||||
operation_name=None
|
operation_name=None
|
||||||
)
|
)
|
||||||
|
|
||||||
objects_response = await flow("objects-query-request").request(objects_request)
|
objects_response = await flow("rows-query-request").request(objects_request)
|
||||||
|
|
||||||
if objects_response.error is not None:
|
if objects_response.error is not None:
|
||||||
raise Exception(f"Objects query service error: {objects_response.error.message}")
|
raise Exception(f"Rows query service error: {objects_response.error.message}")
|
||||||
|
|
||||||
# Handle GraphQL errors from the objects query service
|
# Handle GraphQL errors from the rows query service
|
||||||
graphql_errors = []
|
graphql_errors = []
|
||||||
if objects_response.errors:
|
if objects_response.errors:
|
||||||
for gql_error in objects_response.errors:
|
for gql_error in objects_response.errors:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "de-write"
|
default_ident = "doc-embeddings-write"
|
||||||
default_store_uri = 'http://localhost:19530'
|
default_store_uri = 'http://localhost:19530'
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "de-write"
|
default_ident = "doc-embeddings-write"
|
||||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||||
default_cloud = "aws"
|
default_cloud = "aws"
|
||||||
default_region = "us-east-1"
|
default_region = "us-east-1"
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "de-write"
|
default_ident = "doc-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
default_store_uri = 'http://localhost:6333'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ def get_term_value(term):
|
||||||
# For blank nodes or other types, use id or value
|
# For blank nodes or other types, use id or value
|
||||||
return term.id or term.value
|
return term.id or term.value
|
||||||
|
|
||||||
default_ident = "ge-write"
|
default_ident = "graph-embeddings-write"
|
||||||
default_store_uri = 'http://localhost:19530'
|
default_store_uri = 'http://localhost:19530'
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ def get_term_value(term):
|
||||||
# For blank nodes or other types, use id or value
|
# For blank nodes or other types, use id or value
|
||||||
return term.id or term.value
|
return term.id or term.value
|
||||||
|
|
||||||
default_ident = "ge-write"
|
default_ident = "graph-embeddings-write"
|
||||||
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
|
||||||
default_cloud = "aws"
|
default_cloud = "aws"
|
||||||
default_region = "us-east-1"
|
default_region = "us-east-1"
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def get_term_value(term):
|
||||||
return term.id or term.value
|
return term.id or term.value
|
||||||
|
|
||||||
|
|
||||||
default_ident = "ge-write"
|
default_ident = "graph-embeddings-write"
|
||||||
|
|
||||||
default_store_uri = 'http://localhost:6333'
|
default_store_uri = 'http://localhost:6333'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# Objects storage module
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from . write import *
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
from . write import run
|
|
||||||
|
|
||||||
run()
|
|
||||||
|
|
@ -1,538 +0,0 @@
|
||||||
"""
|
|
||||||
Object writer for Cassandra. Input is ExtractedObject.
|
|
||||||
Writes structured objects to Cassandra tables based on schema definitions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Set, Optional, Any
|
|
||||||
from cassandra.cluster import Cluster
|
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
|
||||||
from cassandra.cqlengine import connection
|
|
||||||
from cassandra import ConsistencyLevel
|
|
||||||
|
|
||||||
from .... schema import ExtractedObject
|
|
||||||
from .... schema import RowSchema, Field
|
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|
||||||
from .... base import CollectionConfigHandler
|
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
|
||||||
|
|
||||||
# Module logger
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
default_ident = "objects-write"
|
|
||||||
|
|
||||||
class Processor(CollectionConfigHandler, FlowProcessor):
|
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
|
|
||||||
id = params.get("id", default_ident)
|
|
||||||
|
|
||||||
# Get Cassandra parameters
|
|
||||||
cassandra_host = params.get("cassandra_host")
|
|
||||||
cassandra_username = params.get("cassandra_username")
|
|
||||||
cassandra_password = params.get("cassandra_password")
|
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
|
||||||
host=cassandra_host,
|
|
||||||
username=cassandra_username,
|
|
||||||
password=cassandra_password
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store resolved configuration with proper names
|
|
||||||
self.cassandra_host = hosts # Store as list
|
|
||||||
self.cassandra_username = username
|
|
||||||
self.cassandra_password = password
|
|
||||||
|
|
||||||
# Config key for schemas
|
|
||||||
self.config_key = params.get("config_type", "schema")
|
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
|
||||||
**params | {
|
|
||||||
"id": id,
|
|
||||||
"config_type": self.config_key,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_specification(
|
|
||||||
ConsumerSpec(
|
|
||||||
name = "input",
|
|
||||||
schema = ExtractedObject,
|
|
||||||
handler = self.on_object
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register config handlers
|
|
||||||
self.register_config_handler(self.on_schema_config)
|
|
||||||
self.register_config_handler(self.on_collection_config)
|
|
||||||
|
|
||||||
# Cache of known keyspaces/tables
|
|
||||||
self.known_keyspaces: Set[str] = set()
|
|
||||||
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
|
|
||||||
|
|
||||||
# Schema storage: name -> RowSchema
|
|
||||||
self.schemas: Dict[str, RowSchema] = {}
|
|
||||||
|
|
||||||
# Cassandra session
|
|
||||||
self.cluster = None
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
def connect_cassandra(self):
|
|
||||||
"""Connect to Cassandra cluster"""
|
|
||||||
if self.session:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.cassandra_username and self.cassandra_password:
|
|
||||||
auth_provider = PlainTextAuthProvider(
|
|
||||||
username=self.cassandra_username,
|
|
||||||
password=self.cassandra_password
|
|
||||||
)
|
|
||||||
self.cluster = Cluster(
|
|
||||||
contact_points=self.cassandra_host,
|
|
||||||
auth_provider=auth_provider
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
|
||||||
|
|
||||||
self.session = self.cluster.connect()
|
|
||||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def on_schema_config(self, config, version):
|
|
||||||
"""Handle schema configuration updates"""
|
|
||||||
logger.info(f"Loading schema configuration version {version}")
|
|
||||||
|
|
||||||
# Clear existing schemas
|
|
||||||
self.schemas = {}
|
|
||||||
|
|
||||||
# Check if our config type exists
|
|
||||||
if self.config_key not in config:
|
|
||||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the schemas dictionary for our type
|
|
||||||
schemas_config = config[self.config_key]
|
|
||||||
|
|
||||||
# Process each schema in the schemas config
|
|
||||||
for schema_name, schema_json in schemas_config.items():
|
|
||||||
try:
|
|
||||||
# Parse the JSON schema definition
|
|
||||||
schema_def = json.loads(schema_json)
|
|
||||||
|
|
||||||
# Create Field objects
|
|
||||||
fields = []
|
|
||||||
for field_def in schema_def.get("fields", []):
|
|
||||||
field = Field(
|
|
||||||
name=field_def["name"],
|
|
||||||
type=field_def["type"],
|
|
||||||
size=field_def.get("size", 0),
|
|
||||||
primary=field_def.get("primary_key", False),
|
|
||||||
description=field_def.get("description", ""),
|
|
||||||
required=field_def.get("required", False),
|
|
||||||
enum_values=field_def.get("enum", []),
|
|
||||||
indexed=field_def.get("indexed", False)
|
|
||||||
)
|
|
||||||
fields.append(field)
|
|
||||||
|
|
||||||
# Create RowSchema
|
|
||||||
row_schema = RowSchema(
|
|
||||||
name=schema_def.get("name", schema_name),
|
|
||||||
description=schema_def.get("description", ""),
|
|
||||||
fields=fields
|
|
||||||
)
|
|
||||||
|
|
||||||
self.schemas[schema_name] = row_schema
|
|
||||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
|
||||||
|
|
||||||
def ensure_keyspace(self, keyspace: str):
|
|
||||||
"""Ensure keyspace exists in Cassandra"""
|
|
||||||
if keyspace in self.known_keyspaces:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Connect if needed
|
|
||||||
self.connect_cassandra()
|
|
||||||
|
|
||||||
# Sanitize keyspace name
|
|
||||||
safe_keyspace = self.sanitize_name(keyspace)
|
|
||||||
|
|
||||||
# Create keyspace if not exists
|
|
||||||
create_keyspace_cql = f"""
|
|
||||||
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
|
||||||
WITH REPLICATION = {{
|
|
||||||
'class': 'SimpleStrategy',
|
|
||||||
'replication_factor': 1
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.session.execute(create_keyspace_cql)
|
|
||||||
self.known_keyspaces.add(keyspace)
|
|
||||||
self.known_tables[keyspace] = set()
|
|
||||||
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
|
|
||||||
"""Convert schema field type to Cassandra type"""
|
|
||||||
# Handle None size
|
|
||||||
if size is None:
|
|
||||||
size = 0
|
|
||||||
|
|
||||||
type_mapping = {
|
|
||||||
"string": "text",
|
|
||||||
"integer": "bigint" if size > 4 else "int",
|
|
||||||
"float": "double" if size > 4 else "float",
|
|
||||||
"boolean": "boolean",
|
|
||||||
"timestamp": "timestamp",
|
|
||||||
"date": "date",
|
|
||||||
"time": "time",
|
|
||||||
"uuid": "uuid"
|
|
||||||
}
|
|
||||||
|
|
||||||
return type_mapping.get(field_type, "text")
|
|
||||||
|
|
||||||
def sanitize_name(self, name: str) -> str:
|
|
||||||
"""Sanitize names for Cassandra compatibility"""
|
|
||||||
# Replace non-alphanumeric characters with underscore
|
|
||||||
import re
|
|
||||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
||||||
# Ensure it starts with a letter
|
|
||||||
if safe_name and not safe_name[0].isalpha():
|
|
||||||
safe_name = 'o_' + safe_name
|
|
||||||
return safe_name.lower()
|
|
||||||
|
|
||||||
def sanitize_table(self, name: str) -> str:
|
|
||||||
"""Sanitize names for Cassandra compatibility"""
|
|
||||||
# Replace non-alphanumeric characters with underscore
|
|
||||||
import re
|
|
||||||
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
|
||||||
# Ensure it starts with a letter
|
|
||||||
safe_name = 'o_' + safe_name
|
|
||||||
return safe_name.lower()
|
|
||||||
|
|
||||||
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
|
|
||||||
"""Ensure table exists with proper structure"""
|
|
||||||
table_key = f"{keyspace}.{table_name}"
|
|
||||||
if table_key in self.known_tables.get(keyspace, set()):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Ensure keyspace exists first
|
|
||||||
self.ensure_keyspace(keyspace)
|
|
||||||
|
|
||||||
safe_keyspace = self.sanitize_name(keyspace)
|
|
||||||
safe_table = self.sanitize_table(table_name)
|
|
||||||
|
|
||||||
# Build column definitions
|
|
||||||
columns = ["collection text"] # Collection is always part of table
|
|
||||||
primary_key_fields = []
|
|
||||||
clustering_fields = []
|
|
||||||
|
|
||||||
for field in schema.fields:
|
|
||||||
safe_field_name = self.sanitize_name(field.name)
|
|
||||||
cassandra_type = self.get_cassandra_type(field.type, field.size)
|
|
||||||
columns.append(f"{safe_field_name} {cassandra_type}")
|
|
||||||
|
|
||||||
if field.primary:
|
|
||||||
primary_key_fields.append(safe_field_name)
|
|
||||||
|
|
||||||
# Build primary key - collection is always first in partition key
|
|
||||||
if primary_key_fields:
|
|
||||||
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
|
|
||||||
else:
|
|
||||||
# If no primary key defined, use collection and a synthetic id
|
|
||||||
columns.append("synthetic_id uuid")
|
|
||||||
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
|
|
||||||
|
|
||||||
# Create table
|
|
||||||
create_table_cql = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
|
|
||||||
{', '.join(columns)},
|
|
||||||
{primary_key}
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.session.execute(create_table_cql)
|
|
||||||
if keyspace not in self.known_tables:
|
|
||||||
self.known_tables[keyspace] = set()
|
|
||||||
self.known_tables[keyspace].add(table_key)
|
|
||||||
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
|
|
||||||
|
|
||||||
# Create secondary indexes for indexed fields
|
|
||||||
for field in schema.fields:
|
|
||||||
if field.indexed and not field.primary:
|
|
||||||
safe_field_name = self.sanitize_name(field.name)
|
|
||||||
index_name = f"{safe_table}_{safe_field_name}_idx"
|
|
||||||
create_index_cql = f"""
|
|
||||||
CREATE INDEX IF NOT EXISTS {index_name}
|
|
||||||
ON {safe_keyspace}.{safe_table} ({safe_field_name})
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.session.execute(create_index_cql)
|
|
||||||
logger.info(f"Created index: {index_name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to create index {index_name}: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def convert_value(self, value: Any, field_type: str) -> Any:
|
|
||||||
"""Convert value to appropriate type for Cassandra"""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
if field_type == "integer":
|
|
||||||
return int(value)
|
|
||||||
elif field_type == "float":
|
|
||||||
return float(value)
|
|
||||||
elif field_type == "boolean":
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value.lower() in ('true', '1', 'yes')
|
|
||||||
return bool(value)
|
|
||||||
elif field_type == "timestamp":
|
|
||||||
# Handle timestamp conversion if needed
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
return str(value)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
|
|
||||||
return str(value)
|
|
||||||
async def on_object(self, msg, consumer, flow):
|
|
||||||
"""Process incoming ExtractedObject and store in Cassandra"""
|
|
||||||
|
|
||||||
obj = msg.value()
|
|
||||||
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
|
|
||||||
|
|
||||||
# Validate collection exists before accepting writes
|
|
||||||
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
|
||||||
error_msg = (
|
|
||||||
f"Collection {obj.metadata.collection} does not exist. "
|
|
||||||
f"Create it first via collection management API."
|
|
||||||
)
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Get schema definition
|
|
||||||
schema = self.schemas.get(obj.schema_name)
|
|
||||||
if not schema:
|
|
||||||
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Ensure table exists
|
|
||||||
keyspace = obj.metadata.user
|
|
||||||
table_name = obj.schema_name
|
|
||||||
self.ensure_table(keyspace, table_name, schema)
|
|
||||||
|
|
||||||
# Prepare data for insertion
|
|
||||||
safe_keyspace = self.sanitize_name(keyspace)
|
|
||||||
safe_table = self.sanitize_table(table_name)
|
|
||||||
|
|
||||||
# Process each object in the batch
|
|
||||||
for obj_index, value_map in enumerate(obj.values):
|
|
||||||
# Build column names and values for this object
|
|
||||||
columns = ["collection"]
|
|
||||||
values = [obj.metadata.collection]
|
|
||||||
placeholders = ["%s"]
|
|
||||||
|
|
||||||
# Check if we need a synthetic ID
|
|
||||||
has_primary_key = any(field.primary for field in schema.fields)
|
|
||||||
if not has_primary_key:
|
|
||||||
import uuid
|
|
||||||
columns.append("synthetic_id")
|
|
||||||
values.append(uuid.uuid4())
|
|
||||||
placeholders.append("%s")
|
|
||||||
|
|
||||||
# Process fields for this object
|
|
||||||
skip_object = False
|
|
||||||
for field in schema.fields:
|
|
||||||
safe_field_name = self.sanitize_name(field.name)
|
|
||||||
raw_value = value_map.get(field.name)
|
|
||||||
|
|
||||||
# Handle required fields
|
|
||||||
if field.required and raw_value is None:
|
|
||||||
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
|
|
||||||
# Continue anyway - Cassandra doesn't enforce NOT NULL
|
|
||||||
|
|
||||||
# Check if primary key field is NULL
|
|
||||||
if field.primary and raw_value is None:
|
|
||||||
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
|
|
||||||
skip_object = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Convert value to appropriate type
|
|
||||||
converted_value = self.convert_value(raw_value, field.type)
|
|
||||||
|
|
||||||
columns.append(safe_field_name)
|
|
||||||
values.append(converted_value)
|
|
||||||
placeholders.append("%s")
|
|
||||||
|
|
||||||
# Skip this object if primary key validation failed
|
|
||||||
if skip_object:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Build and execute insert query for this object
|
|
||||||
insert_cql = f"""
|
|
||||||
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
|
|
||||||
VALUES ({', '.join(placeholders)})
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Debug: Show data being inserted
|
|
||||||
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
|
|
||||||
|
|
||||||
if len(columns) != len(values) or len(columns) != len(placeholders):
|
|
||||||
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Convert to tuple - Cassandra driver requires tuple for parameters
|
|
||||||
self.session.execute(insert_cql, tuple(values))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
|
||||||
"""Create/verify collection exists in Cassandra object store"""
|
|
||||||
# Connect if not already connected
|
|
||||||
self.connect_cassandra()
|
|
||||||
|
|
||||||
# Sanitize names for safety
|
|
||||||
safe_keyspace = self.sanitize_name(user)
|
|
||||||
|
|
||||||
# Ensure keyspace exists
|
|
||||||
if safe_keyspace not in self.known_keyspaces:
|
|
||||||
self.ensure_keyspace(safe_keyspace)
|
|
||||||
self.known_keyspaces.add(safe_keyspace)
|
|
||||||
|
|
||||||
# For Cassandra objects, collection is just a property in rows
|
|
||||||
# No need to create separate tables per collection
|
|
||||||
# Just mark that we've seen this collection
|
|
||||||
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
|
|
||||||
|
|
||||||
async def delete_collection(self, user: str, collection: str):
|
|
||||||
"""Delete all data for a specific collection using schema information"""
|
|
||||||
# Connect if not already connected
|
|
||||||
self.connect_cassandra()
|
|
||||||
|
|
||||||
# Sanitize names for safety
|
|
||||||
safe_keyspace = self.sanitize_name(user)
|
|
||||||
|
|
||||||
# Check if keyspace exists
|
|
||||||
if safe_keyspace not in self.known_keyspaces:
|
|
||||||
# Query to verify keyspace exists
|
|
||||||
check_keyspace_cql = """
|
|
||||||
SELECT keyspace_name FROM system_schema.keyspaces
|
|
||||||
WHERE keyspace_name = %s
|
|
||||||
"""
|
|
||||||
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
|
||||||
if not result.one():
|
|
||||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
|
||||||
return
|
|
||||||
self.known_keyspaces.add(safe_keyspace)
|
|
||||||
|
|
||||||
# Iterate over schemas we manage to delete from relevant tables
|
|
||||||
tables_deleted = 0
|
|
||||||
|
|
||||||
for schema_name, schema in self.schemas.items():
|
|
||||||
safe_table = self.sanitize_table(schema_name)
|
|
||||||
|
|
||||||
# Check if table exists
|
|
||||||
table_key = f"{user}.{schema_name}"
|
|
||||||
if table_key not in self.known_tables.get(user, set()):
|
|
||||||
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get primary key fields from schema
|
|
||||||
primary_key_fields = [field for field in schema.fields if field.primary]
|
|
||||||
|
|
||||||
if primary_key_fields:
|
|
||||||
# Schema has primary keys: need to query for partition keys first
|
|
||||||
# Build SELECT query for primary key fields
|
|
||||||
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
|
|
||||||
select_cql = f"""
|
|
||||||
SELECT {', '.join(pk_field_names)}
|
|
||||||
FROM {safe_keyspace}.{safe_table}
|
|
||||||
WHERE collection = %s
|
|
||||||
ALLOW FILTERING
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = self.session.execute(select_cql, (collection,))
|
|
||||||
|
|
||||||
# Delete each row using full partition key
|
|
||||||
for row in rows:
|
|
||||||
where_clauses = ["collection = %s"]
|
|
||||||
values = [collection]
|
|
||||||
|
|
||||||
for field_name in pk_field_names:
|
|
||||||
where_clauses.append(f"{field_name} = %s")
|
|
||||||
values.append(getattr(row, field_name))
|
|
||||||
|
|
||||||
delete_cql = f"""
|
|
||||||
DELETE FROM {safe_keyspace}.{safe_table}
|
|
||||||
WHERE {' AND '.join(where_clauses)}
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.session.execute(delete_cql, tuple(values))
|
|
||||||
else:
|
|
||||||
# No primary keys, uses synthetic_id
|
|
||||||
# Need to query for synthetic_ids first
|
|
||||||
select_cql = f"""
|
|
||||||
SELECT synthetic_id
|
|
||||||
FROM {safe_keyspace}.{safe_table}
|
|
||||||
WHERE collection = %s
|
|
||||||
ALLOW FILTERING
|
|
||||||
"""
|
|
||||||
|
|
||||||
rows = self.session.execute(select_cql, (collection,))
|
|
||||||
|
|
||||||
# Delete each row using collection and synthetic_id
|
|
||||||
for row in rows:
|
|
||||||
delete_cql = f"""
|
|
||||||
DELETE FROM {safe_keyspace}.{safe_table}
|
|
||||||
WHERE collection = %s AND synthetic_id = %s
|
|
||||||
"""
|
|
||||||
self.session.execute(delete_cql, (collection, row.synthetic_id))
|
|
||||||
|
|
||||||
tables_deleted += 1
|
|
||||||
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Clean up Cassandra connections"""
|
|
||||||
if self.cluster:
|
|
||||||
self.cluster.shutdown()
|
|
||||||
logger.info("Closed Cassandra connection")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_args(parser):
|
|
||||||
"""Add command-line arguments"""
|
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
|
||||||
add_cassandra_args(parser)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--config-type',
|
|
||||||
default='schema',
|
|
||||||
help='Configuration type prefix for schemas (default: schema)'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
|
||||||
"""Entry point for objects-write-cassandra command"""
|
|
||||||
Processor.launch(default_ident, __doc__)
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""
|
||||||
|
Row embeddings storage modules.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Qdrant storage for row embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .write import Processor, run, default_ident
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
|
||||||
|
from .write import run
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
"""
|
||||||
|
Row embeddings writer for Qdrant (Stage 2).
|
||||||
|
|
||||||
|
Consumes RowEmbeddings messages (which already contain computed vectors)
|
||||||
|
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
|
||||||
|
|
||||||
|
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
|
||||||
|
Stage 1 (row-embeddings): Compute embeddings
|
||||||
|
Stage 2 (this processor): Store embeddings
|
||||||
|
|
||||||
|
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
|
||||||
|
|
||||||
|
Payload structure:
|
||||||
|
- index_name: The indexed field(s) this embedding represents
|
||||||
|
- index_value: The original list of values (for Cassandra lookup)
|
||||||
|
- text: The text that was embedded (for debugging/display)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Set, Tuple
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||||
|
|
||||||
|
from .... schema import RowEmbeddings
|
||||||
|
from .... base import FlowProcessor, ConsumerSpec
|
||||||
|
from .... base import CollectionConfigHandler
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_ident = "row-embeddings-write"
|
||||||
|
default_store_uri = 'http://localhost:6333'
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
|
store_uri = params.get("store_uri", default_store_uri)
|
||||||
|
api_key = params.get("api_key", None)
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"id": id,
|
||||||
|
"store_uri": store_uri,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name="input",
|
||||||
|
schema=RowEmbeddings,
|
||||||
|
handler=self.on_embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register config handler for collection management
|
||||||
|
self.register_config_handler(self.on_collection_config)
|
||||||
|
|
||||||
|
# Cache of created Qdrant collections
|
||||||
|
self.created_collections: Set[str] = set()
|
||||||
|
|
||||||
|
# Qdrant client
|
||||||
|
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||||
|
|
||||||
|
def sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names for Qdrant collection naming"""
|
||||||
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||||
|
if safe_name and not safe_name[0].isalpha():
|
||||||
|
safe_name = 'r_' + safe_name
|
||||||
|
return safe_name.lower()
|
||||||
|
|
||||||
|
def get_collection_name(
|
||||||
|
self, user: str, collection: str, schema_name: str, dimension: int
|
||||||
|
) -> str:
|
||||||
|
"""Generate Qdrant collection name"""
|
||||||
|
safe_user = self.sanitize_name(user)
|
||||||
|
safe_collection = self.sanitize_name(collection)
|
||||||
|
safe_schema = self.sanitize_name(schema_name)
|
||||||
|
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
||||||
|
|
||||||
|
def ensure_collection(self, collection_name: str, dimension: int):
|
||||||
|
"""Create Qdrant collection if it doesn't exist"""
|
||||||
|
if collection_name in self.created_collections:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.qdrant.collection_exists(collection_name):
|
||||||
|
logger.info(
|
||||||
|
f"Creating Qdrant collection {collection_name} "
|
||||||
|
f"with dimension {dimension}"
|
||||||
|
)
|
||||||
|
self.qdrant.create_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dimension,
|
||||||
|
distance=Distance.COSINE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.created_collections.add(collection_name)
|
||||||
|
|
||||||
|
async def on_embeddings(self, msg, consumer, flow):
|
||||||
|
"""Process incoming RowEmbeddings and write to Qdrant"""
|
||||||
|
|
||||||
|
embeddings = msg.value()
|
||||||
|
logger.info(
|
||||||
|
f"Writing {len(embeddings.embeddings)} embeddings for schema "
|
||||||
|
f"{embeddings.schema_name} from {embeddings.metadata.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate collection exists in config before processing
|
||||||
|
if not self.collection_exists(
|
||||||
|
embeddings.metadata.user, embeddings.metadata.collection
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Collection {embeddings.metadata.collection} for user "
|
||||||
|
f"{embeddings.metadata.user} does not exist in config. "
|
||||||
|
f"Dropping message."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
user = embeddings.metadata.user
|
||||||
|
collection = embeddings.metadata.collection
|
||||||
|
schema_name = embeddings.schema_name
|
||||||
|
|
||||||
|
embeddings_written = 0
|
||||||
|
qdrant_collection = None
|
||||||
|
|
||||||
|
for row_emb in embeddings.embeddings:
|
||||||
|
if not row_emb.vectors:
|
||||||
|
logger.warning(
|
||||||
|
f"No vectors for index {row_emb.index_name} - skipping"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use first vector (there may be multiple from different models)
|
||||||
|
for vector in row_emb.vectors:
|
||||||
|
dimension = len(vector)
|
||||||
|
|
||||||
|
# Create/get collection name (lazily on first vector)
|
||||||
|
if qdrant_collection is None:
|
||||||
|
qdrant_collection = self.get_collection_name(
|
||||||
|
user, collection, schema_name, dimension
|
||||||
|
)
|
||||||
|
self.ensure_collection(qdrant_collection, dimension)
|
||||||
|
|
||||||
|
# Write to Qdrant
|
||||||
|
self.qdrant.upsert(
|
||||||
|
collection_name=qdrant_collection,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"index_name": row_emb.index_name,
|
||||||
|
"index_value": row_emb.index_value,
|
||||||
|
"text": row_emb.text
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
embeddings_written += 1
|
||||||
|
|
||||||
|
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||||
|
|
||||||
|
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||||
|
"""Collection creation via config push - collections created lazily on first write"""
|
||||||
|
logger.info(
|
||||||
|
f"Row embeddings collection create request for {user}/{collection} - "
|
||||||
|
f"will be created lazily on first write"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_collection(self, user: str, collection: str):
|
||||||
|
"""Delete all Qdrant collections for a given user/collection"""
|
||||||
|
try:
|
||||||
|
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
|
||||||
|
|
||||||
|
# Get all collections and filter for matches
|
||||||
|
all_collections = self.qdrant.get_collections().collections
|
||||||
|
matching_collections = [
|
||||||
|
coll.name for coll in all_collections
|
||||||
|
if coll.name.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not matching_collections:
|
||||||
|
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||||
|
else:
|
||||||
|
for collection_name in matching_collections:
|
||||||
|
self.qdrant.delete_collection(collection_name)
|
||||||
|
self.created_collections.discard(collection_name)
|
||||||
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
|
logger.info(
|
||||||
|
f"Deleted {len(matching_collections)} collection(s) "
|
||||||
|
f"for {user}/{collection}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete collection {user}/{collection}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_collection_schema(
|
||||||
|
self, user: str, collection: str, schema_name: str
|
||||||
|
):
|
||||||
|
"""Delete Qdrant collection for a specific user/collection/schema"""
|
||||||
|
try:
|
||||||
|
prefix = (
|
||||||
|
f"rows_{self.sanitize_name(user)}_"
|
||||||
|
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all collections and filter for matches
|
||||||
|
all_collections = self.qdrant.get_collections().collections
|
||||||
|
matching_collections = [
|
||||||
|
coll.name for coll in all_collections
|
||||||
|
if coll.name.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not matching_collections:
|
||||||
|
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||||
|
else:
|
||||||
|
for collection_name in matching_collections:
|
||||||
|
self.qdrant.delete_collection(collection_name)
|
||||||
|
self.created_collections.discard(collection_name)
|
||||||
|
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
"""Add command-line arguments"""
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--store-uri',
|
||||||
|
default=default_store_uri,
|
||||||
|
help=f'Qdrant URI (default: {default_store_uri})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-k', '--api-key',
|
||||||
|
default=None,
|
||||||
|
help='Qdrant API key (default: None)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Entry point for row-embeddings-write-qdrant command"""
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
@ -1,46 +1,49 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
Row writer for Cassandra. Input is ExtractedObject.
|
||||||
|
Writes structured rows to a unified Cassandra table with multi-index support.
|
||||||
|
|
||||||
|
Uses a single 'rows' table with the schema:
|
||||||
|
- collection: text
|
||||||
|
- schema_name: text
|
||||||
|
- index_name: text
|
||||||
|
- index_value: frozen<list<text>>
|
||||||
|
- data: map<text, text>
|
||||||
|
- source: text
|
||||||
|
|
||||||
|
Each row is written multiple times - once per indexed field defined in the schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise RuntimeError("This code is no longer in use")
|
import json
|
||||||
|
|
||||||
import pulsar
|
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Dict, Set, Optional, Any, List, Tuple
|
||||||
|
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.auth import PlainTextAuthProvider
|
from cassandra.auth import PlainTextAuthProvider
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
|
||||||
|
|
||||||
from .... schema import Rows
|
from .... schema import ExtractedObject
|
||||||
from .... log_level import LogLevel
|
from .... schema import RowSchema, Field
|
||||||
from .... base import Consumer
|
from .... base import FlowProcessor, ConsumerSpec
|
||||||
|
from .... base import CollectionConfigHandler
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
module = "rows-write"
|
default_ident = "rows-write"
|
||||||
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
|
|
||||||
|
|
||||||
default_input_queue = "rows-store" # Default queue name
|
|
||||||
default_subscriber = module
|
|
||||||
|
|
||||||
class Processor(Consumer):
|
class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
input_queue = params.get("input_queue", default_input_queue)
|
id = params.get("id", default_ident)
|
||||||
subscriber = params.get("subscriber", default_subscriber)
|
|
||||||
|
|
||||||
# Get Cassandra parameters
|
# Get Cassandra parameters
|
||||||
cassandra_host = params.get("cassandra_host")
|
cassandra_host = params.get("cassandra_host")
|
||||||
cassandra_username = params.get("cassandra_username")
|
cassandra_username = params.get("cassandra_username")
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
|
|
@ -48,99 +51,549 @@ class Processor(Consumer):
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store resolved configuration with proper names
|
||||||
|
self.cassandra_host = hosts # Store as list
|
||||||
|
self.cassandra_username = username
|
||||||
|
self.cassandra_password = password
|
||||||
|
|
||||||
|
# Config key for schemas
|
||||||
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"input_queue": input_queue,
|
"id": id,
|
||||||
"subscriber": subscriber,
|
"config_type": self.config_key,
|
||||||
"input_schema": Rows,
|
|
||||||
"cassandra_host": ','.join(hosts),
|
|
||||||
"cassandra_username": username,
|
|
||||||
"cassandra_password": password,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if username and password:
|
|
||||||
auth_provider = PlainTextAuthProvider(username=username, password=password)
|
|
||||||
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
|
|
||||||
else:
|
|
||||||
self.cluster = Cluster(hosts)
|
|
||||||
self.session = self.cluster.connect()
|
|
||||||
|
|
||||||
self.tables = set()
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name="input",
|
||||||
|
schema=ExtractedObject,
|
||||||
|
handler=self.on_object
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.session.execute("""
|
# Register config handlers
|
||||||
create keyspace if not exists trustgraph
|
self.register_config_handler(self.on_schema_config)
|
||||||
with replication = {
|
self.register_config_handler(self.on_collection_config)
|
||||||
'class' : 'SimpleStrategy',
|
|
||||||
'replication_factor' : 1
|
|
||||||
};
|
|
||||||
""");
|
|
||||||
|
|
||||||
self.session.execute("use trustgraph");
|
# Cache of known keyspaces and whether tables exist
|
||||||
|
self.known_keyspaces: Set[str] = set()
|
||||||
|
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
|
||||||
|
|
||||||
async def handle(self, msg):
|
# Cache of registered (collection, schema_name) pairs
|
||||||
|
self.registered_partitions: Set[Tuple[str, str]] = set()
|
||||||
|
|
||||||
|
# Schema storage: name -> RowSchema
|
||||||
|
self.schemas: Dict[str, RowSchema] = {}
|
||||||
|
|
||||||
|
# Cassandra session
|
||||||
|
self.cluster = None
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
def connect_cassandra(self):
|
||||||
|
"""Connect to Cassandra cluster"""
|
||||||
|
if self.session:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if self.cassandra_username and self.cassandra_password:
|
||||||
v = msg.value()
|
auth_provider = PlainTextAuthProvider(
|
||||||
name = v.row_schema.name
|
username=self.cassandra_username,
|
||||||
|
password=self.cassandra_password
|
||||||
if name not in self.tables:
|
|
||||||
|
|
||||||
# FIXME: SQL injection?
|
|
||||||
|
|
||||||
pkey = []
|
|
||||||
|
|
||||||
stmt = "create table if not exists " + name + " ( "
|
|
||||||
|
|
||||||
for field in v.row_schema.fields:
|
|
||||||
|
|
||||||
stmt += field.name + " text, "
|
|
||||||
|
|
||||||
if field.primary:
|
|
||||||
pkey.append(field.name)
|
|
||||||
|
|
||||||
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
|
|
||||||
|
|
||||||
self.session.execute(stmt)
|
|
||||||
|
|
||||||
self.tables.add(name);
|
|
||||||
|
|
||||||
for row in v.rows:
|
|
||||||
|
|
||||||
field_names = []
|
|
||||||
values = []
|
|
||||||
|
|
||||||
for field in v.row_schema.fields:
|
|
||||||
field_names.append(field.name)
|
|
||||||
values.append(row[field.name])
|
|
||||||
|
|
||||||
# FIXME: SQL injection?
|
|
||||||
stmt = (
|
|
||||||
"insert into " + name + " (" + ", ".join(field_names) +
|
|
||||||
") values (" + ",".join(["%s"] * len(values)) + ")"
|
|
||||||
)
|
)
|
||||||
|
self.cluster = Cluster(
|
||||||
|
contact_points=self.cassandra_host,
|
||||||
|
auth_provider=auth_provider
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||||
|
|
||||||
self.session.execute(stmt, values)
|
self.session = self.cluster.connect()
|
||||||
|
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
logger.error(f"Exception: {str(e)}", exc_info=True)
|
async def on_schema_config(self, config, version):
|
||||||
|
"""Handle schema configuration updates"""
|
||||||
|
logger.info(f"Loading schema configuration version {version}")
|
||||||
|
|
||||||
# If there's an error make sure to do table creation etc.
|
# Track which schemas changed so we can clear partition cache
|
||||||
self.tables.remove(name)
|
old_schema_names = set(self.schemas.keys())
|
||||||
|
|
||||||
raise e
|
# Clear existing schemas
|
||||||
|
self.schemas = {}
|
||||||
|
|
||||||
|
# Check if our config type exists
|
||||||
|
if self.config_key not in config:
|
||||||
|
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the schemas dictionary for our type
|
||||||
|
schemas_config = config[self.config_key]
|
||||||
|
|
||||||
|
# Process each schema in the schemas config
|
||||||
|
for schema_name, schema_json in schemas_config.items():
|
||||||
|
try:
|
||||||
|
# Parse the JSON schema definition
|
||||||
|
schema_def = json.loads(schema_json)
|
||||||
|
|
||||||
|
# Create Field objects
|
||||||
|
fields = []
|
||||||
|
for field_def in schema_def.get("fields", []):
|
||||||
|
field = Field(
|
||||||
|
name=field_def["name"],
|
||||||
|
type=field_def["type"],
|
||||||
|
size=field_def.get("size", 0),
|
||||||
|
primary=field_def.get("primary_key", False),
|
||||||
|
description=field_def.get("description", ""),
|
||||||
|
required=field_def.get("required", False),
|
||||||
|
enum_values=field_def.get("enum", []),
|
||||||
|
indexed=field_def.get("indexed", False)
|
||||||
|
)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
# Create RowSchema
|
||||||
|
row_schema = RowSchema(
|
||||||
|
name=schema_def.get("name", schema_name),
|
||||||
|
description=schema_def.get("description", ""),
|
||||||
|
fields=fields
|
||||||
|
)
|
||||||
|
|
||||||
|
self.schemas[schema_name] = row_schema
|
||||||
|
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||||
|
|
||||||
|
# Clear partition cache for schemas that changed
|
||||||
|
# This ensures next write will re-register partitions
|
||||||
|
new_schema_names = set(self.schemas.keys())
|
||||||
|
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
|
||||||
|
if changed_schemas:
|
||||||
|
self.registered_partitions = {
|
||||||
|
(col, sch) for col, sch in self.registered_partitions
|
||||||
|
if sch not in changed_schemas
|
||||||
|
}
|
||||||
|
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
|
||||||
|
|
||||||
|
def sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names for Cassandra compatibility"""
|
||||||
|
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||||
|
# Ensure it starts with a letter
|
||||||
|
if safe_name and not safe_name[0].isalpha():
|
||||||
|
safe_name = 'r_' + safe_name
|
||||||
|
return safe_name.lower()
|
||||||
|
|
||||||
|
def ensure_keyspace(self, keyspace: str):
|
||||||
|
"""Ensure keyspace exists in Cassandra"""
|
||||||
|
if keyspace in self.known_keyspaces:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Connect if needed
|
||||||
|
self.connect_cassandra()
|
||||||
|
|
||||||
|
# Sanitize keyspace name
|
||||||
|
safe_keyspace = self.sanitize_name(keyspace)
|
||||||
|
|
||||||
|
# Create keyspace if not exists
|
||||||
|
create_keyspace_cql = f"""
|
||||||
|
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
|
||||||
|
WITH REPLICATION = {{
|
||||||
|
'class': 'SimpleStrategy',
|
||||||
|
'replication_factor': 1
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(create_keyspace_cql)
|
||||||
|
self.known_keyspaces.add(keyspace)
|
||||||
|
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def ensure_tables(self, keyspace: str):
|
||||||
|
"""Ensure unified rows and row_partitions tables exist"""
|
||||||
|
if keyspace in self.tables_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure keyspace exists first
|
||||||
|
self.ensure_keyspace(keyspace)
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(keyspace)
|
||||||
|
|
||||||
|
# Create unified rows table
|
||||||
|
create_rows_cql = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
|
||||||
|
collection text,
|
||||||
|
schema_name text,
|
||||||
|
index_name text,
|
||||||
|
index_value frozen<list<text>>,
|
||||||
|
data map<text, text>,
|
||||||
|
source text,
|
||||||
|
PRIMARY KEY ((collection, schema_name, index_name), index_value)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create row_partitions tracking table
|
||||||
|
create_partitions_cql = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
|
||||||
|
collection text,
|
||||||
|
schema_name text,
|
||||||
|
index_name text,
|
||||||
|
PRIMARY KEY ((collection), schema_name, index_name)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(create_rows_cql)
|
||||||
|
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
|
||||||
|
|
||||||
|
self.session.execute(create_partitions_cql)
|
||||||
|
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
|
||||||
|
|
||||||
|
self.tables_initialized.add(keyspace)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get all index names for a schema.
|
||||||
|
Returns list of index_name strings (single field names or comma-joined composites).
|
||||||
|
"""
|
||||||
|
index_names = []
|
||||||
|
|
||||||
|
for field in schema.fields:
|
||||||
|
# Primary key fields are treated as indexes
|
||||||
|
if field.primary:
|
||||||
|
index_names.append(field.name)
|
||||||
|
# Indexed fields
|
||||||
|
elif field.indexed:
|
||||||
|
index_names.append(field.name)
|
||||||
|
|
||||||
|
# TODO: Support composite indexes in the future
|
||||||
|
# For now, each indexed field is a single-field index
|
||||||
|
|
||||||
|
return index_names
|
||||||
|
|
||||||
|
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
|
||||||
|
"""
|
||||||
|
Register partition entries for a (collection, schema_name) pair.
|
||||||
|
Called once on first row for each pair.
|
||||||
|
"""
|
||||||
|
cache_key = (collection, schema_name)
|
||||||
|
if cache_key in self.registered_partitions:
|
||||||
|
return
|
||||||
|
|
||||||
|
schema = self.schemas.get(schema_name)
|
||||||
|
if not schema:
|
||||||
|
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(keyspace)
|
||||||
|
index_names = self.get_index_names(schema)
|
||||||
|
|
||||||
|
# Insert partition entries for each index
|
||||||
|
insert_cql = f"""
|
||||||
|
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
|
||||||
|
VALUES (%s, %s, %s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
for index_name in index_names:
|
||||||
|
try:
|
||||||
|
self.session.execute(insert_cql, (collection, schema_name, index_name))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
|
||||||
|
|
||||||
|
self.registered_partitions.add(cache_key)
|
||||||
|
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
|
||||||
|
|
||||||
|
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Build the index_value list for a given index.
|
||||||
|
For single-field indexes, returns a single-element list.
|
||||||
|
For composite indexes (comma-separated), returns multiple elements.
|
||||||
|
"""
|
||||||
|
field_names = [f.strip() for f in index_name.split(',')]
|
||||||
|
values = []
|
||||||
|
|
||||||
|
for field_name in field_names:
|
||||||
|
value = value_map.get(field_name)
|
||||||
|
# Convert to string for storage
|
||||||
|
values.append(str(value) if value is not None else "")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
async def on_object(self, msg, consumer, flow):
|
||||||
|
"""Process incoming ExtractedObject and store in Cassandra"""
|
||||||
|
|
||||||
|
obj = msg.value()
|
||||||
|
logger.info(
|
||||||
|
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
|
||||||
|
f"from {obj.metadata.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate collection exists before accepting writes
|
||||||
|
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
|
||||||
|
error_msg = (
|
||||||
|
f"Collection {obj.metadata.collection} does not exist. "
|
||||||
|
f"Create it first via collection management API."
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Get schema definition
|
||||||
|
schema = self.schemas.get(obj.schema_name)
|
||||||
|
if not schema:
|
||||||
|
logger.warning(f"No schema found for {obj.schema_name} - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
keyspace = obj.metadata.user
|
||||||
|
collection = obj.metadata.collection
|
||||||
|
schema_name = obj.schema_name
|
||||||
|
source = getattr(obj.metadata, 'source', '') or ''
|
||||||
|
|
||||||
|
# Ensure tables exist
|
||||||
|
self.ensure_tables(keyspace)
|
||||||
|
|
||||||
|
# Register partitions if first time seeing this (collection, schema_name)
|
||||||
|
self.register_partitions(keyspace, collection, schema_name)
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(keyspace)
|
||||||
|
|
||||||
|
# Get all index names for this schema
|
||||||
|
index_names = self.get_index_names(schema)
|
||||||
|
|
||||||
|
if not index_names:
|
||||||
|
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare insert statement
|
||||||
|
insert_cql = f"""
|
||||||
|
INSERT INTO {safe_keyspace}.rows
|
||||||
|
(collection, schema_name, index_name, index_value, data, source)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Process each row in the batch
|
||||||
|
rows_written = 0
|
||||||
|
for row_index, value_map in enumerate(obj.values):
|
||||||
|
# Convert all values to strings for the data map
|
||||||
|
data_map = {}
|
||||||
|
for field in schema.fields:
|
||||||
|
raw_value = value_map.get(field.name)
|
||||||
|
if raw_value is not None:
|
||||||
|
data_map[field.name] = str(raw_value)
|
||||||
|
|
||||||
|
# Write one copy per index
|
||||||
|
for index_name in index_names:
|
||||||
|
index_value = self.build_index_value(value_map, index_name)
|
||||||
|
|
||||||
|
# Skip if index value is empty/null
|
||||||
|
if not index_value or all(v == "" for v in index_value):
|
||||||
|
logger.debug(
|
||||||
|
f"Skipping index {index_name} for row {row_index} - "
|
||||||
|
f"empty index value"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
insert_cql,
|
||||||
|
(collection, schema_name, index_name, index_value, data_map, source)
|
||||||
|
)
|
||||||
|
rows_written += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to insert row {row_index} for index {index_name}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
|
||||||
|
f"({len(index_names)} indexes per row)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||||
|
"""Create/verify collection exists in Cassandra row store"""
|
||||||
|
# Connect if not already connected
|
||||||
|
self.connect_cassandra()
|
||||||
|
|
||||||
|
# Ensure tables exist
|
||||||
|
self.ensure_tables(user)
|
||||||
|
|
||||||
|
logger.info(f"Collection {collection} ready for user {user}")
|
||||||
|
|
||||||
|
async def delete_collection(self, user: str, collection: str):
|
||||||
|
"""Delete all data for a specific collection using partition tracking"""
|
||||||
|
# Connect if not already connected
|
||||||
|
self.connect_cassandra()
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(user)
|
||||||
|
|
||||||
|
# Check if keyspace exists
|
||||||
|
if user not in self.known_keyspaces:
|
||||||
|
check_keyspace_cql = """
|
||||||
|
SELECT keyspace_name FROM system_schema.keyspaces
|
||||||
|
WHERE keyspace_name = %s
|
||||||
|
"""
|
||||||
|
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
|
||||||
|
if not result.one():
|
||||||
|
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||||
|
return
|
||||||
|
self.known_keyspaces.add(user)
|
||||||
|
|
||||||
|
# Discover all partitions for this collection
|
||||||
|
select_partitions_cql = f"""
|
||||||
|
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
||||||
|
WHERE collection = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
partitions = self.session.execute(select_partitions_cql, (collection,))
|
||||||
|
partition_list = list(partitions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to query partitions for collection {collection}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Delete each partition from rows table
|
||||||
|
delete_rows_cql = f"""
|
||||||
|
DELETE FROM {safe_keyspace}.rows
|
||||||
|
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
partitions_deleted = 0
|
||||||
|
for partition in partition_list:
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
delete_rows_cql,
|
||||||
|
(collection, partition.schema_name, partition.index_name)
|
||||||
|
)
|
||||||
|
partitions_deleted += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete partition {collection}/{partition.schema_name}/"
|
||||||
|
f"{partition.index_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Clean up row_partitions entries
|
||||||
|
delete_partitions_cql = f"""
|
||||||
|
DELETE FROM {safe_keyspace}.row_partitions
|
||||||
|
WHERE collection = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(delete_partitions_cql, (collection,))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Clear from local cache
|
||||||
|
self.registered_partitions = {
|
||||||
|
(col, sch) for col, sch in self.registered_partitions
|
||||||
|
if col != collection
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleted collection {collection}: {partitions_deleted} partitions "
|
||||||
|
f"from keyspace {safe_keyspace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
|
||||||
|
"""Delete all data for a specific collection + schema combination"""
|
||||||
|
# Connect if not already connected
|
||||||
|
self.connect_cassandra()
|
||||||
|
|
||||||
|
safe_keyspace = self.sanitize_name(user)
|
||||||
|
|
||||||
|
# Discover partitions for this collection + schema
|
||||||
|
select_partitions_cql = f"""
|
||||||
|
SELECT index_name FROM {safe_keyspace}.row_partitions
|
||||||
|
WHERE collection = %s AND schema_name = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
|
||||||
|
partition_list = list(partitions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to query partitions for {collection}/{schema_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Delete each partition from rows table
|
||||||
|
delete_rows_cql = f"""
|
||||||
|
DELETE FROM {safe_keyspace}.rows
|
||||||
|
WHERE collection = %s AND schema_name = %s AND index_name = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
partitions_deleted = 0
|
||||||
|
for partition in partition_list:
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
delete_rows_cql,
|
||||||
|
(collection, schema_name, partition.index_name)
|
||||||
|
)
|
||||||
|
partitions_deleted += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete partition {collection}/{schema_name}/"
|
||||||
|
f"{partition.index_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Clean up row_partitions entries for this schema
|
||||||
|
delete_partitions_cql = f"""
|
||||||
|
DELETE FROM {safe_keyspace}.row_partitions
|
||||||
|
WHERE collection = %s AND schema_name = %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(delete_partitions_cql, (collection, schema_name))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Clear from local cache
|
||||||
|
self.registered_partitions.discard((collection, schema_name))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
|
||||||
|
f"from keyspace {safe_keyspace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Clean up Cassandra connections"""
|
||||||
|
if self.cluster:
|
||||||
|
self.cluster.shutdown()
|
||||||
|
logger.info("Closed Cassandra connection")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
"""Add command-line arguments"""
|
||||||
|
|
||||||
Consumer.add_args(
|
FlowProcessor.add_args(parser)
|
||||||
parser, default_input_queue, default_subscriber,
|
|
||||||
)
|
|
||||||
add_cassandra_args(parser)
|
add_cassandra_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-type',
|
||||||
|
default='schema',
|
||||||
|
help='Configuration type prefix for schemas (default: schema)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
"""Entry point for rows-write-cassandra command"""
|
||||||
Processor.launch(module, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue