Structured data 2 (#645)

* Structured data refactor - multi-index tables, remove need for manual mods to the Cassandra tables

* Tech spec updated to track implementation
This commit is contained in:
cybermaggedon 2026-02-23 15:56:29 +00:00 committed by GitHub
parent 5ffad92345
commit 1809c1f56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 5233 additions and 3235 deletions

View file

@ -5,7 +5,7 @@ VERSION=0.0.0
DOCKER=podman
all: container
all: containers
# Not used
wheels:
@ -49,7 +49,9 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
container: update-package-versions
FORCE:
containers: FORCE
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \

View file

@ -0,0 +1,467 @@
# Structured Data Technical Specification (Part 2)
## Overview
This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`.
## Problem Statements
### 1. Naming Inconsistency: "Object" vs "Row"
The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion:
- "Object" is an overloaded term in software (Python objects, JSON objects, etc.)
- The data being handled is fundamentally tabular - rows in tables with defined schemas
- "Row" more accurately describes the data model and aligns with database terminology
This inconsistency appears in module names, class names, message types, and documentation.
### 2. Row Store Query Limitations
The current row store implementation has significant query limitations:
**Natural Language Mismatch**: Queries struggle with real-world data variations. For example:
- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"`
- Abbreviations, case differences, and formatting variations break exact-match queries
- Users expect semantic understanding, but the store provides literal matching
**Schema Evolution Issues**: Changing schemas causes problems:
- Existing data may not conform to updated schemas
- Table structure changes can break queries and data integrity
- No clear migration path for schema updates
### 3. Row Embeddings Required
Related to problem 2, the system needs vector embeddings for row data to enable:
- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST")
- Similarity matching for fuzzy queries
- Hybrid search combining structured filters with semantic similarity
- Better natural language query support
The embedding service was specified but not implemented.
### 4. Row Data Ingestion Incomplete
The structured data ingestion pipeline is not fully operational:
- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.)
- The ingestion service that uses these prompts is not plumbed into the system
- No end-to-end path for loading pre-structured data into the row store
## Goals
- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations
- **Consistent Naming**: Standardize on "row" terminology throughout the codebase
- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings
- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data
## Technical Design
### Unified Row Storage Schema
The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations.
The new design uses a single unified table for all row data:
```sql
CREATE TABLE rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
```
#### Column Definitions
| Column | Type | Description |
|--------|------|-------------|
| `collection` | `text` | Data collection/import identifier (from metadata) |
| `schema_name` | `text` | Name of the schema this row conforms to |
| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites |
| `index_value` | `frozen<list<text>>` | Index value(s) as a list |
| `data` | `map<text, text>` | Row data as key-value pairs |
| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. |
#### Index Handling
Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility.
**Single-field index example:**
- Schema defines `email` as indexed
- `index_name = "email"`
- `index_value = ['foo@bar.com']`
**Composite index example:**
- Schema defines composite index on `region` and `status`
- `index_name = "region,status"` (field names sorted and comma-joined)
- `index_value = ['US', 'active']` (values in same order as field names)
**Primary key example:**
- Schema defines `customer_id` as primary key
- `index_name = "customer_id"`
- `index_value = ['CUST001']`
#### Query Patterns
All queries follow the same pattern regardless of which index is used:
```sql
SELECT * FROM rows
WHERE collection = 'import_2024'
AND schema_name = 'customers'
AND index_name = 'email'
AND index_value = ['foo@bar.com']
```
#### Design Trade-offs
**Advantages:**
- Schema changes don't require table structure changes
- Row data is opaque to Cassandra - field additions/removals are transparent
- Consistent query pattern for all access methods
- No Cassandra secondary indexes (which can be slow at scale)
- Native Cassandra types throughout (`map`, `frozen<list>`)
**Trade-offs:**
- Write amplification: each row insert = N inserts (one per indexed field)
- Storage overhead from duplicated row data
- Type information stored in schema config, conversion at application layer
#### Consistency Model
The design accepts certain simplifications:
1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row.
2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed.
### Partition Tracking and Deletion
#### The Problem
With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data.
#### Partition Tracking Table
A secondary lookup table tracks which partitions exist:
```sql
CREATE TABLE row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
```
This enables efficient discovery of partitions for deletion operations.
#### Row Writer Behavior
The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row:
1. Check if `(collection, schema_name)` is in the cache
2. If not cached (first row for this pair):
- Look up the schema config to get all index names
- Insert entries into `row_partitions` for each `(collection, schema_name, index_name)`
- Add the pair to the cache
3. Proceed with writing the row data
The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names.
This approach ensures:
- Lookup table writes happen once per `(collection, schema_name)` pair, not per row
- The lookup table reflects the indexes that were active when data was written
- Schema changes mid-import are picked up correctly
#### Deletion Operations
**Delete collection:**
```sql
-- 1. Discover all partitions
SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X';
-- 2. Delete each partition from rows table
DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...';
-- (repeat for each discovered partition)
-- 3. Clean up the lookup table
DELETE FROM row_partitions WHERE collection = 'X';
```
**Delete collection + schema:**
```sql
-- 1. Discover partitions for this schema
SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
-- 2. Delete each partition from rows table
DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...';
-- (repeat for each discovered partition)
-- 3. Clean up the lookup table entries
DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y';
```
### Row Embeddings
Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street").
#### Design Overview
Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra.
#### Qdrant Collection Structure
One Qdrant collection per `(user, collection, schema_name, dimension)` tuple:
- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}`
- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix)
- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist
#### What Gets Embedded
The text representation of index values:
| Index Type | Example `index_value` | Text to Embed |
|------------|----------------------|---------------|
| Single-field | `['foo@bar.com']` | `"foo@bar.com"` |
| Composite | `['US', 'active']` | `"US active"` (space-joined) |
#### Point Structure
Each Qdrant point contains:
```json
{
"id": "<uuid>",
"vector": [0.1, 0.2, ...],
"payload": {
"index_name": "street_name",
"index_value": ["CHESTNUT ST"],
"text": "CHESTNUT ST"
}
}
```
| Payload Field | Description |
|---------------|-------------|
| `index_name` | The indexed field(s) this embedding represents |
| `index_value` | The original list of values (for Cassandra lookup) |
| `text` | The text that was embedded (for debugging/display) |
Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name.
#### Query Flow
1. User queries for "Chestnut Street" within user U, collection X, schema Y
2. Embed the query text
3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_`
4. Search matching Qdrant collection(s) for nearest vectors
5. Get matching points with payloads containing `index_name` and `index_value`
6. Query Cassandra:
```sql
SELECT * FROM rows
WHERE collection = 'X'
AND schema_name = 'Y'
AND index_name = '<from payload>'
AND index_value = <from payload>
```
7. Return matched rows
#### Optional: Filtering by Index Name
Queries can optionally filter by `index_name` in Qdrant to search only specific fields:
- **"Find any field matching 'Chestnut'"** → search all vectors in the collection
- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'`
#### Architecture
Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings):
- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings`
- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant
The Cassandra row writer is a separate parallel consumer:
- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra
All three services consume from the same flow, keeping them decoupled. This allows:
- Independent scaling of Cassandra writes vs embedding generation vs vector storage
- Embedding services can be disabled if not needed
- Failures in one service don't affect the others
- Consistent architecture with GraphRAG pipelines
#### Write Path
**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`:
1. Look up the schema to find indexed fields
2. For each indexed field:
- Build the text representation of the index value
- Compute embedding via the embeddings service
3. Output a `RowEmbeddings` message containing all computed vectors
**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`:
1. For each embedding in the message:
- Determine Qdrant collection from `(user, collection, schema_name, dimension)`
- Create collection if needed (lazy creation on first write)
- Upsert point with vector and payload
#### Message Types
```python
@dataclass
class RowIndexEmbedding:
index_name: str # The indexed field name(s)
index_value: list[str] # The field value(s)
text: str # Text that was embedded
vectors: list[list[float]] # Computed embedding vectors
@dataclass
class RowEmbeddings:
metadata: Metadata
schema_name: str
embeddings: list[RowIndexEmbedding]
```
#### Deletion Integration
Qdrant collections are discovered by prefix matching on the collection name pattern:
**Delete `(user, collection)`:**
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_`
2. Delete each matching collection
3. Delete Cassandra rows partitions (as documented above)
4. Clean up `row_partitions` entries
**Delete `(user, collection, schema_name)`:**
1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_`
2. Delete each matching collection (handles multiple dimensions)
3. Delete Cassandra rows partitions
4. Clean up `row_partitions`
#### Module Locations
| Stage | Module | Entry Point |
|-------|--------|-------------|
| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` |
| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` |
### Row Embeddings Query API
The row embeddings query is a **separate API** from the GraphQL row query service:
| API | Purpose | Backend |
|-----|---------|---------|
| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra |
| Row Embeddings Query | Fuzzy/semantic matching | Qdrant |
This separation keeps concerns clean:
- GraphQL service focuses on exact, structured queries
- Embeddings API handles semantic similarity
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
### Row Data Ingestion
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
## Implementation Impact
### Current State Analysis
The existing implementation has two main components:
| Component | Location | Lines | Description |
|-----------|----------|-------|-------------|
| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling |
| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete |
**Current Query Pattern:**
```sql
SELECT * FROM {keyspace}.o_{schema_name}
WHERE collection = 'X' AND email = 'foo@bar.com'
ALLOW FILTERING
```
**New Query Pattern:**
```sql
SELECT * FROM {keyspace}.rows
WHERE collection = 'X' AND schema_name = 'customers'
AND index_name = 'email' AND index_value = ['foo@bar.com']
```
### Key Changes
1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either:
- Become post-filtering on returned data (if still needed)
- Are removed in favor of using the embeddings API for fuzzy matching
2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code.
### Proposed Refactor
The refactor has two parts:
#### 1. Break Out GraphQL Code
Extract reusable GraphQL components into a shared module:
```
trustgraph-flow/trustgraph/query/graphql/
├── __init__.py
├── types.py # Filter types (IntFilter, StringFilter, FloatFilter)
├── schema.py # Dynamic schema generation from RowSchema
└── filters.py # Filter parsing utilities
```
This enables:
- Reuse across different row store backends
- Cleaner separation of concerns
- Easier testing of GraphQL logic independently
#### 2. Implement New Table Schema
Refactor the Cassandra-specific code to use the unified table:
**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`):
- Single `rows` table instead of per-schema tables
- Write N copies per row (one per index)
- Register to `row_partitions` table
- Simpler table creation (one-time setup)
**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`):
- Query the unified `rows` table
- Use extracted GraphQL module for schema generation
- Simplified filter handling (exact match only at DB level)
### Module Renames
As part of the "object" → "row" naming cleanup:
| Current | New |
|---------|-----|
| `storage/objects/cassandra/` | `storage/rows/cassandra/` |
| `query/objects/cassandra/` | `query/rows/cassandra/` |
| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` |
### New Modules
| Module | Purpose |
|--------|---------|
| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities |
| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API |
| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) |
| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) |
## References
- [Structured Data Technical Specification](structured-data.md)

View file

@ -1,6 +1,6 @@
type: object
description: |
Objects query request - GraphQL query over knowledge graph.
Rows query request - GraphQL query over structured data.
required:
- query
properties:

View file

@ -1,5 +1,5 @@
type: object
description: Objects query response (GraphQL format)
description: Rows query response (GraphQL format)
properties:
data:
description: GraphQL response data (JSON object or null)

View file

@ -121,8 +121,8 @@ paths:
$ref: './paths/flow/mcp-tool.yaml'
/api/v1/flow/{flow}/service/triples:
$ref: './paths/flow/triples.yaml'
/api/v1/flow/{flow}/service/objects:
$ref: './paths/flow/objects.yaml'
/api/v1/flow/{flow}/service/rows:
$ref: './paths/flow/rows.yaml'
/api/v1/flow/{flow}/service/nlp-query:
$ref: './paths/flow/nlp-query.yaml'
/api/v1/flow/{flow}/service/structured-query:

View file

@ -34,7 +34,7 @@ post:
```
1. User asks: "Who does Alice know?"
2. NLP Query generates GraphQL
3. Execute via /api/v1/flow/{flow}/service/objects
3. Execute via /api/v1/flow/{flow}/service/rows
4. Return results to user
```

View file

@ -1,19 +1,19 @@
post:
tags:
- Flow Services
summary: Objects query - GraphQL over knowledge graph
summary: Rows query - GraphQL over structured data
description: |
Query knowledge graph using GraphQL for object-oriented data access.
Query structured data using GraphQL for row-oriented data access.
## Objects Query Overview
## Rows Query Overview
GraphQL interface to knowledge graph:
GraphQL interface to structured data:
- **Schema-driven**: Predefined types and relationships
- **Flexible queries**: Request exactly what you need
- **Nested data**: Traverse relationships in single query
- **Type-safe**: Strong typing with introspection
Abstracts RDF triples into familiar object model.
Abstracts structured rows into familiar object model.
## GraphQL Benefits
@ -61,7 +61,7 @@ post:
Schema defines available types via config service.
Use introspection query to discover schema.
operationId: objectsQueryService
operationId: rowsQueryService
security:
- bearerAuth: []
parameters:
@ -77,7 +77,7 @@ post:
content:
application/json:
schema:
$ref: '../../components/schemas/query/ObjectsQueryRequest.yaml'
$ref: '../../components/schemas/query/RowsQueryRequest.yaml'
examples:
simpleQuery:
summary: Simple query
@ -129,7 +129,7 @@ post:
content:
application/json:
schema:
$ref: '../../components/schemas/query/ObjectsQueryResponse.yaml'
$ref: '../../components/schemas/query/RowsQueryResponse.yaml'
examples:
successfulQuery:
summary: Successful query

View file

@ -9,7 +9,7 @@ post:
Combines two operations in one call:
1. **NLP Query**: Generate GraphQL from question
2. **Objects Query**: Execute generated query
2. **Rows Query**: Execute generated query
3. **Return Results**: Direct answer data
Simplest way to query knowledge graph with natural language.
@ -21,7 +21,7 @@ post:
- **Output**: Query results (data)
- **Use when**: Want simple, direct answers
### NLP Query + Objects Query (separate calls)
### NLP Query + Rows Query (separate calls)
- **Step 1**: Convert question → GraphQL
- **Step 2**: Execute GraphQL → results
- **Use when**: Need to inspect/modify query before execution

View file

@ -25,7 +25,7 @@ payload:
- $ref: './requests/EmbeddingsRequest.yaml'
- $ref: './requests/McpToolRequest.yaml'
- $ref: './requests/TriplesRequest.yaml'
- $ref: './requests/ObjectsRequest.yaml'
- $ref: './requests/RowsRequest.yaml'
- $ref: './requests/NlpQueryRequest.yaml'
- $ref: './requests/StructuredQueryRequest.yaml'
- $ref: './requests/StructuredDiagRequest.yaml'

View file

@ -1,5 +1,5 @@
type: object
description: WebSocket request for objects service (flow-hosted service)
description: WebSocket request for rows service (flow-hosted service)
required:
- id
- service
@ -11,16 +11,16 @@ properties:
description: Unique request identifier
service:
type: string
const: objects
description: Service identifier for objects service
const: rows
description: Service identifier for rows service
flow:
type: string
description: Flow ID
request:
$ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml'
$ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml'
examples:
- id: req-1
service: objects
service: rows
flow: my-flow
request:
query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }"

View file

@ -1,8 +1,8 @@
"""
Contract tests for Cassandra Object Storage
Contract tests for Cassandra Row Storage
These tests verify the message contracts and schema compatibility
for the objects storage processor.
for the rows storage processor.
"""
import pytest
@ -10,12 +10,12 @@ import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.storage.rows.cassandra.write import Processor
@pytest.mark.contract
class TestObjectsCassandraContracts:
"""Contract tests for Cassandra object storage messages"""
class TestRowsCassandraContracts:
"""Contract tests for Cassandra row storage messages"""
def test_extracted_object_input_contract(self):
"""Test that ExtractedObject schema matches expected input format"""
@ -145,50 +145,6 @@ class TestObjectsCassandraContracts:
assert required_field_keys.issubset(field.keys())
assert set(field.keys()).issubset(required_field_keys | optional_field_keys)
def test_cassandra_type_mapping_contract(self):
"""Test that all supported field types have Cassandra mappings"""
processor = Processor.__new__(Processor)
# All field types that should be supported
supported_types = [
("string", "text"),
("integer", "int"), # or bigint based on size
("float", "float"), # or double based on size
("boolean", "boolean"),
("timestamp", "timestamp"),
("date", "date"),
("time", "time"),
("uuid", "uuid")
]
for field_type, expected_cassandra_type in supported_types:
cassandra_type = processor.get_cassandra_type(field_type)
# For integer and float, the exact type depends on size
if field_type in ["integer", "float"]:
assert cassandra_type in ["int", "bigint", "float", "double"]
else:
assert cassandra_type == expected_cassandra_type
def test_value_conversion_contract(self):
"""Test value conversion for all supported types"""
processor = Processor.__new__(Processor)
# Test conversions maintain data integrity
test_cases = [
# (input_value, field_type, expected_output, expected_type)
("123", "integer", 123, int),
("123.45", "float", 123.45, float),
("true", "boolean", True, bool),
("false", "boolean", False, bool),
("test string", "string", "test string", str),
(None, "string", None, type(None)),
]
for input_val, field_type, expected_val, expected_type in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_val
assert isinstance(result, expected_type) or result is None
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
def test_extracted_object_serialization_contract(self):
"""Test that ExtractedObject can be serialized/deserialized correctly"""
@ -222,43 +178,31 @@ class TestObjectsCassandraContracts:
assert decoded.confidence == original.confidence
assert decoded.source_span == original.source_span
def test_cassandra_table_naming_contract(self):
def test_cassandra_name_sanitization_contract(self):
"""Test Cassandra naming conventions and constraints"""
processor = Processor.__new__(Processor)
# Test table naming (always gets o_ prefix)
table_test_names = [
("simple_name", "o_simple_name"),
("Name-With-Dashes", "o_name_with_dashes"),
("name.with.dots", "o_name_with_dots"),
("123_numbers", "o_123_numbers"),
("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "o_uppercase"),
("CamelCase", "o_camelcase"),
("", "o_"), # Edge case - empty string becomes o_
]
for input_name, expected_name in table_test_names:
result = processor.sanitize_table(input_name)
assert result == expected_name
# Verify result is valid Cassandra identifier (starts with letter)
assert result.startswith('o_')
assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_'
# Test regular name sanitization (only adds o_ prefix if starts with number)
# Test name sanitization for Cassandra identifiers
# - Non-alphanumeric chars (except underscore) become underscores
# - Names starting with non-letter get 'r_' prefix
# - All names converted to lowercase
name_test_cases = [
("simple_name", "simple_name"),
("Name-With-Dashes", "name_with_dashes"),
("name.with.dots", "name_with_dots"),
("123_numbers", "o_123_numbers"), # Only this gets o_ prefix
("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number)
("special!@#chars", "special___chars"), # 3 special chars become 3 underscores
("UPPERCASE", "uppercase"),
("CamelCase", "camelcase"),
("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore)
]
for input_name, expected_name in name_test_cases:
result = processor.sanitize_name(input_name)
assert result == expected_name
assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}"
# Verify result is valid Cassandra identifier (starts with letter)
if result: # Skip empty string case
assert result[0].isalpha(), f"Result {result} should start with a letter"
def test_primary_key_structure_contract(self):
"""Test that primary key structure follows Cassandra best practices"""
@ -308,8 +252,8 @@ class TestObjectsCassandraContracts:
@pytest.mark.contract
class TestObjectsCassandraContractsBatch:
"""Contract tests for Cassandra object storage batch processing"""
class TestRowsCassandraContractsBatch:
"""Contract tests for Cassandra row storage batch processing"""
def test_extracted_object_batch_input_contract(self):
"""Test that batched ExtractedObject schema matches expected input format"""

View file

@ -1,26 +1,26 @@
"""
Contract tests for Objects GraphQL Query Service
Contract tests for Rows GraphQL Query Service
These tests verify the message contracts and schema compatibility
for the objects GraphQL query processor.
for the rows GraphQL query processor.
"""
import pytest
import json
from pulsar.schema import AvroSchema
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
@pytest.mark.contract
class TestObjectsGraphQLQueryContracts:
class TestRowsGraphQLQueryContracts:
"""Contract tests for GraphQL query service messages"""
def test_objects_query_request_contract(self):
"""Test ObjectsQueryRequest schema structure and required fields"""
def test_rows_query_request_contract(self):
"""Test RowsQueryRequest schema structure and required fields"""
# Create test request with all required fields
test_request = ObjectsQueryRequest(
test_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ customers { id name email } }',
@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_request.variables["status"] == "active"
assert test_request.operation_name == "GetCustomers"
def test_objects_query_request_minimal(self):
"""Test ObjectsQueryRequest with minimal required fields"""
def test_rows_query_request_minimal(self):
"""Test RowsQueryRequest with minimal required fields"""
# Create request with only essential fields
minimal_request = ObjectsQueryRequest(
minimal_request = RowsQueryRequest(
user="user",
collection="collection",
query='{ test }',
@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts:
assert test_error.path == ["customers", "0", "nonexistent"]
assert test_error.extensions["code"] == "FIELD_ERROR"
def test_objects_query_response_success_contract(self):
"""Test ObjectsQueryResponse schema for successful queries"""
def test_rows_query_response_success_contract(self):
"""Test RowsQueryResponse schema for successful queries"""
# Create successful response
success_response = ObjectsQueryResponse(
success_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=[],
@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts:
assert len(parsed_data["customers"]) == 1
assert parsed_data["customers"][0]["id"] == "1"
def test_objects_query_response_error_contract(self):
"""Test ObjectsQueryResponse schema for error cases"""
def test_rows_query_response_error_contract(self):
"""Test RowsQueryResponse schema for error cases"""
# Create GraphQL errors - work around Pulsar Array(Record) validation bug
# by creating a response without the problematic errors array first
error_response = ObjectsQueryResponse(
error_response = RowsQueryResponse(
error=None, # System error is None - these are GraphQL errors
data=None, # No data due to errors
errors=[], # Empty errors array to avoid Pulsar bug
@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts:
assert validation_error.path == ["customers", "email"]
assert validation_error.extensions["details"] == "Invalid email format"
def test_objects_query_response_system_error_contract(self):
"""Test ObjectsQueryResponse schema for system errors"""
def test_rows_query_response_system_error_contract(self):
"""Test RowsQueryResponse schema for system errors"""
from trustgraph.schema import Error
# Create system error response
system_error_response = ObjectsQueryResponse(
system_error_response = RowsQueryResponse(
error=Error(
type="objects-query-error",
type="rows-query-error",
message="Failed to connect to Cassandra cluster"
),
data=None,
@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts:
# Verify system error structure
assert system_error_response.error is not None
assert system_error_response.error.type == "objects-query-error"
assert system_error_response.error.type == "rows-query-error"
assert "Cassandra" in system_error_response.error.message
assert system_error_response.data is None
assert len(system_error_response.errors) == 0
@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts:
def test_request_response_serialization_contract(self):
"""Test that request/response can be serialized/deserialized correctly"""
# Create original request
original_request = ObjectsQueryRequest(
original_request = RowsQueryRequest(
user="serialization_test",
collection="test_data",
query='{ orders(limit: 5) { id total customer { name } } }',
@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test request serialization using Pulsar schema
request_schema = AvroSchema(ObjectsQueryRequest)
request_schema = AvroSchema(RowsQueryRequest)
# Encode and decode request
encoded_request = request_schema.encode(original_request)
@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts:
assert decoded_request.operation_name == original_request.operation_name
# Create original response - work around Pulsar Array(Record) bug
original_response = ObjectsQueryResponse(
original_response = RowsQueryResponse(
error=None,
data='{"orders": []}',
errors=[], # Empty to avoid Pulsar validation bug
@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts:
)
# Test response serialization
response_schema = AvroSchema(ObjectsQueryResponse)
response_schema = AvroSchema(RowsQueryResponse)
# Encode and decode response
encoded_response = response_schema.encode(original_response)
@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_query_format_contract(self):
"""Test supported GraphQL query formats"""
# Test basic query
basic_query = ObjectsQueryRequest(
basic_query = RowsQueryRequest(
user="test", collection="test", query='{ customers { id } }',
variables={}, operation_name=""
)
@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts:
assert basic_query.query.strip().endswith('}')
# Test query with variables
parameterized_query = ObjectsQueryRequest(
parameterized_query = RowsQueryRequest(
user="test", collection="test",
query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }',
variables={"status": "active", "limit": "10"},
@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts:
assert parameterized_query.operation_name == "GetCustomers"
# Test complex nested query
nested_query = ObjectsQueryRequest(
nested_query = RowsQueryRequest(
user="test", collection="test",
query='''
{
@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts:
# Note: Current schema uses Map(String()) which only supports string values
# This test verifies the current contract, though ideally we'd support all JSON types
variables_test = ObjectsQueryRequest(
variables_test = RowsQueryRequest(
user="test", collection="test", query='{ test }',
variables={
"string_var": "test_value",
@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts:
def test_cassandra_context_fields_contract(self):
"""Test that request contains necessary fields for Cassandra operations"""
# Verify request has fields needed for Cassandra keyspace/table targeting
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="keyspace_name", # Maps to Cassandra keyspace
collection="partition_collection", # Used in partition key
query='{ objects { id } }',
@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts:
def test_graphql_extensions_contract(self):
"""Test GraphQL extensions field format and usage"""
# Extensions should support query metadata
response_with_extensions = ObjectsQueryResponse(
response_with_extensions = RowsQueryResponse(
error=None,
data='{"test": "data"}',
errors=[],
@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts:
'''
# Request to execute specific operation
multi_op_request = ObjectsQueryRequest(
multi_op_request = RowsQueryRequest(
user="test", collection="test",
query=multi_op_query,
variables={},
@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts:
assert "GetOrders" in multi_op_request.query
# Test single operation (operation_name optional)
single_op_request = ObjectsQueryRequest(
single_op_request = RowsQueryRequest(
user="test", collection="test",
query='{ customers { id } }',
variables={}, operation_name=""

View file

@ -12,7 +12,7 @@ from argparse import ArgumentParser
# Import processors that use Cassandra configuration
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow:
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
"""Test complete flow from environment variables to Cassandra Cluster connection."""
env_vars = {
@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
# Trigger Cassandra connection
processor.connect_cassandra()
@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd:
class TestMultipleHostsHandling:
"""Test multiple Cassandra hosts handling end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
env_vars = {
@ -333,7 +333,7 @@ class TestMultipleHostsHandling:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify all hosts were passed to Cluster
@ -386,8 +386,8 @@ class TestMultipleHostsHandling:
class TestAuthenticationFlow:
"""Test authentication configuration flow end-to-end."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is enabled when both username and password are provided."""
env_vars = {
@ -402,7 +402,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should be created
@ -416,8 +416,8 @@ class TestAuthenticationFlow:
assert 'auth_provider' in call_args.kwargs
assert call_args.kwargs['auth_provider'] == mock_auth_instance
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when credentials are missing."""
env_vars = {
@ -429,7 +429,7 @@ class TestAuthenticationFlow:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Auth provider should not be created
@ -439,11 +439,11 @@ class TestAuthenticationFlow:
call_args = mock_cluster.call_args
assert 'auth_provider' not in call_args.kwargs
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
"""Test that authentication is not used when only username is provided."""
processor = ObjectsWriter(
processor = RowsWriter(
taskgroup=MagicMock(),
cassandra_host='partial-auth-host',
cassandra_username='partial-user'

View file

@ -11,7 +11,7 @@ import json
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration
@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Mock flow with failing prompt service
@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration:
processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor)
# Import and bind the convert_values_to_strings function
from trustgraph.extract.kg.objects.processor import convert_values_to_strings
from trustgraph.extract.kg.rows.processor import convert_values_to_strings
processor.convert_values_to_strings = convert_values_to_strings
# Load configuration

View file

@ -1,608 +0,0 @@
"""
Integration tests for Cassandra Object Storage
These tests verify the end-to-end functionality of storing ExtractedObjects
in Cassandra, including table creation, data insertion, and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import uuid
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestObjectsCassandraIntegration:
"""Integration tests for Cassandra object storage"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
# Extract keyspace name from query
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
# Check if this keyspace was created
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.graph_host = "localhost"
processor.graph_username = None
processor.graph_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.cluster = None
processor.session = None
# Bind actual methods
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.create_collection = Processor.create_collection.__get__(processor, Processor)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
# Mock Cluster creation
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 1.5: Create the collection first (simulate tg-set-collection)
await processor.create_collection("test_user", "import_2024", {})
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix
assert "collection text" in str(table_calls[0])
assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0])
# Verify index creation
index_calls = [call for call in mock_session.execute.call_args_list
if "CREATE INDEX" in str(call)]
assert len(index_calls) == 1
assert "email" in str(index_calls[0])
# Verify data insertion
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
insert_call = insert_calls[0]
assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix
# Check inserted values
values = insert_call[0][1]
assert "import_2024" in values # collection
assert "CUST001" in values # customer_id
assert "John Doe" in values # name
assert "john@example.com" in values # email
assert 30 in values # age (converted to int)
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas and objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Create collections first
await processor.create_collection("shop", "catalog", {})
await processor.create_collection("shop", "sales", {})
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify separate tables were created
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2
assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix
assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix
@pytest.mark.asyncio
async def test_missing_required_fields(self, processor_with_mocks):
"""Test handling of objects with missing required fields"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema with required field
processor.schemas["test_schema"] = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True, required=True),
Field(name="required_field", type="string", size=100, required=True)
]
)
# Create collection first
await processor.create_collection("test", "test", {})
# Create object missing required field
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test_schema",
values=[{"id": "123"}], # missing required_field
confidence=0.8,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should still process (Cassandra doesn't enforce NOT NULL)
await processor.on_object(msg, None, None)
# Verify insert was attempted
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
@pytest.mark.asyncio
async def test_schema_without_primary_key(self, processor_with_mocks):
"""Test handling schemas without defined primary keys"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema without primary key
processor.schemas["events"] = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Create collection first
await processor.create_collection("logger", "app_events", {})
# Process object
test_obj = ExtractedObject(
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
schema_name="events",
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
confidence=1.0,
source_span="Event"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify synthetic_id was added
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "synthetic_id uuid" in str(table_calls[0])
# Verify insert includes UUID
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 1
values = insert_calls[0][0][1]
# Check that a UUID was generated (will be in values list)
uuid_found = any(isinstance(v, uuid.UUID) for v in values)
assert uuid_found
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_error_handling_during_insert(self, processor_with_mocks):
"""Test error handling when insertion fails"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["test"] = RowSchema(
name="test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Make insert fail
mock_result = MagicMock()
mock_result.one.return_value = MagicMock() # Keyspace exists
mock_session.execute.side_effect = [
mock_result, # keyspace existence check succeeds
None, # table creation succeeds
Exception("Connection timeout") # insert fails
]
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="test",
values=[{"id": "123"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
# Should raise the exception
with pytest.raises(Exception, match="Connection timeout"):
await processor.on_object(msg, None, None)
@pytest.mark.asyncio
async def test_collection_partitioning(self, processor_with_mocks):
"""Test that objects are properly partitioned by collection"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["data"] = RowSchema(
name="data",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process objects from different collections
collections = ["import_jan", "import_feb", "import_mar"]
# Create all collections first
for coll in collections:
await processor.create_collection("analytics", coll, {})
for coll in collections:
obj = ExtractedObject(
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
schema_name="data",
values=[{"id": f"ID-{coll}"}],
confidence=0.9,
source_span="Data"
)
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Verify all inserts include collection in values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3
# Check each insert has the correct collection
for i, call in enumerate(insert_calls):
values = call[0][1]
assert collections[i] in values
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
# Create collection first
await processor.create_collection("test_user", "batch_import", {})
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
assert "o_batch_customers" in str(table_calls[0])
# Verify multiple inserts for batch values
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
# Should have 3 separate inserts for the 3 objects in the batch
assert len(insert_calls) == 3
# Check each insert has correct data
for i, call in enumerate(insert_calls):
values = call[0][1]
assert "batch_import" in values # collection
assert f"CUST00{i+1}" in values # customer_id
if i == 0:
assert "John Doe" in values
assert "john@example.com" in values
elif i == 1:
assert "Jane Smith" in values
assert "jane@example.com" in values
elif i == 2:
assert "Bob Johnson" in values
assert "bob@example.com" in values
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Create collection first
await processor.create_collection("test", "empty", {})
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should still create table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 1
# Should not create any insert statements for empty batch
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 0
@pytest.mark.asyncio
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
"""Test processing mix of single and batch objects"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["mixed_test"] = RowSchema(
name="mixed_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
# Create collection first
await processor.create_collection("test", "mixed", {})
# Single object (backward compatibility)
single_obj = ExtractedObject(
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[{"id": "single-1", "data": "single data"}], # Array with single item
confidence=0.9,
source_span="Single object"
)
# Batch object
batch_obj = ExtractedObject(
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
schema_name="mixed_test",
values=[
{"id": "batch-1", "data": "batch data 1"},
{"id": "batch-2", "data": "batch data 2"}
],
confidence=0.85,
source_span="Batch objects"
)
# Process both
for obj in [single_obj, batch_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# Should have 3 total inserts (1 + 2)
insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call)]
assert len(insert_calls) == 3

View file

@ -0,0 +1,492 @@
"""
Integration tests for Cassandra Row Storage (Unified Table Implementation)
These tests verify the end-to-end functionality of storing ExtractedObjects
in the unified Cassandra rows table, including table creation, data insertion,
and error handling.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
@pytest.mark.integration
class TestRowsCassandraIntegration:
"""Integration tests for Cassandra row storage with unified table"""
@pytest.fixture
def mock_cassandra_session(self):
"""Mock Cassandra session for integration tests"""
session = MagicMock()
# Track if keyspaces have been created
created_keyspaces = set()
# Mock the execute method to return a valid result for keyspace checks
def execute_mock(query, *args, **kwargs):
result = MagicMock()
query_str = str(query)
# Track keyspace creation
if "CREATE KEYSPACE" in query_str:
import re
match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str)
if match:
created_keyspaces.add(match.group(1))
# For keyspace existence checks
if "system_schema.keyspaces" in query_str:
if args and args[0] in created_keyspaces:
result.one.return_value = MagicMock() # Exists
else:
result.one.return_value = None # Doesn't exist
else:
result.one.return_value = None
return result
session.execute = MagicMock(side_effect=execute_mock)
return session
@pytest.fixture
def mock_cassandra_cluster(self, mock_cassandra_session):
"""Mock Cassandra cluster"""
cluster = MagicMock()
cluster.connect.return_value = mock_cassandra_session
cluster.shutdown = MagicMock()
return cluster
@pytest.fixture
def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session):
"""Create processor with mocked Cassandra dependencies"""
processor = MagicMock()
processor.cassandra_host = ["localhost"]
processor.cassandra_username = None
processor.cassandra_password = None
processor.config_key = "schema"
processor.schemas = {}
processor.known_keyspaces = set()
processor.tables_initialized = set()
processor.registered_partitions = set()
processor.cluster = None
processor.session = None
# Bind actual methods from the new unified table implementation
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.collection_exists = MagicMock(return_value=True)
return processor, mock_cassandra_cluster, mock_cassandra_session
@pytest.mark.asyncio
async def test_end_to_end_object_storage(self, processor_with_mocks):
"""Test complete flow from schema config to object storage"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Step 1: Configure schema
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer information",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True},
{"name": "age", "type": "integer"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert "customer_records" in processor.schemas
# Step 2: Process an ExtractedObject
test_obj = ExtractedObject(
metadata=Metadata(
id="doc-001",
user="test_user",
collection="import_2024",
metadata=[]
),
schema_name="customer_records",
values=[{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com",
"age": "30"
}],
confidence=0.95,
source_span="Customer: John Doe..."
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify Cassandra interactions
assert mock_cluster.connect.called
# Verify keyspace creation
keyspace_calls = [call for call in mock_session.execute.call_args_list
if "CREATE KEYSPACE" in str(call)]
assert len(keyspace_calls) == 1
assert "test_user" in str(keyspace_calls[0])
# Verify unified table creation (rows table, not per-schema table)
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows table + row_partitions table
assert any("rows" in str(call) for call in table_calls)
assert any("row_partitions" in str(call) for call in table_calls)
# Verify the rows table has correct structure
rows_table_call = [call for call in table_calls if ".rows" in str(call)][0]
assert "collection text" in str(rows_table_call)
assert "schema_name text" in str(rows_table_call)
assert "index_name text" in str(rows_table_call)
assert "data map<text, text>" in str(rows_table_call)
# Verify data insertion into unified table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
# Should have 2 data inserts: one for customer_id (primary), one for email (indexed)
assert len(rows_insert_calls) == 2
@pytest.mark.asyncio
async def test_multi_schema_handling(self, processor_with_mocks):
"""Test handling multiple schemas stored in unified table"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure multiple schemas
config = {
"schema": {
"products": json.dumps({
"name": "products",
"fields": [
{"name": "product_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string"},
{"name": "price", "type": "float"}
]
}),
"orders": json.dumps({
"name": "orders",
"fields": [
{"name": "order_id", "type": "string", "primary_key": True},
{"name": "customer_id", "type": "string"},
{"name": "total", "type": "float"}
]
})
}
}
await processor.on_schema_config(config, version=1)
assert len(processor.schemas) == 2
# Process objects for different schemas
product_obj = ExtractedObject(
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
schema_name="products",
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
confidence=0.9,
source_span="Product..."
)
order_obj = ExtractedObject(
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
schema_name="orders",
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
confidence=0.85,
source_span="Order..."
)
# Process both objects
for obj in [product_obj, order_obj]:
msg = MagicMock()
msg.value.return_value = obj
await processor.on_object(msg, None, None)
# All data goes into the same unified rows table
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
# Should only create 2 tables: rows + row_partitions (not per-schema tables)
assert len(table_calls) == 2
# Verify data inserts go to unified rows table
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) > 0
for call in rows_insert_calls:
assert ".rows" in str(call)
@pytest.mark.asyncio
async def test_multi_index_storage(self, processor_with_mocks):
"""Test that rows are stored with multiple indexes"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Schema with multiple indexed fields
processor.schemas["indexed_data"] = RowSchema(
name="indexed_data",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True),
Field(name="status", type="string", size=50, indexed=True),
Field(name="description", type="string", size=200) # Not indexed
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="indexed_data",
values=[{
"id": "123",
"category": "electronics",
"status": "active",
"description": "A product"
}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 data inserts (one per indexed field: id, category, status)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 3
# Verify different index names were used
index_names = set()
for call in rows_insert_calls:
values = call[0][1]
index_names.add(values[2]) # index_name is 3rd parameter
assert index_names == {"id", "category", "status"}
@pytest.mark.asyncio
async def test_authentication_handling(self, processor_with_mocks):
"""Test Cassandra authentication"""
processor, mock_cluster, mock_session = processor_with_mocks
processor.cassandra_username = "cassandra_user"
processor.cassandra_password = "cassandra_pass"
with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class:
with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth:
mock_cluster_class.return_value = mock_cluster
# Trigger connection
processor.connect_cassandra()
# Verify authentication was configured
mock_auth.assert_called_once_with(
username="cassandra_user",
password="cassandra_pass"
)
mock_cluster_class.assert_called_once()
call_kwargs = mock_cluster_class.call_args[1]
assert 'auth_provider' in call_kwargs
@pytest.mark.asyncio
async def test_batch_object_processing(self, processor_with_mocks):
"""Test processing objects with batched values"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
# Configure schema
config = {
"schema": {
"batch_customers": json.dumps({
"name": "batch_customers",
"description": "Customer batch data",
"fields": [
{"name": "customer_id", "type": "string", "primary_key": True},
{"name": "name", "type": "string", "required": True},
{"name": "email", "type": "string", "indexed": True}
]
})
}
}
await processor.on_schema_config(config, version=1)
# Process batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_import",
metadata=[]
),
schema_name="batch_customers",
values=[
{
"customer_id": "CUST001",
"name": "John Doe",
"email": "john@example.com"
},
{
"customer_id": "CUST002",
"name": "Jane Smith",
"email": "jane@example.com"
},
{
"customer_id": "CUST003",
"name": "Bob Johnson",
"email": "bob@example.com"
}
],
confidence=0.92,
source_span="Multiple customers extracted from document"
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Verify unified table creation
table_calls = [call for call in mock_session.execute.call_args_list
if "CREATE TABLE" in str(call)]
assert len(table_calls) == 2 # rows + row_partitions
# Each row in batch gets 2 data inserts (customer_id primary + email indexed)
# 3 rows * 2 indexes = 6 data inserts
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 6
@pytest.mark.asyncio
async def test_empty_batch_processing(self, processor_with_mocks):
"""Test processing objects with empty values array"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["empty_test"] = RowSchema(
name="empty_test",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
# Process empty batch object
empty_obj = ExtractedObject(
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
schema_name="empty_test",
values=[], # Empty batch
confidence=1.0,
source_span="No objects found"
)
msg = MagicMock()
msg.value.return_value = empty_obj
await processor.on_object(msg, None, None)
# Should not create any data insert statements for empty batch
# (partition registration may still happen)
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) == 0
@pytest.mark.asyncio
async def test_data_stored_as_map(self, processor_with_mocks):
"""Test that data is stored as map<text, text>"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["map_test"] = RowSchema(
name="map_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="count", type="integer", size=0)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
schema_name="map_test",
values=[{"id": "123", "name": "Test Item", "count": "42"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify insert uses map for data
rows_insert_calls = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and ".rows" in str(call)
and "row_partitions" not in str(call)]
assert len(rows_insert_calls) >= 1
# Check that data is passed as a dict (will be map in Cassandra)
insert_call = rows_insert_calls[0]
values = insert_call[0][1]
# Values are: (collection, schema_name, index_name, index_value, data, source)
# values[4] should be the data map
data_map = values[4]
assert isinstance(data_map, dict)
assert data_map["id"] == "123"
assert data_map["name"] == "Test Item"
assert data_map["count"] == "42"
@pytest.mark.asyncio
async def test_partition_registration(self, processor_with_mocks):
"""Test that partitions are registered for efficient querying"""
processor, mock_cluster, mock_session = processor_with_mocks
with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster):
processor.schemas["partition_test"] = RowSchema(
name="partition_test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="category", type="string", size=50, indexed=True)
]
)
test_obj = ExtractedObject(
metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]),
schema_name="partition_test",
values=[{"id": "123", "category": "test"}],
confidence=0.9,
source_span="Test"
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Verify partition registration
partition_inserts = [call for call in mock_session.execute.call_args_list
if "INSERT INTO" in str(call) and "row_partitions" in str(call)]
# Should register partitions for each index (id, category)
assert len(partition_inserts) == 2
# Verify cache was updated
assert ("my_collection", "partition_test") in processor.registered_partitions

View file

@ -1,5 +1,5 @@
"""
Integration tests for Objects GraphQL Query Service
Integration tests for Rows GraphQL Query Service
These tests verify end-to-end functionality including:
- Real Cassandra database operations
@ -24,8 +24,8 @@ except Exception:
DOCKER_AVAILABLE = False
CassandraContainer = None
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration:
processor.connect_cassandra()
# Create mock message
request = ObjectsQueryRequest(
request = RowsQueryRequest(
user="msg_test_user",
collection="msg_test_collection",
query='{ customer_objects { customer_id name } }',
@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration:
# Verify response structure
sent_response = mock_response_producer.send.call_args[0][0]
assert isinstance(sent_response, ObjectsQueryResponse)
assert isinstance(sent_response, RowsQueryResponse)
# Should have no system error (even if no data)
assert sent_response.error is None

View file

@ -2,7 +2,7 @@
Integration tests for Structured Query Service
These tests verify the end-to-end functionality of the structured query service,
testing orchestration between nlp-query and objects-query services.
testing orchestration between nlp-query and rows-query services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects Query Service Response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
errors=None,
@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration:
# Verify Objects service call
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert "customers" in objects_call_args.query
assert "orders" in objects_call_args.query
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects service failure
objects_error_response = ObjectsQueryResponse(
objects_error_response = RowsQueryResponse(
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
data=None,
errors=None,
@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration:
assert response.error is not None
assert response.error.type == "structured-query-error"
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "nonexistent_table" in response.error.message
@pytest.mark.asyncio
@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # No data when validation fails
errors=validation_errors,
@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration:
]
}
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=json.dumps(complex_data),
errors=None,
@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock empty Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": []}', # Empty result set
errors=None,
@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
errors=None,
@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration:
if service_name == "nlp-query-request":
service_call_count += 1
return nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
service_call_count += 1
return objects_client
elif service_name == "response":
@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration:
)
# Mock Objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
errors=None,
@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -0,0 +1,380 @@
"""
Unit tests for trustgraph.embeddings.row_embeddings.embeddings
Tests the Stage 1 processor that computes embeddings for row index fields.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase):
"""Test row embeddings processor functionality"""
async def test_processor_initialization(self):
"""Test basic processor initialization"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings'
}
processor = Processor(**config)
assert hasattr(processor, 'schemas')
assert processor.schemas == {}
assert processor.batch_size == 10 # default
async def test_processor_initialization_with_custom_batch_size(self):
"""Test processor initialization with custom batch size"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-row-embeddings',
'batch_size': 25
}
processor = Processor(**config)
assert processor.batch_size == 25
async def test_get_index_names_single_index(self):
"""Test getting index names with single indexed field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
Field(name='email', type='text', indexed=False),
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed field
assert 'id' in index_names
assert 'name' in index_names
assert 'email' not in index_names
async def test_get_index_names_no_indexes(self):
"""Test getting index names when no fields are indexed"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import RowSchema, Field
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
schema = RowSchema(
name='logs',
description='Log records',
fields=[
Field(name='timestamp', type='text'),
Field(name='message', type='text'),
]
)
index_names = processor.get_index_names(schema)
assert index_names == []
async def test_build_index_value_single_field(self):
"""Test building index value for single field"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'id': 'CUST001',
'name': 'John Doe',
'email': 'john@example.com'
}
result = processor.build_index_value(value_map, 'name')
assert result == ['John Doe']
async def test_build_index_value_composite_index(self):
"""Test building index value for composite index"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'first_name': 'John',
'last_name': 'Doe',
'city': 'New York'
}
result = processor.build_index_value(value_map, 'first_name, last_name')
assert result == ['John', 'Doe']
async def test_build_index_value_missing_field(self):
"""Test building index value when field is missing"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
value_map = {
'name': 'John Doe'
}
result = processor.build_index_value(value_map, 'missing_field')
assert result == ['']
async def test_build_text_for_embedding_single_value(self):
"""Test building text representation for single value"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John Doe'])
assert result == 'John Doe'
async def test_build_text_for_embedding_multiple_values(self):
"""Test building text representation for multiple values"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
result = processor.build_text_for_embedding(['John', 'Doe', 'NYC'])
assert result == 'John Doe NYC'
async def test_on_schema_config_loads_schemas(self):
"""Test that schema configuration is loaded correctly"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
schema_def = {
'name': 'customers',
'description': 'Customer records',
'fields': [
{'name': 'id', 'type': 'text', 'primary_key': True},
{'name': 'name', 'type': 'text', 'indexed': True},
{'name': 'email', 'type': 'text'}
]
}
config_data = {
'schema': {
'customers': json.dumps(schema_def)
}
}
await processor.on_schema_config(config_data, 1)
assert 'customers' in processor.schemas
assert processor.schemas['customers'].name == 'customers'
assert len(processor.schemas['customers'].fields) == 3
async def test_on_schema_config_handles_missing_type(self):
"""Test that missing schema type is handled gracefully"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
config_data = {
'other_type': {}
}
await processor.on_schema_config(config_data, 1)
assert processor.schemas == {}
async def test_on_message_drops_unknown_collection(self):
"""Test that messages for unknown collections are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_drops_unknown_schema(self):
"""Test that messages for unknown schemas are dropped"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# No schemas registered
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='unknown_schema',
values=[{'id': '123', 'name': 'Test'}]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
mock_flow = MagicMock()
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Flow should not be called for output
mock_flow.assert_not_called()
async def test_on_message_processes_embeddings(self):
"""Test processing a message and computing embeddings"""
from trustgraph.embeddings.row_embeddings.embeddings import Processor
from trustgraph.schema import ExtractedObject, RowSchema, Field
import json
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor',
'config_type': 'schema'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Set up schema
processor.schemas['customers'] = RowSchema(
name='customers',
description='Customer records',
fields=[
Field(name='id', type='text', primary=True),
Field(name='name', type='text', indexed=True),
]
)
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
obj = ExtractedObject(
metadata=metadata,
schema_name='customers',
values=[
{'id': 'CUST001', 'name': 'John Doe'},
{'id': 'CUST002', 'name': 'Jane Smith'}
]
)
mock_msg = MagicMock()
mock_msg.value.return_value = obj
# Mock the flow
mock_embeddings_request = AsyncMock()
mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]]
mock_output = AsyncMock()
def flow_factory(name):
if name == 'embeddings-request':
return mock_embeddings_request
elif name == 'output':
return mock_output
return MagicMock()
mock_flow = MagicMock(side_effect=flow_factory)
await processor.on_message(mock_msg, MagicMock(), mock_flow)
# Should have called embed for each unique text
# 4 values: CUST001, John Doe, CUST002, Jane Smith
assert mock_embeddings_request.embed.call_count == 4
# Should have sent output
mock_output.send.assert_called()
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,7 +1,7 @@
"""
Unit tests for objects import dispatcher.
Unit tests for rows import dispatcher.
Tests the business logic of objects import dispatcher
Tests the business logic of rows import dispatcher
while mocking the Publisher and websocket components.
"""
@ -11,7 +11,7 @@ import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from aiohttp import web
from trustgraph.gateway.dispatch.objects_import import ObjectsImport
from trustgraph.gateway.dispatch.rows_import import RowsImport
from trustgraph.schema import Metadata, ExtractedObject
@ -92,16 +92,16 @@ def minimal_objects_message():
}
class TestObjectsImportInitialization:
"""Test ObjectsImport initialization."""
class TestRowsImportInitialization:
"""Test RowsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters."""
"""Test that RowsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -116,28 +116,28 @@ class TestObjectsImportInitialization:
)
# Verify instance variables are set correctly
assert objects_import.ws == mock_websocket
assert objects_import.running == mock_running
assert objects_import.publisher == mock_publisher_instance
assert rows_import.ws == mock_websocket
assert rows_import.running == mock_running
assert rows_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport(
"""Test that RowsImport stores all required references."""
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="objects-queue"
)
assert objects_import.ws is mock_websocket
assert objects_import.running is mock_running
assert rows_import.ws is mock_websocket
assert rows_import.running is mock_running
class TestObjectsImportLifecycle:
"""Test ObjectsImport lifecycle methods."""
class TestRowsImportLifecycle:
"""Test RowsImport lifecycle methods."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that start() calls publisher.start()."""
@ -145,18 +145,18 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.start = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.start()
await rows_import.start()
mock_publisher_instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket."""
@ -164,21 +164,21 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.destroy()
await rows_import.destroy()
# Verify sequence of operations
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
"""Test that destroy() handles None websocket gracefully."""
@ -186,7 +186,7 @@ class TestObjectsImportLifecycle:
mock_publisher_instance.stop = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -194,16 +194,16 @@ class TestObjectsImportLifecycle:
)
# Should not raise exception
await objects_import.destroy()
await rows_import.destroy()
mock_running.stop.assert_called_once()
mock_publisher_instance.stop.assert_called_once()
class TestObjectsImportMessageProcessing:
"""Test ObjectsImport message processing."""
class TestRowsImportMessageProcessing:
"""Test RowsImport message processing."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly."""
@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = sample_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.metadata.collection == "testcollection"
assert len(sent_object.metadata.metadata) == 1 # One triple in metadata
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields."""
@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = minimal_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == "" # Default value
assert len(sent_object.metadata.metadata) == 0 # Default empty list
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields."""
@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing:
mock_msg = Mock()
mock_msg.json.return_value = message_data
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Get the sent object and verify defaults
sent_object = mock_publisher_instance.send.call_args[0][1]
@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing:
assert sent_object.source_span == ""
class TestObjectsImportRunMethod:
"""Test ObjectsImport run method."""
class TestRowsImportRunMethod:
"""Test RowsImport run method."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True."""
@ -331,14 +331,14 @@ class TestObjectsImportRunMethod:
# Set up running state to return True twice, then False
mock_running.get.side_effect = [True, True, False]
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="test-queue"
)
await objects_import.run()
await rows_import.run()
# Verify sleep was called twice (for the two True iterations)
assert mock_sleep.call_count == 2
@ -348,10 +348,10 @@ class TestObjectsImportRunMethod:
mock_websocket.close.assert_called_once()
# Verify websocket was set to None
assert objects_import.ws is None
assert rows_import.ws is None
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep')
@pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
"""Test that run() handles None websocket gracefully."""
@ -360,7 +360,7 @@ class TestObjectsImportRunMethod:
mock_running.get.return_value = False # Exit immediately
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=None, # None websocket
running=mock_running,
backend=mock_backend,
@ -368,14 +368,14 @@ class TestObjectsImportRunMethod:
)
# Should not raise exception
await objects_import.run()
await rows_import.run()
# Verify websocket remains None
assert objects_import.ws is None
assert rows_import.ws is None
class TestObjectsImportBatchProcessing:
"""Test ObjectsImport batch processing functionality."""
class TestRowsImportBatchProcessing:
"""Test RowsImport batch processing functionality."""
@pytest.fixture
def batch_objects_message(self):
@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing:
"source_span": "Multiple people found in document"
}
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly."""
@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = batch_objects_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Verify publisher.send was called
mock_publisher_instance.send.assert_called_once()
@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing:
assert sent_object.confidence == 0.85
assert sent_object.source_span == "Multiple people found in document"
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly."""
@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing:
mock_publisher_instance.send = AsyncMock()
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing:
mock_msg = Mock()
mock_msg.json.return_value = empty_batch_message
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
# Should still send the message
mock_publisher_instance.send.assert_called_once()
@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing:
assert len(sent_object.values) == 0
class TestObjectsImportErrorHandling:
"""Test error handling in ObjectsImport."""
class TestRowsImportErrorHandling:
"""Test error handling in RowsImport."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors."""
@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling:
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
mock_publisher_class.return_value = mock_publisher_instance
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling:
mock_msg.json.return_value = sample_objects_message
with pytest.raises(Exception, match="Publisher error"):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport(
rows_import = RowsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling:
mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with pytest.raises(json.JSONDecodeError):
await objects_import.receive(mock_msg)
await rows_import.receive(mock_msg)

View file

@ -76,7 +76,7 @@ def cities_schema():
def validator():
"""Create a mock processor with just the validation method"""
from unittest.mock import MagicMock
from trustgraph.extract.kg.objects.processor import Processor
from trustgraph.extract.kg.rows.processor import Processor
# Create a mock processor
mock_processor = MagicMock()

View file

@ -167,7 +167,7 @@ class TestFlowClient:
expected_methods = [
'text_completion', 'agent', 'graph_rag', 'document_rag',
'graph_embeddings_query', 'embeddings', 'prompt',
'triples_query', 'objects_query'
'triples_query', 'rows_query'
]
for method in expected_methods:
@ -216,7 +216,7 @@ class TestSocketClient:
expected_methods = [
'agent', 'text_completion', 'graph_rag', 'document_rag',
'prompt', 'graph_embeddings_query', 'embeddings',
'triples_query', 'objects_query', 'mcp_tool'
'triples_query', 'rows_query', 'mcp_tool'
]
for method in expected_methods:
@ -243,7 +243,7 @@ class TestBulkClient:
'import_graph_embeddings',
'import_document_embeddings',
'import_entity_contexts',
'import_objects'
'import_rows'
]
for method in import_methods:

View file

@ -1,10 +1,11 @@
"""
Unit tests for Cassandra Objects GraphQL Query Processor
Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation)
Tests the business logic of the GraphQL query processor including:
- GraphQL schema generation from RowSchema
- Query execution and validation
- CQL translation logic
- Schema configuration handling
- Query execution using unified rows table
- Name sanitization
- GraphQL query execution
- Message processing logic
"""
@ -12,119 +13,91 @@ import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import strawberry
from strawberry import Schema
from trustgraph.query.objects.cassandra.service import Processor
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from trustgraph.query.rows.cassandra.service import Processor
from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from trustgraph.schema import RowSchema, Field
class TestObjectsGraphQLQueryLogic:
"""Test business logic without external dependencies"""
def test_get_python_type_mapping(self):
"""Test schema field type conversion to Python types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_python_type("string") == str
assert processor.get_python_type("integer") == int
assert processor.get_python_type("float") == float
assert processor.get_python_type("boolean") == bool
assert processor.get_python_type("timestamp") == str
assert processor.get_python_type("date") == str
assert processor.get_python_type("time") == str
assert processor.get_python_type("uuid") == str
# Unknown type defaults to str
assert processor.get_python_type("unknown_type") == str
def test_create_graphql_type_basic_fields(self):
"""Test GraphQL type creation for basic field types"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="test_table",
description="Test table",
fields=[
Field(
name="id",
type="string",
primary=True,
required=True,
description="Primary key"
),
Field(
name="name",
type="string",
required=True,
description="Name field"
),
Field(
name="age",
type="integer",
required=False,
description="Optional age"
),
Field(
name="active",
type="boolean",
required=False,
description="Status flag"
)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test_table", schema)
# Verify type was created
assert graphql_type is not None
assert hasattr(graphql_type, '__name__')
assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower()
class TestRowsGraphQLQueryLogic:
"""Test business logic for unified table query implementation"""
def test_sanitize_name_cassandra_compatibility(self):
"""Test name sanitization for Cassandra field names"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test field name sanitization (matches storage processor)
# Test field name sanitization (uses r_ prefix like storage processor)
assert processor.sanitize_name("simple_field") == "simple_field"
assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes"
assert processor.sanitize_name("field.with.dots") == "field_with_dots"
assert processor.sanitize_name("123_field") == "o_123_field"
assert processor.sanitize_name("123_field") == "r_123_field"
assert processor.sanitize_name("field with spaces") == "field_with_spaces"
assert processor.sanitize_name("special!@#chars") == "special___chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
def test_sanitize_table_name(self):
"""Test table name sanitization (always gets o_ prefix)"""
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Table names always get o_ prefix
assert processor.sanitize_table("simple_table") == "o_simple_table"
assert processor.sanitize_table("Table-Name") == "o_table_name"
assert processor.sanitize_table("123table") == "o_123table"
assert processor.sanitize_table("") == "o_"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names
assert len(index_names) == 3
def test_find_matching_index_exact_match(self):
"""Test finding matching index for exact match query"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
schema = RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string") # Not indexed
]
)
# Filter on indexed field should return match
filters = {"category": "electronics"}
result = processor.find_matching_index(schema, filters)
assert result is not None
assert result[0] == "category"
assert result[1] == ["electronics"]
# Filter on non-indexed field should return None
filters = {"name": "test"}
result = processor.find_matching_index(schema, filters)
assert result is None
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
processor = MagicMock()
processor.schemas = {}
processor.graphql_types = {}
processor.graphql_schema = None
processor.config_key = "schema" # Set the config key
processor.generate_graphql_schema = AsyncMock()
processor.config_key = "schema"
processor.schema_builder = MagicMock()
processor.schema_builder.clear = MagicMock()
processor.schema_builder.add_schema = MagicMock()
processor.schema_builder.build = MagicMock(return_value=MagicMock())
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
schema_config = {
"schema": {
@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic:
})
}
}
# Process config
await processor.on_schema_config(schema_config, version=1)
# Verify schema was loaded
assert "customer" in processor.schemas
schema = processor.schemas["customer"]
assert schema.name == "customer"
assert len(schema.fields) == 3
# Verify fields
id_field = next(f for f in schema.fields if f.name == "id")
assert id_field.primary is True
# The field should have been created correctly from JSON
# Let's test what we can verify - that the field has the right attributes
assert hasattr(id_field, 'required') # Has the required attribute
assert hasattr(id_field, 'primary') # Has the primary attribute
email_field = next(f for f in schema.fields if f.name == "email")
assert email_field.indexed is True
status_field = next(f for f in schema.fields if f.name == "status")
assert status_field.enum_values == ["active", "inactive"]
# Verify GraphQL schema regeneration was called
processor.generate_graphql_schema.assert_called_once()
def test_cql_query_building_basic(self):
"""Test basic CQL query construction"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to capture the query
mock_result = []
processor.session.execute.return_value = mock_result
# Create test schema
schema = RowSchema(
name="test_table",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string", indexed=True),
Field(name="status", type="string")
]
)
# Test query building
asyncio = pytest.importorskip("asyncio")
async def run_test():
await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="test_table",
row_schema=schema,
filters={"name": "John", "invalid_filter": "ignored"},
limit=10
)
# Run the async test
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_test())
finally:
loop.close()
# Verify Cassandra connection and query execution
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify the query structure (can't easily test exact query without complex mocking)
call_args = processor.session.execute.call_args
query = call_args[0][0] # First positional argument is the query
params = call_args[0][1] # Second positional argument is parameters
# Basic query structure checks
assert "SELECT * FROM test_user.o_test_table" in query
assert "WHERE" in query
assert "collection = %s" in query
assert "LIMIT 10" in query
# Parameters should include collection and name filter
assert "test_collection" in params
assert "John" in params
# Verify schema builder was called
processor.schema_builder.add_schema.assert_called_once()
processor.schema_builder.build.assert_called_once()
@pytest.mark.asyncio
async def test_graphql_context_handling(self):
@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Mock schema execution
mock_result = MagicMock()
mock_result.data = {"customers": [{"id": "1", "name": "Test"}]}
mock_result.errors = None
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { id name } }',
variables={},
@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify schema.execute was called with correct context
processor.graphql_schema.execute.assert_called_once()
call_args = processor.graphql_schema.execute.call_args
# Verify context was passed
context = call_args[1]['context_value'] # keyword argument
context = call_args[1]['context_value']
assert context["processor"] == processor
assert context["user"] == "test_user"
assert context["collection"] == "test_collection"
# Verify result structure
assert "data" in result
assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]}
@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.graphql_schema = AsyncMock()
processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor)
# Create a simple object to simulate GraphQL error instead of MagicMock
# Create a simple object to simulate GraphQL error
class MockError:
def __init__(self, message, path, extensions):
self.message = message
self.path = path
self.extensions = extensions
def __str__(self):
return self.message
mock_error = MockError(
message="Field 'invalid_field' doesn't exist",
path=["customers", "0", "invalid_field"],
extensions={"code": "FIELD_NOT_FOUND"}
)
mock_result = MagicMock()
mock_result.data = None
mock_result.errors = [mock_error]
processor.graphql_schema.execute.return_value = mock_result
result = await processor.execute_graphql_query(
query='{ customers { invalid_field } }',
variables={},
operation_name=None,
user="test_user",
user="test_user",
collection="test_collection"
)
# Verify error handling
assert "errors" in result
assert len(result["errors"]) == 1
error = result["errors"][0]
assert error["message"] == "Field 'invalid_field' doesn't exist"
assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path
assert error["path"] == ["customers", "0", "invalid_field"]
assert error["extensions"] == {"code": "FIELD_NOT_FOUND"}
def test_schema_generation_basic_structure(self):
"""Test basic GraphQL schema generation structure"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Test individual type creation (avoiding the full schema generation which has annotation issues)
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify type was created
assert len(processor.graphql_types) == 1
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None
@pytest.mark.asyncio
async def test_message_processing_success(self):
"""Test successful message processing flow"""
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock successful query result
processor.execute_graphql_query.return_value = {
"data": {"customers": [{"id": "1", "name": "John"}]},
"errors": [],
"extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String())
"extensions": {}
}
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
collection="test_collection",
query='{ customers { id name } }',
variables={},
operation_name=None
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-123"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify query was executed
processor.execute_graphql_query.assert_called_once_with(
query='{ customers { id name } }',
@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic:
user="test_user",
collection="test_collection"
)
# Verify response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is None
assert '"customers"' in response_call.data # JSON encoded
assert len(response_call.errors) == 0
@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic:
processor = MagicMock()
processor.execute_graphql_query = AsyncMock()
processor.on_message = Processor.on_message.__get__(processor, Processor)
# Mock query execution error
processor.execute_graphql_query.side_effect = RuntimeError("No schema available")
# Create mock message
mock_msg = MagicMock()
mock_request = ObjectsQueryRequest(
mock_request = RowsQueryRequest(
user="test_user",
collection="test_collection",
query='{ invalid_query }',
@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic:
)
mock_msg.value.return_value = mock_request
mock_msg.properties.return_value = {"id": "test-456"}
# Mock flow
mock_flow = MagicMock()
mock_response_flow = AsyncMock()
mock_flow.return_value = mock_response_flow
# Process message
await processor.on_message(mock_msg, None, mock_flow)
# Verify error response was sent
mock_response_flow.send.assert_called_once()
response_call = mock_response_flow.send.call_args[0][0]
# Verify error response structure
assert isinstance(response_call, ObjectsQueryResponse)
assert isinstance(response_call, RowsQueryResponse)
assert response_call.error is not None
assert response_call.error.type == "objects-query-error"
assert response_call.error.type == "rows-query-error"
assert "No schema available" in response_call.error.message
assert response_call.data is None
class TestCQLQueryGeneration:
"""Test CQL query generation logic in isolation"""
def test_partition_key_inclusion(self):
"""Test that collection is always included in queries"""
class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
async def test_query_with_index_match(self):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
# Mock the query building (simplified version)
keyspace = processor.sanitize_name("test_user")
table = processor.sanitize_table("test_table")
query = f"SELECT * FROM {keyspace}.{table}"
where_clauses = ["collection = %s"]
assert "collection = %s" in where_clauses
assert keyspace == "test_user"
assert table == "o_test_table"
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
processor.session.execute.return_value = [mock_row]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string")
]
)
# Query with filter on indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"category": "electronics"},
limit=10
)
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
processor.session.execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = processor.session.execute.call_args
query = call_args[0][0]
params = call_args[0][1]
assert "SELECT data, source FROM test_user.rows" in query
assert "collection = %s" in query
assert "schema_name = %s" in query
assert "index_name = %s" in query
assert "index_value = %s" in query
assert params[0] == "test_collection"
assert params[1] == "products"
assert params[2] == "category"
assert params[3] == ["electronics"]
# Verify results
assert len(results) == 1
assert results[0]["id"] == "123"
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
async def test_query_without_index_match(self):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock session execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
processor.session.execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string"), # Not indexed
Field(name="price", type="string") # Not indexed
]
)
# Query with filter on non-indexed field
results = await processor.query_cassandra(
user="test_user",
collection="test_collection",
schema_name="products",
row_schema=schema,
filters={"name": "Product A"},
limit=10
)
# Query should use ALLOW FILTERING for scan
call_args = processor.session.execute.call_args
query = call_args[0][0]
assert "ALLOW FILTERING" in query
# Should post-filter results
assert len(results) == 1
assert results[0]["name"] == "Product A"
class TestFilterMatching:
"""Test filter matching logic"""
def test_matches_filters_exact_match(self):
"""Test exact match filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active", "name": "test"}
assert processor._matches_filters(row, {"status": "active"}, schema) is True
assert processor._matches_filters(row, {"status": "inactive"}, schema) is False
def test_matches_filters_comparison_operators(self):
"""Test comparison operators in filters"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="price", type="float")])
row = {"price": "100.0"}
# Greater than
assert processor._matches_filters(row, {"price_gt": 50}, schema) is True
assert processor._matches_filters(row, {"price_gt": 150}, schema) is False
# Less than
assert processor._matches_filters(row, {"price_lt": 150}, schema) is True
assert processor._matches_filters(row, {"price_lt": 50}, schema) is False
# Greater than or equal
assert processor._matches_filters(row, {"price_gte": 100}, schema) is True
assert processor._matches_filters(row, {"price_gte": 101}, schema) is False
# Less than or equal
assert processor._matches_filters(row, {"price_lte": 100}, schema) is True
assert processor._matches_filters(row, {"price_lte": 99}, schema) is False
def test_matches_filters_contains(self):
"""Test contains filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="description", type="string")])
row = {"description": "A great product for everyone"}
assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True
assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False
def test_matches_filters_in_list(self):
"""Test in-list filter"""
processor = MagicMock()
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
schema = RowSchema(name="test", fields=[Field(name="status", type="string")])
row = {"status": "active"}
assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True
assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False
class TestIndexedFieldFiltering:
"""Test that only indexed or primary key fields can be directly filtered"""
def test_indexed_field_filtering(self):
"""Test that only indexed or primary key fields can be filtered"""
# Create schema with mixed field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="indexed_field", type="string", indexed=True),
Field(name="normal_field", type="string", indexed=False),
Field(name="another_field", type="string")
]
)
filters = {
"id": "test123", # Primary key - should be included
"indexed_field": "value", # Indexed - should be included
"normal_field": "ignored", # Not indexed - should be ignored
"another_field": "also_ignored" # Not indexed - should be ignored
}
# Simulate the filtering logic from the processor
valid_filters = []
for field_name, value in filters.items():
@ -492,7 +531,7 @@ class TestCQLQueryGeneration:
schema_field = next((f for f in schema.fields if f.name == field_name), None)
if schema_field and (schema_field.indexed or schema_field.primary):
valid_filters.append((field_name, value))
# Only id and indexed_field should be included
assert len(valid_filters) == 2
field_names = [f[0] for f in valid_filters]
@ -500,52 +539,3 @@ class TestCQLQueryGeneration:
assert "indexed_field" in field_names
assert "normal_field" not in field_names
assert "another_field" not in field_names
class TestGraphQLSchemaGeneration:
"""Test GraphQL schema generation in detail"""
def test_field_type_annotations(self):
"""Test that GraphQL types have correct field annotations"""
processor = MagicMock()
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create schema with various field types
schema = RowSchema(
name="test",
fields=[
Field(name="id", type="string", required=True, primary=True),
Field(name="count", type="integer", required=True),
Field(name="price", type="float", required=False),
Field(name="active", type="boolean", required=False),
Field(name="optional_text", type="string", required=False)
]
)
# Create GraphQL type
graphql_type = processor.create_graphql_type("test", schema)
# Verify type was created successfully
assert graphql_type is not None
def test_basic_type_creation(self):
"""Test that GraphQL types are created correctly"""
processor = MagicMock()
processor.schemas = {
"customer": RowSchema(
name="customer",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.graphql_types = {}
processor.get_python_type = Processor.get_python_type.__get__(processor, Processor)
processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor)
# Create GraphQL type directly
graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"])
processor.graphql_types["customer"] = graphql_type
# Verify customer type was created
assert "customer" in processor.graphql_types
assert processor.graphql_types["customer"] is not None

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
StructuredQueryRequest, StructuredQueryResponse,
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
ObjectsQueryRequest, ObjectsQueryResponse,
RowsQueryRequest, RowsQueryResponse,
Error, GraphQLError
)
from trustgraph.retrieval.structured_query.service import Processor
@ -68,7 +68,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}',
errors=None,
@ -86,7 +86,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -108,7 +108,7 @@ class TestStructuredQueryProcessor:
# Verify objects query service was called correctly
mock_objects_client.request.assert_called_once()
objects_call_args = mock_objects_client.request.call_args[0][0]
assert isinstance(objects_call_args, ObjectsQueryRequest)
assert isinstance(objects_call_args, RowsQueryRequest)
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
assert objects_call_args.variables == {"state": "NY"}
assert objects_call_args.user == "trustgraph"
@ -224,7 +224,7 @@ class TestStructuredQueryProcessor:
assert response.error is not None
assert "empty GraphQL query" in response.error.message
async def test_objects_query_service_error(self, processor):
async def test_rows_query_service_error(self, processor):
"""Test handling of objects query service errors"""
# Arrange
request = StructuredQueryRequest(
@ -250,7 +250,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects query service error
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=Error(type="graphql-execution-error", message="Table 'customers' not found"),
data=None,
errors=None,
@ -267,7 +267,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -284,7 +284,7 @@ class TestStructuredQueryProcessor:
response = response_call[0][0]
assert response.error is not None
assert "Objects query service error" in response.error.message
assert "Rows query service error" in response.error.message
assert "Table 'customers' not found" in response.error.message
async def test_graphql_errors_handling(self, processor):
@ -321,7 +321,7 @@ class TestStructuredQueryProcessor:
)
]
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None,
errors=graphql_errors,
@ -338,7 +338,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -400,7 +400,7 @@ class TestStructuredQueryProcessor:
)
# Mock objects response
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}',
errors=None
@ -416,7 +416,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response
@ -464,7 +464,7 @@ class TestStructuredQueryProcessor:
confidence=0.9
)
objects_response = ObjectsQueryResponse(
objects_response = RowsQueryResponse(
error=None,
data=None, # Null data
errors=None,
@ -481,7 +481,7 @@ class TestStructuredQueryProcessor:
def flow_router(service_name):
if service_name == "nlp-query-request":
return mock_nlp_client
elif service_name == "objects-query-request":
elif service_name == "rows-query-request":
return mock_objects_client
elif service_name == "response":
return flow_response

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
from trustgraph.storage.knowledge.store import Processor as KgStore
@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_password is None
class TestObjectsWriterConfiguration:
class TestRowsWriterConfiguration:
"""Test Cassandra configuration in objects writer processor."""
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_environment_variable_configuration(self, mock_cluster):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2']
assert processor.cassandra_username == 'obj-env-user'
assert processor.cassandra_password == 'obj-env-pass'
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
def test_cassandra_connection_with_hosts_list(self, mock_cluster):
"""Test that Cassandra connection uses hosts list correctly."""
env_vars = {
@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify cluster was called with hosts list
@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration:
assert 'contact_points' in call_args.kwargs
assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3']
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
@patch('trustgraph.storage.rows.cassandra.write.Cluster')
@patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider')
def test_authentication_configuration(self, mock_auth_provider, mock_cluster):
"""Test authentication is configured when credentials are provided."""
env_vars = {
@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration:
mock_cluster.return_value = mock_cluster_instance
with patch.dict(os.environ, env_vars, clear=True):
processor = ObjectsWriter(taskgroup=MagicMock())
processor = RowsWriter(taskgroup=MagicMock())
processor.connect_cassandra()
# Verify auth provider was created with correct credentials
@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling:
def test_objects_writer_add_args(self):
"""Test that objects writer adds standard Cassandra arguments."""
import argparse
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter
parser = argparse.ArgumentParser()
ObjectsWriter.add_args(parser)
RowsWriter.add_args(parser)
# Parse empty args to check that arguments exist
args = parser.parse_args([])

View file

@ -1,533 +0,0 @@
"""
Unit tests for Cassandra Object Storage Processor
Tests the business logic of the object storage processor including:
- Schema configuration handling
- Type conversions
- Name sanitization
- Table structure generation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.objects.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestObjectsCassandraStorageLogic:
"""Test business logic without FlowProcessor dependencies"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns (back to original logic)
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
def test_get_cassandra_type(self):
"""Test field type conversion to Cassandra types"""
processor = MagicMock()
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
# Basic type mappings
assert processor.get_cassandra_type("string") == "text"
assert processor.get_cassandra_type("boolean") == "boolean"
assert processor.get_cassandra_type("timestamp") == "timestamp"
assert processor.get_cassandra_type("uuid") == "uuid"
# Integer types with size hints
assert processor.get_cassandra_type("integer", size=2) == "int"
assert processor.get_cassandra_type("integer", size=8) == "bigint"
# Float types with size hints
assert processor.get_cassandra_type("float", size=2) == "float"
assert processor.get_cassandra_type("float", size=8) == "double"
# Unknown type defaults to text
assert processor.get_cassandra_type("unknown_type") == "text"
def test_convert_value(self):
"""Test value conversion for different field types"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Integer conversions
assert processor.convert_value("123", "integer") == 123
assert processor.convert_value(123.5, "integer") == 123
assert processor.convert_value(None, "integer") is None
# Float conversions
assert processor.convert_value("123.45", "float") == 123.45
assert processor.convert_value(123, "float") == 123.0
# Boolean conversions
assert processor.convert_value("true", "boolean") is True
assert processor.convert_value("false", "boolean") is False
assert processor.convert_value("1", "boolean") is True
assert processor.convert_value("0", "boolean") is False
assert processor.convert_value("yes", "boolean") is True
assert processor.convert_value("no", "boolean") is False
# String conversions
assert processor.convert_value(123, "string") == "123"
assert processor.convert_value(True, "string") == "True"
def test_table_creation_cql_generation(self):
"""Test CQL generation for table creation"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create test schema
schema = RowSchema(
name="customer_records",
description="Test customer schema",
fields=[
Field(
name="customer_id",
type="string",
size=50,
primary=True,
required=True,
indexed=False
),
Field(
name="email",
type="string",
size=100,
required=True,
indexed=True
),
Field(
name="age",
type="integer",
size=4,
required=False,
indexed=False
)
]
)
# Call ensure_table
processor.ensure_table("test_user", "customer_records", schema)
# Verify keyspace was ensured (check that it was added to known_keyspaces)
assert "test_user" in processor.known_keyspaces
# Check the CQL that was executed (first call should be table creation)
all_calls = processor.session.execute.call_args_list
table_creation_cql = all_calls[0][0][0] # First call
# Verify table structure (keyspace uses sanitize_name, table uses sanitize_table)
assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql
assert "collection text" in table_creation_cql
assert "customer_id text" in table_creation_cql
assert "email text" in table_creation_cql
assert "age int" in table_creation_cql
assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql
def test_table_creation_without_primary_key(self):
"""Test table creation when no primary key is defined"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema without primary key
schema = RowSchema(
name="events",
description="Event log",
fields=[
Field(name="event_type", type="string", size=50),
Field(name="timestamp", type="timestamp", size=0)
]
)
# Call ensure_table
processor.ensure_table("test_user", "events", schema)
# Check the CQL includes synthetic_id (field names don't get o_ prefix)
executed_cql = processor.session.execute.call_args[0][0]
assert "synthetic_id uuid" in executed_cql
assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "balance",
"type": "float",
"size": 8
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
# Note: Field.required always returns False due to Pulsar schema limitations
# The actual required value is tracked during schema parsing
@pytest.mark.asyncio
async def test_object_processing_logic(self):
"""Test the logic for processing ExtractedObject"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="integer", size=4)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "456"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"])
# Verify insert was executed (keyspace normal, table with o_ prefix)
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_test_schema" in insert_cql
assert "collection" in insert_cql
assert values[0] == "test_collection" # collection value
assert values[1] == "123" # id value (from values[0])
assert values[2] == 456 # converted integer value (from values[0])
def test_secondary_index_creation(self):
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
description="Product catalog",
fields=[
Field(name="product_id", type="string", size=50, primary=True),
Field(name="category", type="string", size=30, indexed=True),
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
# Check index creation calls (table has o_ prefix, fields don't)
calls = processor.session.execute.call_args_list
index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]]
assert len(index_calls) == 2
assert any("o_products_category_idx" in call for call in index_calls)
assert any("o_products_price_idx" in call for call in index_calls)
class TestObjectsCassandraStorageBatchLogic:
"""Test batch processing logic in Cassandra storage"""
@pytest.mark.asyncio
async def test_batch_object_processing_logic(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
description="Test batch schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="name", type="string", size=100),
Field(name="value", type="integer", size=4)
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First", "value": "100"},
{"id": "002", "name": "Second", "value": "200"},
{"id": "003", "name": "Third", "value": "300"}
],
confidence=0.95,
source_span="batch source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = batch_obj
# Process batch object
await processor.on_object(msg, None, None)
# Verify table was ensured once
processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"])
# Verify 3 separate insert calls (one per batch item)
assert processor.session.execute.call_count == 3
# Check each insert call
calls = processor.session.execute.call_args_list
for i, call in enumerate(calls):
insert_cql = call[0][0]
values = call[0][1]
assert "INSERT INTO test_user.o_batch_schema" in insert_cql
assert "collection" in insert_cql
# Check values for each batch item
assert values[0] == "batch_collection" # collection
assert values[1] == f"00{i+1}" # id from batch item i
assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name
assert values[3] == (i+1) * 100 # converted integer value
@pytest.mark.asyncio
async def test_empty_batch_processing_logic(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", size=50, primary=True)]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span="empty source"
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
# Process empty batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
@pytest.mark.asyncio
async def test_single_item_batch_processing_logic(self):
"""Test processing of single-item batch (backward compatibility)"""
processor = MagicMock()
processor.schemas = {
"single_schema": RowSchema(
name="single_schema",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="data", type="string", size=100)
]
)
}
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(
metadata=Metadata(
id="single-001",
user="test_user",
collection="single_collection",
metadata=[]
),
schema_name="single_schema",
values=[{"id": "single-1", "data": "single data"}], # Array with one item
confidence=0.8,
source_span="single source"
)
msg = MagicMock()
msg.value.return_value = single_batch_obj
# Process single-item batch object
await processor.on_object(msg, None, None)
# Verify table was ensured
processor.ensure_table.assert_called_once()
# Verify exactly one insert call
processor.session.execute.assert_called_once()
insert_cql = processor.session.execute.call_args[0][0]
values = processor.session.execute.call_args[0][1]
assert "INSERT INTO test_user.o_single_schema" in insert_cql
assert values[0] == "single_collection" # collection
assert values[1] == "single-1" # id value
assert values[2] == "single data" # data value
def test_batch_value_conversion_logic(self):
"""Test value conversion works correctly for batch items"""
processor = MagicMock()
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
# Test various conversion scenarios that would occur in batch processing
test_cases = [
# Integer conversions for batch items
("123", "integer", 123),
("456", "integer", 456),
("789", "integer", 789),
# Float conversions for batch items
("12.5", "float", 12.5),
("34.7", "float", 34.7),
# Boolean conversions for batch items
("true", "boolean", True),
("false", "boolean", False),
("1", "boolean", True),
("0", "boolean", False),
# String conversions for batch items
(123, "string", "123"),
(45.6, "string", "45.6"),
]
for input_val, field_type, expected_output in test_cases:
result = processor.convert_value(input_val, field_type)
assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}"

View file

@ -0,0 +1,435 @@
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant row embeddings storage functionality"""
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_basic(self, mock_qdrant_client):
"""Test basic Qdrant processor initialization"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key='test-api-key'
)
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
"""Test processor initialization with default values"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-qdrant-processor'
}
processor = Processor(**config)
mock_qdrant_client.assert_called_once_with(
url='http://localhost:6333', api_key=None
)
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_sanitize_name(self, mock_qdrant_client):
"""Test name sanitization for Qdrant collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Test basic sanitization
assert processor.sanitize_name("simple") == "simple"
assert processor.sanitize_name("with-dash") == "with_dash"
assert processor.sanitize_name("with.dot") == "with_dot"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
# Test numeric prefix handling
assert processor.sanitize_name("123start") == "r_123start"
assert processor.sanitize_name("_underscore") == "r__underscore"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_get_collection_name(self, mock_qdrant_client):
"""Test Qdrant collection name generation"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_client.return_value = MagicMock()
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
"""Test that ensure_collection creates a new collection when needed"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
"""Test that ensure_collection skips creation when collection exists"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_uses_cache(self, mock_qdrant_client):
"""Test that ensure_collection uses cache for previously created collections"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client):
"""Test processing basic row embeddings message"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid-123'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='customer_id',
index_value=['CUST001'],
text='CUST001',
vectors=[[0.1, 0.2, 0.3]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='customers',
embeddings=[embedding]
)
# Mock message wrapper
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload['index_name'] == 'customer_id'
assert point.payload['index_value'] == ['CUST001']
assert point.payload['text'] == 'CUST001'
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.row_embeddings.qdrant.write.uuid')
async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client):
"""Test processing embeddings with multiple vectors"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = 'test-uuid'
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with multiple vectors
embedding = RowIndexEmbedding(
index_name='name',
index_value=['John Doe'],
text='John Doe',
vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='people',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should be called 3 times (once per vector)
assert mock_qdrant_instance.upsert.call_count == 3
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client):
"""Test that embeddings with no vectors are skipped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
# Embedding with no vectors
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[] # Empty vectors
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client):
"""Test that messages for unknown collections are dropped"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
from trustgraph.schema import RowEmbeddings, RowIndexEmbedding
mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
embedding = RowIndexEmbedding(
index_name='id',
index_value=['123'],
text='123',
vectors=[[0.1, 0.2]]
)
embeddings_msg = RowEmbeddings(
metadata=metadata,
schema_name='items',
embeddings=[embedding]
)
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection(self, mock_qdrant_client):
"""Test deleting all collections for a user/collection"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
"""Test deleting collections for a specific schema"""
from trustgraph.storage.row_embeddings.qdrant.write import Processor
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
mock_qdrant_instance.get_collections.return_value = mock_collections
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,474 @@
"""
Unit tests for Cassandra Row Storage Processor (Unified Table Implementation)
Tests the business logic of the row storage processor including:
- Schema configuration handling
- Name sanitization
- Unified table structure
- Index management
- Row storage with multi-index support
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
from trustgraph.storage.rows.cassandra.write import Processor
from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field
class TestRowsCassandraStorageLogic:
"""Test business logic for unified table implementation"""
def test_sanitize_name(self):
"""Test name sanitization for Cassandra compatibility"""
processor = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
# Test various name patterns
assert processor.sanitize_name("simple_name") == "simple_name"
assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes"
assert processor.sanitize_name("name.with.dots") == "name_with_dots"
assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number"
assert processor.sanitize_name("name with spaces") == "name_with_spaces"
assert processor.sanitize_name("special!@#$%^chars") == "special______chars"
assert processor.sanitize_name("UPPERCASE") == "uppercase"
assert processor.sanitize_name("CamelCase") == "camelcase"
assert processor.sanitize_name("_underscore_start") == "r__underscore_start"
def test_get_index_names(self):
"""Test extraction of index names from schema"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
# Schema with primary and indexed fields
schema = RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="name", type="string"), # Not indexed
Field(name="status", type="string", indexed=True)
]
)
index_names = processor.get_index_names(schema)
# Should include primary key and indexed fields
assert "id" in index_names
assert "category" in index_names
assert "status" in index_names
assert "name" not in index_names # Not indexed
assert len(index_names) == 3
def test_get_index_names_no_indexes(self):
"""Test schema with no indexed fields"""
processor = MagicMock()
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
schema = RowSchema(
name="no_index_schema",
fields=[
Field(name="data1", type="string"),
Field(name="data2", type="string")
]
)
index_names = processor.get_index_names(schema)
assert len(index_names) == 0
def test_build_index_value(self):
"""Test building index values from row data"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"id": "123", "category": "electronics", "name": "Widget"}
# Single field index
result = processor.build_index_value(value_map, "id")
assert result == ["123"]
result = processor.build_index_value(value_map, "category")
assert result == ["electronics"]
# Missing field returns empty string
result = processor.build_index_value(value_map, "missing")
assert result == [""]
def test_build_index_value_composite(self):
"""Test building composite index values"""
processor = MagicMock()
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
value_map = {"region": "us-west", "category": "electronics", "id": "123"}
# Composite index (comma-separated field names)
result = processor.build_index_value(value_map, "region,category")
assert result == ["us-west", "electronics"]
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration
config = {
"schema": {
"customer_records": json.dumps({
"name": "customer_records",
"description": "Customer data",
"fields": [
{
"name": "id",
"type": "string",
"primary_key": True,
"required": True
},
{
"name": "name",
"type": "string",
"required": True
},
{
"name": "category",
"type": "string",
"indexed": True
}
]
})
}
}
# Process configuration
await processor.on_schema_config(config, version=1)
# Verify schema was loaded
assert "customer_records" in processor.schemas
schema = processor.schemas["customer_records"]
assert schema.name == "customer_records"
assert len(schema.fields) == 3
# Check field properties
id_field = schema.fields[0]
assert id_field.name == "id"
assert id_field.type == "string"
assert id_field.primary is True
@pytest.mark.asyncio
async def test_object_processing_stores_data_map(self):
"""Test that row processing stores data as map<text, text>"""
processor = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
description="Test",
fields=[
Field(name="id", type="string", size=50, primary=True),
Field(name="value", type="string", size=100)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create test object
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="test_schema",
values=[{"id": "123", "value": "test_data"}],
confidence=0.9,
source_span="test source"
)
# Create mock message
msg = MagicMock()
msg.value.return_value = test_obj
# Process object
await processor.on_object(msg, None, None)
# Verify insert was executed
processor.session.execute.assert_called()
insert_call = processor.session.execute.call_args
insert_cql = insert_call[0][0]
values = insert_call[0][1]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
assert values[1] == "test_schema" # schema_name
assert values[2] == "id" # index_name (primary key field)
assert values[3] == ["123"] # index_value as list
assert values[4] == {"id": "123", "value": "test_data"} # data map
assert values[5] == "" # source
@pytest.mark.asyncio
async def test_object_processing_multiple_indexes(self):
"""Test that row is written once per indexed field"""
processor = MagicMock()
processor.schemas = {
"multi_index_schema": RowSchema(
name="multi_index_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True),
Field(name="status", type="string", indexed=True)
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
metadata=[]
),
schema_name="multi_index_schema",
values=[{"id": "123", "category": "electronics", "status": "active"}],
confidence=0.9,
source_span=""
)
msg = MagicMock()
msg.value.return_value = test_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per indexed field: id, category, status)
assert processor.session.execute.call_count == 3
# Check that different index_names were used
index_names_used = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
index_names_used.add(values[2]) # index_name is 3rd value
assert index_names_used == {"id", "category", "status"}
class TestRowsCassandraStorageBatchLogic:
"""Test batch processing logic for unified table implementation"""
@pytest.mark.asyncio
async def test_batch_object_processing(self):
"""Test processing of batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"batch_schema": RowSchema(
name="batch_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="name", type="string")
]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create batch object with multiple values
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
metadata=[]
),
schema_name="batch_schema",
values=[
{"id": "001", "name": "First"},
{"id": "002", "name": "Second"},
{"id": "003", "name": "Third"}
],
confidence=0.95,
source_span=""
)
msg = MagicMock()
msg.value.return_value = batch_obj
await processor.on_object(msg, None, None)
# Should have 3 inserts (one per row, one index per row since only primary key)
assert processor.session.execute.call_count == 3
# Check each insert has different id
ids_inserted = set()
for call in processor.session.execute.call_args_list:
values = call[0][1]
ids_inserted.add(tuple(values[3])) # index_value is 4th value
assert ids_inserted == {("001",), ("002",), ("003",)}
@pytest.mark.asyncio
async def test_empty_batch_processing(self):
"""Test processing of empty batch ExtractedObjects"""
processor = MagicMock()
processor.schemas = {
"empty_schema": RowSchema(
name="empty_schema",
fields=[Field(name="id", type="string", primary=True)]
)
}
processor.tables_initialized = {"test_user"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.ensure_tables = MagicMock()
processor.register_partitions = MagicMock()
processor.collection_exists = MagicMock(return_value=True)
processor.on_object = Processor.on_object.__get__(processor, Processor)
# Create empty batch object
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
metadata=[]
),
schema_name="empty_schema",
values=[], # Empty batch
confidence=1.0,
source_span=""
)
msg = MagicMock()
msg.value.return_value = empty_batch_obj
await processor.on_object(msg, None, None)
# Verify no insert calls for empty batch
processor.session.execute.assert_not_called()
class TestUnifiedTableStructure:
"""Test the unified rows table structure"""
def test_ensure_tables_creates_unified_structure(self):
"""Test that ensure_tables creates the unified rows table"""
processor = MagicMock()
processor.known_keyspaces = {"test_user"}
processor.tables_initialized = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.ensure_keyspace = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should have 2 calls: create rows table + create row_partitions table
assert processor.session.execute.call_count == 2
# Check rows table creation
rows_cql = processor.session.execute.call_args_list[0][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql
assert "collection text" in rows_cql
assert "schema_name text" in rows_cql
assert "index_name text" in rows_cql
assert "index_value frozen<list<text>>" in rows_cql
assert "data map<text, text>" in rows_cql
assert "source text" in rows_cql
assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql
# Check row_partitions table creation
partitions_cql = processor.session.execute.call_args_list[1][0][0]
assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql
assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql
# Verify keyspace added to initialized set
assert "test_user" in processor.tables_initialized
def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent"""
processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized
processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
processor.ensure_tables("test_user")
# Should not execute any CQL since already initialized
processor.session.execute.assert_not_called()
class TestPartitionRegistration:
"""Test partition registration for tracking what's stored"""
def test_register_partitions(self):
"""Test registering partitions for a collection/schema pair"""
processor = MagicMock()
processor.registered_partitions = set()
processor.session = MagicMock()
processor.schemas = {
"test_schema": RowSchema(
name="test_schema",
fields=[
Field(name="id", type="string", primary=True),
Field(name="category", type="string", indexed=True)
]
)
}
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should have 2 inserts (one per index: id, category)
assert processor.session.execute.call_count == 2
# Verify cache was updated
assert ("test_collection", "test_schema") in processor.registered_partitions
def test_register_partitions_idempotent(self):
"""Test that partition registration is idempotent"""
processor = MagicMock()
processor.registered_partitions = {("test_collection", "test_schema")} # Already registered
processor.session = MagicMock()
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor.register_partitions("test_user", "test_collection", "test_schema")
# Should not execute any CQL since already registered
processor.session.execute.assert_not_called()

View file

@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'client')
assert hasattr(processor, 'safety_settings')
assert len(processor.safety_settings) == 4 # 4 safety categories
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False)
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Assert
# Verify Google AI Studio client was called with correct API key
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False)
# Verify processor has the client
assert processor.client == mock_genai_client

View file

@ -101,7 +101,7 @@ from .exceptions import (
LoadError,
LookupError,
NLPQueryError,
ObjectsQueryError,
RowsQueryError,
RequestError,
StructuredQueryError,
UnexpectedError,
@ -161,7 +161,7 @@ __all__ = [
"LoadError",
"LookupError",
"NLPQueryError",
"ObjectsQueryError",
"RowsQueryError",
"RequestError",
"StructuredQueryError",
"UnexpectedError",

View file

@ -115,15 +115,15 @@ class AsyncBulkClient:
async for raw_message in websocket:
yield json.loads(raw_message)
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import objects via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for obj in objects:
await websocket.send(json.dumps(obj))
async for row in rows:
await websocket.send(json.dumps(row))
async def aclose(self) -> None:
"""Close connections"""

View file

@ -708,18 +708,18 @@ class AsyncFlowInstance:
return await self.request("triples", request_data)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs: Any):
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs: Any):
"""
Execute a GraphQL query on stored objects.
Execute a GraphQL query on stored rows.
Queries structured data objects using GraphQL syntax. Supports complex
Queries structured data rows using GraphQL syntax. Supports complex
queries with variables and named operations.
Args:
query: GraphQL query string
user: User identifier
collection: Collection identifier containing objects
collection: Collection identifier containing rows
variables: Optional GraphQL query variables
operation_name: Optional operation name for multi-operation queries
**kwargs: Additional service-specific parameters
@ -743,7 +743,7 @@ class AsyncFlowInstance:
}
'''
result = await flow.objects_query(
result = await flow.rows_query(
query=query,
user="trustgraph",
collection="users",
@ -765,4 +765,4 @@ class AsyncFlowInstance:
request_data["operationName"] = operation_name
request_data.update(kwargs)
return await self.request("objects", request_data)
return await self.request("rows", request_data)

View file

@ -320,9 +320,9 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("triples", self.flow_id, request)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs):
"""GraphQL query"""
async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs):
"""GraphQL query against structured rows"""
request = {
"query": query,
"user": user,
@ -334,7 +334,7 @@ class AsyncSocketFlowInstance:
request["operationName"] = operation_name
request.update(kwargs)
return await self.client._send_request("objects", self.flow_id, request)
return await self.client._send_request("rows", self.flow_id, request)
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
"""Execute MCP tool"""

View file

@ -530,45 +530,45 @@ class BulkClient:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""
Bulk import structured objects into a flow.
Bulk import structured rows into a flow.
Efficiently uploads structured data objects via WebSocket streaming
Efficiently uploads structured data rows via WebSocket streaming
for use in GraphQL queries.
Args:
flow: Flow identifier
objects: Iterator yielding object dictionaries
rows: Iterator yielding row dictionaries
**kwargs: Additional parameters (reserved for future use)
Example:
```python
bulk = api.bulk()
# Generate objects to import
def object_generator():
yield {"id": "obj1", "name": "Object 1", "value": 100}
yield {"id": "obj2", "name": "Object 2", "value": 200}
# ... more objects
# Generate rows to import
def row_generator():
yield {"id": "row1", "name": "Row 1", "value": 100}
yield {"id": "row2", "name": "Row 2", "value": 200}
# ... more rows
bulk.import_objects(
bulk.import_rows(
flow="default",
objects=object_generator()
rows=row_generator()
)
```
"""
self._run_async(self._import_objects_async(flow, objects))
self._run_async(self._import_rows_async(flow, rows))
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of objects import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for obj in objects:
await websocket.send(json.dumps(obj))
for row in rows:
await websocket.send(json.dumps(row))
def close(self) -> None:
"""Close connections"""

View file

@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException):
pass
class ObjectsQueryError(TrustGraphException):
"""Objects query service error"""
class RowsQueryError(TrustGraphException):
"""Rows query service error"""
pass
@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = {
"load-error": LoadError,
"lookup-error": LookupError,
"nlp-query-error": NLPQueryError,
"objects-query-error": ObjectsQueryError,
"rows-query-error": RowsQueryError,
"request-error": RequestError,
"structured-query-error": StructuredQueryError,
"unexpected-error": UnexpectedError,

View file

@ -1001,12 +1001,12 @@ class FlowInstance:
input
)
def objects_query(
def rows_query(
self, query, user="trustgraph", collection="default",
variables=None, operation_name=None
):
"""
Execute a GraphQL query against structured objects in the knowledge graph.
Execute a GraphQL query against structured rows in the knowledge graph.
Queries structured data using GraphQL syntax, allowing complex queries
with filtering, aggregation, and relationship traversal.
@ -1038,7 +1038,7 @@ class FlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
user="trustgraph",
collection="scientists"
@ -1053,7 +1053,7 @@ class FlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
variables={"name": "Marie Curie"}
)
@ -1074,7 +1074,7 @@ class FlowInstance:
input["operation_name"] = operation_name
response = self.request(
"service/objects",
"service/rows",
input
)

View file

@ -789,7 +789,7 @@ class SocketFlowInstance:
return self.client._send_request_sync("triples", self.flow_id, request, False)
def objects_query(
def rows_query(
self,
query: str,
user: str,
@ -799,7 +799,7 @@ class SocketFlowInstance:
**kwargs: Any
) -> Dict[str, Any]:
"""
Execute a GraphQL query against structured objects.
Execute a GraphQL query against structured rows.
Args:
query: GraphQL query string
@ -826,7 +826,7 @@ class SocketFlowInstance:
}
}
'''
result = flow.objects_query(
result = flow.rows_query(
query=query,
user="trustgraph",
collection="scientists"
@ -844,7 +844,7 @@ class SocketFlowInstance:
request["operationName"] = operation_name
request.update(kwargs)
return self.client._send_request_sync("objects", self.flow_id, request, False)
return self.client._send_request_sync("rows", self.flow_id, request, False)
def mcp_tool(
self,

View file

@ -21,7 +21,7 @@ from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
)
from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
@ -113,9 +113,9 @@ TranslatorRegistry.register_service(
)
TranslatorRegistry.register_service(
"objects-query",
ObjectsQueryRequestTranslator(),
ObjectsQueryResponseTranslator()
"rows-query",
RowsQueryRequestTranslator(),
RowsQueryResponseTranslator()
)
TranslatorRegistry.register_service(

View file

@ -17,5 +17,5 @@ from .embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
)
from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator

View file

@ -1,44 +1,44 @@
from typing import Dict, Any, Tuple, Optional
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
from ...schema import RowsQueryRequest, RowsQueryResponse
from .base import MessageTranslator
import json
class ObjectsQueryRequestTranslator(MessageTranslator):
"""Translator for ObjectsQueryRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryRequest:
return ObjectsQueryRequest(
class RowsQueryRequestTranslator(MessageTranslator):
"""Translator for RowsQueryRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryRequest:
return RowsQueryRequest(
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
query=data.get("query", ""),
variables=data.get("variables", {}),
operation_name=data.get("operation_name", None)
)
def from_pulsar(self, obj: ObjectsQueryRequest) -> Dict[str, Any]:
def from_pulsar(self, obj: RowsQueryRequest) -> Dict[str, Any]:
result = {
"user": obj.user,
"collection": obj.collection,
"query": obj.query,
"variables": dict(obj.variables) if obj.variables else {}
}
if obj.operation_name:
result["operation_name"] = obj.operation_name
return result
class ObjectsQueryResponseTranslator(MessageTranslator):
"""Translator for ObjectsQueryResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryResponse:
class RowsQueryResponseTranslator(MessageTranslator):
"""Translator for RowsQueryResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: ObjectsQueryResponse) -> Dict[str, Any]:
def from_pulsar(self, obj: RowsQueryResponse) -> Dict[str, Any]:
result = {}
# Handle GraphQL response data
if obj.data:
try:
@ -47,7 +47,7 @@ class ObjectsQueryResponseTranslator(MessageTranslator):
result["data"] = obj.data
else:
result["data"] = None
# Handle GraphQL errors
if obj.errors:
result["errors"] = []
@ -60,20 +60,20 @@ class ObjectsQueryResponseTranslator(MessageTranslator):
if error.extensions:
error_dict["extensions"] = dict(error.extensions)
result["errors"].append(error_dict)
# Handle extensions
if obj.extensions:
result["extensions"] = dict(obj.extensions)
# Handle system-level error
if obj.error:
result["error"] = {
"type": obj.error.type,
"message": obj.error.message
}
return result
def from_response_with_completion(self, obj: ObjectsQueryResponse) -> Tuple[Dict[str, Any], bool]:
def from_response_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
return self.from_pulsar(obj), True

View file

@ -60,3 +60,23 @@ class StructuredObjectEmbedding:
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
############################################################################
# Row embeddings are embeddings associated with indexed field values
# in structured row data. Each index gets embedded separately.
@dataclass
class RowIndexEmbedding:
"""Single row's embedding for one index"""
index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list)
@dataclass
class RowEmbeddings:
"""Batched row embeddings for a schema"""
metadata: Metadata | None = None
schema_name: str = ""
embeddings: list[RowIndexEmbedding] = field(default_factory=list)
############################################################################

View file

@ -9,7 +9,7 @@ from .library import *
from .lookup import *
from .nlp_query import *
from .structured_query import *
from .objects_query import *
from .rows_query import *
from .diagnosis import *
from .collection import *
from .storage import *

View file

@ -59,4 +59,39 @@ document_embeddings_request_queue = topic(
)
document_embeddings_response_queue = topic(
"document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow'
)
############################################################################
# Row embeddings query - for semantic/fuzzy matching on row index values
@dataclass
class RowIndexMatch:
"""A single matching row index from a semantic search"""
index_name: str = "" # The indexed field(s)
index_value: list[str] = field(default_factory=list) # The index values
text: str = "" # The text that was embedded
score: float = 0.0 # Similarity score
@dataclass
class RowEmbeddingsRequest:
"""Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors
limit: int = 10 # Max results to return
user: str = "" # User/keyspace
collection: str = "" # Collection name
schema_name: str = "" # Schema name to search within
index_name: str | None = None # Optional: filter to specific index
@dataclass
class RowEmbeddingsResponse:
"""Response from row embeddings semantic search"""
error: Error | None = None
matches: list[RowIndexMatch] = field(default_factory=list)
row_embeddings_request_queue = topic(
"row-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
)
row_embeddings_response_queue = topic(
"row-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow'
)

View file

@ -6,7 +6,7 @@ from ..core.topic import topic
############################################################################
# Objects Query Service - executes GraphQL queries against structured data
# Rows Query Service - executes GraphQL queries against structured data
@dataclass
class GraphQLError:
@ -15,7 +15,7 @@ class GraphQLError:
extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata
@dataclass
class ObjectsQueryRequest:
class RowsQueryRequest:
user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest)
collection: str = "" # Data collection identifier (required for partition key)
query: str = "" # GraphQL query string
@ -23,7 +23,7 @@ class ObjectsQueryRequest:
operation_name: Optional[str] = None # Operation to execute for multi-operation documents
@dataclass
class ObjectsQueryResponse:
class RowsQueryResponse:
error: Error | None = None # System-level error (connection, timeout, etc.)
data: str = "" # JSON-encoded GraphQL response data
errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors

View file

@ -48,7 +48,7 @@ tg-invoke-graph-embeddings = "trustgraph.cli.invoke_graph_embeddings:main"
tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main"
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
tg-invoke-objects-query = "trustgraph.cli.invoke_objects_query:main"
tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"

View file

@ -1,5 +1,5 @@
"""
Uses the ObjectsQuery service to execute GraphQL queries against structured data
Uses the RowsQuery service to execute GraphQL queries against structured data
"""
import argparse
@ -81,7 +81,7 @@ def format_table_data(rows, table_name, output_format):
else:
return json.dumps({table_name: rows}, indent=2)
def objects_query(
def rows_query(
url, flow_id, query, user, collection, variables, operation_name, output_format='table'
):
@ -96,7 +96,7 @@ def objects_query(
print(f"Error parsing variables JSON: {e}", file=sys.stderr)
sys.exit(1)
resp = api.objects_query(
resp = api.rows_query(
query=query,
user=user,
collection=collection,
@ -126,7 +126,7 @@ def objects_query(
def main():
parser = argparse.ArgumentParser(
prog='tg-invoke-objects-query',
prog='tg-invoke-rows-query',
description=__doc__,
)
@ -181,7 +181,7 @@ def main():
try:
objects_query(
rows_query(
url=args.url,
flow_id=args.flow_id,
query=args.query,

View file

@ -573,19 +573,19 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample
return output_records, descriptor
def _send_to_trustgraph(objects, api_url, flow, batch_size=1000, token=None):
def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None):
"""Send ExtractedObject records to TrustGraph using Python API"""
from trustgraph.api import Api
try:
total_records = len(objects)
total_records = len(rows)
logger.info(f"Importing {total_records} records to TrustGraph...")
# Use Python API bulk import
api = Api(api_url, token=token)
bulk = api.bulk()
bulk.import_objects(flow=flow, objects=iter(objects))
bulk.import_rows(flow=flow, rows=iter(rows))
logger.info(f"Successfully imported {total_records} records to TrustGraph")

View file

@ -60,27 +60,27 @@ api-gateway = "trustgraph.gateway:run"
chunker-recursive = "trustgraph.chunking.recursive:run"
chunker-token = "trustgraph.chunking.token:run"
config-svc = "trustgraph.config.service:run"
de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run"
doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run"
doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run"
doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run"
doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run"
doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run"
document-embeddings = "trustgraph.embeddings.document_embeddings:run"
document-rag = "trustgraph.retrieval.document_rag:run"
embeddings-fastembed = "trustgraph.embeddings.fastembed:run"
embeddings-ollama = "trustgraph.embeddings.ollama:run"
ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run"
graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run"
graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run"
graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run"
graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run"
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-objects = "trustgraph.extract.kg.objects:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"
kg-extract-relationships = "trustgraph.extract.kg.relationships:run"
kg-extract-topics = "trustgraph.extract.kg.topics:run"
kg-extract-ontology = "trustgraph.extract.kg.ontology:run"
@ -90,8 +90,11 @@ librarian = "trustgraph.librarian:run"
mcp-tool = "trustgraph.agent.mcp_tool:run"
metering = "trustgraph.metering:run"
nlp-query = "trustgraph.retrieval.nlp_query:run"
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
rows-write-cassandra = "trustgraph.storage.rows.cassandra:run"
rows-query-cassandra = "trustgraph.query.rows.cassandra:run"
row-embeddings = "trustgraph.embeddings.row_embeddings:run"
row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run"
row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run"
pdf-decoder = "trustgraph.decoding.pdf:run"
pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run"
prompt-template = "trustgraph.prompt.template:run"

View file

@ -0,0 +1,3 @@
from . embeddings import *

View file

@ -0,0 +1,6 @@
from . embeddings import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,263 @@
"""
Row embeddings processor. Calls the embeddings service to compute embeddings
for indexed field values in extracted row data.
Input is ExtractedObject (structured row data with schema).
Output is RowEmbeddings (row data with embeddings for indexed fields).
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (this processor): Compute embeddings
Stage 2 (row-embeddings-write-*): Store embeddings
"""
import json
import logging
from typing import Dict, List, Set
from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding
from ... schema import RowSchema, Field
from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec
from ... base import ProducerSpec, CollectionConfigHandler
logger = logging.getLogger(__name__)
default_ident = "row-embeddings"
default_batch_size = 10
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
self.batch_size = params.get("batch_size", default_batch_size)
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_message,
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name="embeddings-request",
response_name="embeddings-response",
)
)
self.register_specification(
ProducerSpec(
name="output",
schema=RowEmbeddings
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""Build the index_value list for a given index."""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
values.append(str(value) if value is not None else "")
return values
def build_text_for_embedding(self, index_value: List[str]) -> str:
"""Build text representation for embedding from index values."""
# Space-join the values for composite indexes
return " ".join(index_value)
async def on_message(self, msg, consumer, flow):
"""Process incoming ExtractedObject and compute embeddings"""
obj = msg.value()
logger.info(
f"Computing embeddings for {len(obj.values)} rows, "
f"schema {obj.schema_name}, doc {obj.metadata.id}"
)
# Validate collection exists before processing
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
logger.warning(
f"Collection {obj.metadata.collection} for user {obj.metadata.user} "
f"does not exist in config. Dropping message."
)
return
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping")
return
# Track unique texts to avoid duplicate embeddings
# text -> (index_name, index_value)
texts_to_embed: Dict[str, tuple] = {}
# Collect all texts that need embeddings
for value_map in obj.values:
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip empty values
if not index_value or all(v == "" for v in index_value):
continue
text = self.build_text_for_embedding(index_value)
if text and text not in texts_to_embed:
texts_to_embed[text] = (index_name, index_value)
if not texts_to_embed:
logger.info("No texts to embed")
return
# Compute embeddings
embeddings_list = []
try:
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
embeddings_list.append(
RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
)
)
# Send in batches to avoid oversized messages
for i in range(0, len(embeddings_list), self.batch_size):
batch = embeddings_list[i:i + self.batch_size]
result = RowEmbeddings(
metadata=obj.metadata,
schema_name=obj.schema_name,
embeddings=batch,
)
await flow("output").send(result)
logger.info(
f"Computed {len(embeddings_list)} embeddings for "
f"{len(obj.values)} rows ({len(index_names)} indexes)"
)
except Exception as e:
logger.error("Exception during embedding computation", exc_info=True)
raise e
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection notification for {user}/{collection}")
async def delete_collection(self, user: str, collection: str):
"""Collection deletion notification - no action needed for embedding stage"""
logger.debug(f"Row embeddings collection delete notification for {user}/{collection}")
@staticmethod
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--batch-size',
type=int,
default=default_batch_size,
help=f'Maximum embeddings per output message (default: {default_batch_size})'
)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -1,5 +1,5 @@
"""
Object extraction service - extracts structured objects from text chunks
Row extraction service - extracts structured rows from text chunks
based on configured schemas.
"""
@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... messaging.translators import row_schema_translator
default_ident = "kg-extract-objects"
default_ident = "kg-extract-rows"
def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]:
@ -310,5 +310,5 @@ class Processor(FlowProcessor):
FlowProcessor.add_args(parser)
def run():
"""Entry point for kg-extract-objects command"""
"""Entry point for kg-extract-rows command"""
Processor.launch(default_ident, __doc__)

View file

@ -20,7 +20,7 @@ from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . objects_query import ObjectsQueryRequestor
from . rows_query import RowsQueryRequestor
from . nlp_query import NLPQueryRequestor
from . structured_query import StructuredQueryRequestor
from . structured_diag import StructuredDiagRequestor
@ -40,7 +40,7 @@ from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . entity_contexts_import import EntityContextsImport
from . objects_import import ObjectsImport
from . rows_import import RowsImport
from . core_export import CoreExport
from . core_import import CoreImport
@ -58,7 +58,7 @@ request_response_dispatchers = {
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"document-embeddings": DocumentEmbeddingsQueryRequestor,
"triples": TriplesQueryRequestor,
"objects": ObjectsQueryRequestor,
"rows": RowsQueryRequestor,
"nlp-query": NLPQueryRequestor,
"structured-query": StructuredQueryRequestor,
"structured-diag": StructuredDiagRequestor,
@ -89,7 +89,7 @@ import_dispatchers = {
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
"entity-contexts": EntityContextsImport,
"objects": ObjectsImport,
"rows": RowsImport,
}
class DispatcherWrapper:

View file

@ -12,7 +12,7 @@ from . serialize import to_subgraph
# Module logger
logger = logging.getLogger(__name__)
class ObjectsImport:
class RowsImport:
def __init__(
self, ws, running, backend, queue
@ -20,7 +20,7 @@ class ObjectsImport:
self.ws = ws
self.running = running
self.publisher = Publisher(
backend, topic = queue, schema = ExtractedObject
)
@ -73,4 +73,4 @@ class ObjectsImport:
if self.ws:
await self.ws.close()
self.ws = None
self.ws = None

View file

@ -1,30 +1,30 @@
from ... schema import ObjectsQueryRequest, ObjectsQueryResponse
from ... schema import RowsQueryRequest, RowsQueryResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class ObjectsQueryRequestor(ServiceRequestor):
class RowsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(ObjectsQueryRequestor, self).__init__(
super(RowsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=ObjectsQueryRequest,
response_schema=ObjectsQueryResponse,
request_schema=RowsQueryRequest,
response_schema=RowsQueryResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("objects-query")
self.response_translator = TranslatorRegistry.get_response_translator("objects-query")
self.request_translator = TranslatorRegistry.get_request_translator("rows-query")
self.response_translator = TranslatorRegistry.get_response_translator("rows-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,22 @@
"""
Shared GraphQL utilities for row query services.
This module provides reusable GraphQL components including:
- Filter types (IntFilter, StringFilter, FloatFilter)
- Dynamic schema generation from RowSchema definitions
- Filter parsing utilities
"""
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
from .schema import GraphQLSchemaBuilder
from .filters import parse_filter_key, parse_where_clause
__all__ = [
"IntFilter",
"StringFilter",
"FloatFilter",
"SortDirection",
"GraphQLSchemaBuilder",
"parse_filter_key",
"parse_where_clause",
]

View file

@ -0,0 +1,104 @@
"""
Filter parsing utilities for GraphQL row queries.
Provides functions to parse GraphQL filter objects into a normalized
format that can be used by different query backends.
"""
import logging
from typing import Dict, Any, Tuple
logger = logging.getLogger(__name__)
def parse_filter_key(filter_key: str) -> Tuple[str, str]:
"""
Parse GraphQL filter key into field name and operator.
Supports common GraphQL filter patterns:
- field_name -> (field_name, "eq")
- field_name_gt -> (field_name, "gt")
- field_name_gte -> (field_name, "gte")
- field_name_lt -> (field_name, "lt")
- field_name_lte -> (field_name, "lte")
- field_name_in -> (field_name, "in")
Args:
filter_key: The filter key string from GraphQL
Returns:
Tuple of (field_name, operator)
"""
if not filter_key:
return ("", "eq")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
def parse_where_clause(where_obj) -> Dict[str, Any]:
"""
Parse the idiomatic nested GraphQL filter structure into a flat dict.
Converts Strawberry filter objects (StringFilter, IntFilter, etc.)
into a dictionary mapping field names with operators to values.
Example:
Input: where_obj with email.eq = "foo@bar.com"
Output: {"email": "foo@bar.com"}
Input: where_obj with age.gt = 21
Output: {"age_gt": 21}
Args:
where_obj: The GraphQL where clause object
Returns:
Dictionary mapping field_operator keys to values
"""
if not where_obj:
return {}
conditions = {}
logger.debug(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.debug(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
elif operator == "startsWith":
conditions[f"{field_name}_startsWith"] = value
elif operator == "endsWith":
conditions[f"{field_name}_endsWith"] = value
elif operator == "not_":
conditions[f"{field_name}_not"] = value
elif operator == "not_in":
conditions[f"{field_name}_not_in"] = value
logger.debug(f"Final parsed conditions: {conditions}")
return conditions

View file

@ -0,0 +1,251 @@
"""
Dynamic GraphQL schema generation from RowSchema definitions.
Provides a builder class that creates Strawberry GraphQL schemas
from TrustGraph RowSchema definitions, with pluggable query backends.
"""
import logging
from typing import Dict, Any, Optional, List, Callable, Awaitable
import strawberry
from strawberry import Schema
from strawberry.types import Info
from .types import IntFilter, StringFilter, FloatFilter, SortDirection
logger = logging.getLogger(__name__)
# Type alias for query callback function
QueryCallback = Callable[
[str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]],
Awaitable[List[Dict[str, Any]]]
]
class GraphQLSchemaBuilder:
"""
Builds GraphQL schemas from RowSchema definitions.
This class extracts the GraphQL schema generation logic so it can be
reused across different query backends (Cassandra, etc.).
Usage:
builder = GraphQLSchemaBuilder()
# Add schemas
for name, row_schema in schemas.items():
builder.add_schema(name, row_schema)
# Build with a query callback
schema = builder.build(query_callback)
"""
def __init__(self):
self.schemas: Dict[str, Any] = {} # name -> RowSchema
self.graphql_types: Dict[str, type] = {}
self.filter_types: Dict[str, type] = {}
def add_schema(self, name: str, row_schema) -> None:
"""
Add a RowSchema to the builder.
Args:
name: The schema name (used as the GraphQL query field name)
row_schema: The RowSchema object defining fields
"""
self.schemas[name] = row_schema
self.graphql_types[name] = self._create_graphql_type(name, row_schema)
self.filter_types[name] = self._create_filter_type(name, row_schema)
logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields")
def clear(self) -> None:
"""Clear all schemas from the builder."""
self.schemas = {}
self.graphql_types = {}
self.filter_types = {}
def build(self, query_callback: QueryCallback) -> Optional[Schema]:
"""
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- collection: str
- schema_name: str
- row_schema: RowSchema
- filters: Dict[str, Any]
- limit: int
- order_by: Optional[str]
- direction: Optional[SortDirection]
It should return a list of row dictionaries.
Args:
query_callback: Async function to execute queries
Returns:
Strawberry Schema, or None if no schemas are loaded
"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
return None
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = self.filter_types[schema_name]
# Create resolver function for this schema
resolver_func = self._make_resolver(
schema_name, row_schema, graphql_type, filter_type, query_callback
)
# Add field to query dictionary
query_dict[schema_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][schema_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
return schema
def _get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL."""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def _create_graphql_type(self, schema_name: str, row_schema) -> type:
"""Create a GraphQL output type from a RowSchema."""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self._get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def _create_filter_type(self, schema_name: str, row_schema) -> type:
"""Create a dynamic filter input type for a schema."""
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.debug(
f"Field {field.name}: type={field.type}, "
f"indexed={field.indexed}, primary={field.primary}"
)
# Allow filtering on any field
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.debug(
f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}"
)
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def _make_resolver(
self,
schema_name: str,
row_schema,
graphql_type: type,
filter_type: type,
query_callback: QueryCallback
):
"""Create a resolver function for a schema."""
from .filters import parse_where_clause
async def resolver(
info: Info,
where: Optional[filter_type] = None,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None,
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
collection = info.context["collection"]
# Parse the where clause
filters = parse_where_clause(where)
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = graphql_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver

View file

@ -0,0 +1,56 @@
"""
GraphQL filter and sort types for row queries.
These types are used to build dynamic GraphQL schemas for querying
structured row data.
"""
from typing import Optional, List
from enum import Enum
import strawberry
@strawberry.input
class IntFilter:
"""Filter type for integer fields."""
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
"""Filter type for string fields."""
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
"""Filter type for float fields."""
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
@strawberry.enum
class SortDirection(Enum):
"""Sort direction for query results."""
ASC = "asc"
DESC = "desc"

View file

@ -1,738 +0,0 @@
"""
Objects query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional, List, Set
from enum import Enum
from dataclasses import dataclass, field
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import strawberry
from strawberry import Schema
from strawberry.types import Info
from strawberry.scalars import JSON
from strawberry.tools import create_type
from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-query"
# GraphQL filter input types
@strawberry.input
class IntFilter:
eq: Optional[int] = None
gt: Optional[int] = None
gte: Optional[int] = None
lt: Optional[int] = None
lte: Optional[int] = None
in_: Optional[List[int]] = strawberry.field(name="in", default=None)
not_: Optional[int] = strawberry.field(name="not", default=None)
not_in: Optional[List[int]] = None
@strawberry.input
class StringFilter:
eq: Optional[str] = None
contains: Optional[str] = None
startsWith: Optional[str] = None
endsWith: Optional[str] = None
in_: Optional[List[str]] = strawberry.field(name="in", default=None)
not_: Optional[str] = strawberry.field(name="not", default=None)
not_in: Optional[List[str]] = None
@strawberry.input
class FloatFilter:
eq: Optional[float] = None
gt: Optional[float] = None
gte: Optional[float] = None
lt: Optional[float] = None
lte: Optional[float] = None
in_: Optional[List[float]] = strawberry.field(name="in", default=None)
not_: Optional[float] = strawberry.field(name="not", default=None)
not_in: Optional[List[float]] = None
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "request",
schema = ObjectsQueryRequest,
handler = self.on_message
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = ObjectsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema
self.graphql_schema: Optional[Schema] = None
# GraphQL types cache
self.graphql_types: Dict[str, type] = {}
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces and tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {}
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize table names for Cassandra compatibility"""
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
safe_name = 'o_' + safe_name
return safe_name.lower()
def parse_filter_key(self, filter_key: str) -> tuple[str, str]:
"""Parse GraphQL filter key into field name and operator"""
if not filter_key:
return ("", "eq")
# Support common GraphQL filter patterns:
# field_name -> (field_name, "eq")
# field_name_gt -> (field_name, "gt")
# field_name_gte -> (field_name, "gte")
# field_name_lt -> (field_name, "lt")
# field_name_lte -> (field_name, "lte")
# field_name_in -> (field_name, "in")
operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"]
for op_suffix in operators:
if filter_key.endswith(op_suffix):
field_name = filter_key[:-len(op_suffix)]
operator = op_suffix[1:] # Remove the leading underscore
return (field_name, operator)
# Default to equality if no operator suffix found
return (filter_key, "eq")
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.graphql_types = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.generate_graphql_schema()
def get_python_type(self, field_type: str):
"""Convert schema field type to Python type for GraphQL"""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"timestamp": str, # Use string for timestamps in GraphQL
"date": str,
"time": str,
"uuid": str
}
return type_mapping.get(field_type, str)
def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type:
"""Create a GraphQL type from a RowSchema"""
# Create annotations for the GraphQL type
annotations = {}
defaults = {}
for field in row_schema.fields:
python_type = self.get_python_type(field.type)
# Make field optional if not required
if not field.required and not field.primary:
annotations[field.name] = Optional[python_type]
defaults[field.name] = None
else:
annotations[field.name] = python_type
# Create the class dynamically
type_name = f"{schema_name.capitalize()}Type"
graphql_class = type(
type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry decorator
return strawberry.type(graphql_class)
def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema):
"""Create a dynamic filter input type for a schema"""
# Create the filter type dynamically
filter_type_name = f"{schema_name.capitalize()}Filter"
# Add __annotations__ and defaults for the fields
annotations = {}
defaults = {}
logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}")
for field in row_schema.fields:
logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}")
# Allow filtering on any field for now, not just indexed/primary
# if field.indexed or field.primary:
if field.type == "integer":
annotations[field.name] = Optional[IntFilter]
defaults[field.name] = None
logger.info(f"Added IntFilter for {field.name}")
elif field.type == "float":
annotations[field.name] = Optional[FloatFilter]
defaults[field.name] = None
logger.info(f"Added FloatFilter for {field.name}")
elif field.type == "string":
annotations[field.name] = Optional[StringFilter]
defaults[field.name] = None
logger.info(f"Added StringFilter for {field.name}")
logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}")
# Create the class dynamically
FilterType = type(
filter_type_name,
(),
{
"__annotations__": annotations,
**defaults
}
)
# Apply strawberry input decorator
FilterType = strawberry.input(FilterType)
return FilterType
def create_sort_direction_enum(self):
"""Create sort direction enum"""
@strawberry.enum
class SortDirection(Enum):
ASC = "asc"
DESC = "desc"
return SortDirection
def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]:
"""Parse the idiomatic nested filter structure"""
if not where_obj:
return {}
conditions = {}
logger.info(f"Parsing where clause: {where_obj}")
for field_name, filter_obj in where_obj.__dict__.items():
if filter_obj is None:
continue
logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}")
if hasattr(filter_obj, '__dict__'):
# This is a filter object (StringFilter, IntFilter, etc.)
for operator, value in filter_obj.__dict__.items():
if value is not None:
logger.info(f"Found operator {operator} with value {value}")
# Map GraphQL operators to our internal format
if operator == "eq":
conditions[field_name] = value
elif operator in ["gt", "gte", "lt", "lte"]:
conditions[f"{field_name}_{operator}"] = value
elif operator == "in_":
conditions[f"{field_name}_in"] = value
elif operator == "contains":
conditions[f"{field_name}_contains"] = value
logger.info(f"Final parsed conditions: {conditions}")
return conditions
def generate_graphql_schema(self):
"""Generate GraphQL schema from loaded schemas using dynamic filter types"""
if not self.schemas:
logger.warning("No schemas loaded, cannot generate GraphQL schema")
self.graphql_schema = None
return
# Create GraphQL types and filter types for each schema
filter_types = {}
sort_direction_enum = self.create_sort_direction_enum()
for schema_name, row_schema in self.schemas.items():
graphql_type = self.create_graphql_type(schema_name, row_schema)
filter_type = self.create_filter_type_for_schema(schema_name, row_schema)
self.graphql_types[schema_name] = graphql_type
filter_types[schema_name] = filter_type
# Create the Query class with resolvers
query_dict = {'__annotations__': {}}
for schema_name, row_schema in self.schemas.items():
graphql_type = self.graphql_types[schema_name]
filter_type = filter_types[schema_name]
# Create resolver function for this schema
def make_resolver(s_name, r_schema, g_type, f_type, sort_enum):
async def resolver(
info: Info,
where: Optional[f_type] = None,
order_by: Optional[str] = None,
direction: Optional[sort_enum] = None,
limit: Optional[int] = 100
) -> List[g_type]:
# Get the processor instance from context
processor = info.context["processor"]
user = info.context["user"]
collection = info.context["collection"]
# Parse the idiomatic where clause
filters = processor.parse_idiomatic_where_clause(where)
# Query Cassandra
results = await processor.query_cassandra(
user, collection, s_name, r_schema,
filters, limit, order_by, direction
)
# Convert to GraphQL types
graphql_results = []
for row in results:
graphql_obj = g_type(**row)
graphql_results.append(graphql_obj)
return graphql_results
return resolver
# Add resolver to query
resolver_name = schema_name
resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum)
# Add field to query dictionary
query_dict[resolver_name] = strawberry.field(resolver=resolver_func)
query_dict['__annotations__'][resolver_name] = List[graphql_type]
# Create the Query class
Query = type('Query', (), query_dict)
Query = strawberry.type(Query)
# Create the schema with auto_camel_case disabled to keep snake_case field names
self.graphql_schema = strawberry.Schema(
query=Query,
config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False)
)
logger.info(f"Generated GraphQL schema with {len(self.schemas)} types")
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[Any] = None
) -> List[Dict[str, Any]]:
"""Execute a query against Cassandra"""
# Connect if needed
self.connect_cassandra()
# Build the query
keyspace = self.sanitize_name(user)
table = self.sanitize_table(schema_name)
# Start with basic SELECT
query = f"SELECT * FROM {keyspace}.{table}"
# Add WHERE clauses
where_clauses = [f"collection = %s"]
params = [collection]
# Add filters for indexed or primary key fields
for filter_key, value in filters.items():
if value is not None:
# Parse field name and operator from filter key
logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})")
result = self.parse_filter_key(filter_key)
logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})")
if not result or len(result) != 2:
logger.error(f"parse_filter_key returned invalid result: {result}")
continue # Skip this filter
field_name, operator = result
# Find the field in schema
schema_field = None
for f in row_schema.fields:
if f.name == field_name:
schema_field = f
break
if schema_field:
safe_field = self.sanitize_name(field_name)
# Build WHERE clause based on operator
if operator == "eq":
where_clauses.append(f"{safe_field} = %s")
params.append(value)
elif operator == "gt":
where_clauses.append(f"{safe_field} > %s")
params.append(value)
elif operator == "gte":
where_clauses.append(f"{safe_field} >= %s")
params.append(value)
elif operator == "lt":
where_clauses.append(f"{safe_field} < %s")
params.append(value)
elif operator == "lte":
where_clauses.append(f"{safe_field} <= %s")
params.append(value)
elif operator == "in":
if isinstance(value, list):
placeholders = ",".join(["%s"] * len(value))
where_clauses.append(f"{safe_field} IN ({placeholders})")
params.extend(value)
else:
# Default to equality for unknown operators
where_clauses.append(f"{safe_field} = %s")
params.append(value)
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
# Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort)
cassandra_order_by_added = False
if order_by and direction:
# Validate that order_by field exists in schema
order_field_exists = any(f.name == order_by for f in row_schema.fields)
if order_field_exists:
safe_order_field = self.sanitize_name(order_by)
direction_str = "ASC" if direction.value == "asc" else "DESC"
# Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution
query += f" ORDER BY {safe_order_field} {direction_str}"
# Add limit first (must come before ALLOW FILTERING)
if limit:
query += f" LIMIT {limit}"
# Add ALLOW FILTERING for now (should optimize with proper indexes later)
query += " ALLOW FILTERING"
# Execute query
try:
result = self.session.execute(query, params)
cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY
except Exception as e:
# If ORDER BY fails, try without it
if order_by and direction and "ORDER BY" in query:
logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}")
# Remove ORDER BY clause and retry
query_parts = query.split(" ORDER BY ")
if len(query_parts) == 2:
query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING"
result = self.session.execute(query_without_order, params)
cassandra_order_by_added = False
else:
raise
else:
raise
# Convert rows to dicts
results = []
for row in result:
row_dict = {}
for field in row_schema.fields:
safe_field = self.sanitize_name(field.name)
if hasattr(row, safe_field):
value = getattr(row, safe_field)
# Use original field name in result
row_dict[field.name] = value
results.append(row_dict)
# Post-query sorting if Cassandra didn't handle ORDER BY
if order_by and direction and not cassandra_order_by_added:
reverse_order = (direction.value == "desc")
try:
results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = ObjectsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in objects query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = ObjectsQueryResponse(
error = Error(
type = "objects-query-error",
message = str(e),
),
data = None,
errors = [],
extensions = {}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-query-graphql-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings query modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant row embeddings query service.
"""
from .service import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .service import run
run()

View file

@ -0,0 +1,209 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
import logging
import re
from typing import Optional
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from .... schema import (
RowEmbeddingsRequest, RowEmbeddingsResponse,
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowEmbeddingsRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowEmbeddingsResponse
)
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
try:
all_collections = self.qdrant.get_collections().collections
matching = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if matching:
# Return first match (there should typically be only one per dimension)
return matching[0]
except Exception as e:
logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True)
return None
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
request.user, request.collection, request.schema_name
)
if not qdrant_collection:
logger.info(
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
)
# Execute query
matches = await self.query_row_embeddings(request)
response = RowEmbeddingsResponse(
error=None,
matches=matches
)
logger.debug(f"Returning {len(matches)} matches")
await flow("response").send(response, properties={"id": id})
except Exception as e:
logger.error(f"Exception in row embeddings query: {e}", exc_info=True)
response = RowEmbeddingsResponse(
error=Error(
type="row-embeddings-query-error",
message=str(e)
),
matches=[]
)
await flow("response").send(response, properties={"id": id})
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
def run():
"""Entry point for row-embeddings-query-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,523 @@
"""
Row query service using GraphQL. Input is a GraphQL query with variables.
Output is GraphQL response data with any errors.
Queries against the unified 'rows' table with schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
"""
import json
import logging
import re
from typing import Dict, Any, Optional, List, Set
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from ... graphql import GraphQLSchemaBuilder, SortDirection
# Module logger
logger = logging.getLogger(__name__)
default_ident = "rows-query"
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message
)
)
self.register_specification(
ProducerSpec(
name="response",
schema=RowsQueryResponse,
)
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# GraphQL schema builder and generated schema
self.schema_builder = GraphQLSchemaBuilder()
self.graphql_schema = None
# Cassandra session
self.cluster = None
self.session = None
# Known keyspaces
self.known_keyspaces: Set[str] = set()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
self.schema_builder.clear()
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
self.schema_builder.add_schema(schema_name, row_schema)
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Regenerate GraphQL schema
self.graphql_schema = self.schema_builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
"""Get all index names for a schema."""
index_names = []
for field in schema.fields:
if field.primary or field.indexed:
index_names.append(field.name)
return index_names
def find_matching_index(
self,
schema: RowSchema,
filters: Dict[str, Any]
) -> Optional[tuple]:
"""
Find an index that can satisfy the query filters.
Returns (index_name, index_value) if found, None otherwise.
For exact match queries, we need a filter on an indexed field.
"""
index_names = self.get_index_names(schema)
# Look for an exact match filter on an indexed field
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
return None
async def query_cassandra(
self,
user: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
filters: Dict[str, Any],
limit: int,
order_by: Optional[str] = None,
direction: Optional[SortDirection] = None
) -> List[Dict[str, Any]]:
"""
Execute a query against the unified Cassandra rows table.
For exact match queries on indexed fields, we can query directly.
For other queries, we need to scan and post-filter.
"""
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
results = []
if index_match:
# Direct query using index
index_name, index_value = index_match
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
AND index_value = %s
"""
params = [collection, schema_name, index_name, index_value]
if limit:
query += f" LIMIT {limit}"
try:
rows = self.session.execute(query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
else:
# No direct index match - scan all rows for this schema
# This is less efficient but necessary for non-indexed queries
logger.warning(
f"No index match for filters {filters} - scanning all indexes"
)
# Get all index names for this schema
index_names = self.get_index_names(row_schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexes")
return []
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
AND schema_name = %s
AND index_name = %s
ALLOW FILTERING
"""
params = [collection, schema_name, primary_index]
try:
rows = self.session.execute(query, params)
for row in rows:
row_dict = dict(row.data) if row.data else {}
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
raise
# Post-query sorting if requested
if order_by and results:
reverse_order = direction and direction.value == "desc"
try:
results.sort(
key=lambda x: x.get(order_by, ""),
reverse=reverse_order
)
except Exception as e:
logger.warning(f"Failed to sort results by {order_by}: {e}")
return results
def _matches_filters(
self,
row_dict: Dict[str, Any],
filters: Dict[str, Any],
row_schema: RowSchema
) -> bool:
"""Check if a row matches the given filters."""
for filter_key, filter_value in filters.items():
if filter_value is None:
continue
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
field_name = parts[0]
operator = parts[1]
else:
field_name = filter_key
operator = 'eq'
else:
field_name = filter_key
operator = 'eq'
row_value = row_dict.get(field_name)
if row_value is None:
return False
# Convert types for comparison
try:
if operator == 'eq':
if str(row_value) != str(filter_value):
return False
elif operator == 'gt':
if float(row_value) <= float(filter_value):
return False
elif operator == 'gte':
if float(row_value) < float(filter_value):
return False
elif operator == 'lt':
if float(row_value) >= float(filter_value):
return False
elif operator == 'lte':
if float(row_value) > float(filter_value):
return False
elif operator == 'contains':
if str(filter_value) not in str(row_value):
return False
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False
return True
async def execute_graphql_query(
self,
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query"""
if not self.graphql_schema:
raise RuntimeError("No GraphQL schema available - no schemas loaded")
# Create context for the query
context = {
"processor": self,
"user": user,
"collection": collection
}
# Execute the query
result = await self.graphql_schema.execute(
query,
variable_values=variables,
operation_name=operation_name,
context_value=context
)
# Build response
response = {}
if result.data:
response["data"] = result.data
else:
response["data"] = None
if result.errors:
response["errors"] = [
{
"message": str(error),
"path": getattr(error, "path", []),
"extensions": getattr(error, "extensions", {})
}
for error in result.errors
]
else:
response["errors"] = []
# Add extensions if any
if hasattr(result, "extensions") and result.extensions:
response["extensions"] = result.extensions
return response
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
try:
request = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling objects query request {id}...")
# Execute GraphQL query
result = await self.execute_graphql_query(
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)
# Create response
graphql_errors = []
if "errors" in result and result["errors"]:
for err in result["errors"]:
graphql_error = GraphQLError(
message=err.get("message", ""),
path=err.get("path", []),
extensions=err.get("extensions", {})
)
graphql_errors.append(graphql_error)
response = RowsQueryResponse(
error=None,
data=json.dumps(result.get("data")) if result.get("data") else "null",
errors=graphql_errors,
extensions=result.get("extensions", {})
)
logger.debug("Sending objects query response...")
await flow("response").send(response, properties={"id": id})
logger.debug("Objects query request completed")
except Exception as e:
logger.error(f"Exception in rows query service: {e}", exc_info=True)
logger.info("Sending error response...")
response = RowsQueryResponse(
error=Error(
type="rows-query-error",
message=str(e),
),
data=None,
errors=[],
extensions={}
)
await flow("response").send(response, properties={"id": id})
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for rows-query-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,6 +1,6 @@
"""
Structured Query Service - orchestrates natural language question processing.
Takes a question, converts it to GraphQL via nlp-query, executes via objects-query,
Takes a question, converts it to GraphQL via nlp-query, executes via rows-query,
and returns the results.
"""
@ -10,7 +10,7 @@ from typing import Dict, Any, Optional
from ...schema import StructuredQueryRequest, StructuredQueryResponse
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
from ...schema import ObjectsQueryRequest, ObjectsQueryResponse
from ...schema import RowsQueryRequest, RowsQueryResponse
from ...schema import Error
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec
@ -57,13 +57,13 @@ class Processor(FlowProcessor):
)
)
# Client spec for calling objects query service
# Client spec for calling rows query service
self.register_specification(
RequestResponseSpec(
request_name = "objects-query-request",
response_name = "objects-query-response",
request_schema = ObjectsQueryRequest,
response_schema = ObjectsQueryResponse
request_name = "rows-query-request",
response_name = "rows-query-response",
request_schema = RowsQueryRequest,
response_schema = RowsQueryResponse
)
)
@ -112,7 +112,7 @@ class Processor(FlowProcessor):
variables_as_strings[key] = str(value)
# Use user/collection values from request
objects_request = ObjectsQueryRequest(
objects_request = RowsQueryRequest(
user=request.user,
collection=request.collection,
query=nlp_response.graphql_query,
@ -120,12 +120,12 @@ class Processor(FlowProcessor):
operation_name=None
)
objects_response = await flow("objects-query-request").request(objects_request)
objects_response = await flow("rows-query-request").request(objects_request)
if objects_response.error is not None:
raise Exception(f"Objects query service error: {objects_response.error.message}")
# Handle GraphQL errors from the objects query service
raise Exception(f"Rows query service error: {objects_response.error.message}")
# Handle GraphQL errors from the rows query service
graphql_errors = []
if objects_response.errors:
for gql_error in objects_response.errors:

View file

@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):

View file

@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"

View file

@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:6333'

View file

@ -27,7 +27,7 @@ def get_term_value(term):
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:19530'
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):

View file

@ -32,7 +32,7 @@ def get_term_value(term):
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
default_cloud = "aws"
default_region = "us-east-1"

View file

@ -31,7 +31,7 @@ def get_term_value(term):
return term.id or term.value
default_ident = "ge-write"
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:6333'

View file

@ -1 +0,0 @@
# Objects storage module

View file

@ -1 +0,0 @@
from . write import *

View file

@ -1,3 +0,0 @@
from . write import run
run()

View file

@ -1,538 +0,0 @@
"""
Object writer for Cassandra. Input is ExtractedObject.
Writes structured objects to Cassandra tables based on schema definitions.
"""
import json
import logging
from typing import Dict, Set, Optional, Any
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cqlengine import connection
from cassandra import ConsistencyLevel
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "objects-write"
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"id": id,
"config_type": self.config_key,
}
)
self.register_specification(
ConsumerSpec(
name = "input",
schema = ExtractedObject,
handler = self.on_object
)
)
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
# Cache of known keyspaces/tables
self.known_keyspaces: Set[str] = set()
self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
self.known_tables[keyspace] = set()
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def get_cassandra_type(self, field_type: str, size: int = 0) -> str:
"""Convert schema field type to Cassandra type"""
# Handle None size
if size is None:
size = 0
type_mapping = {
"string": "text",
"integer": "bigint" if size > 4 else "int",
"float": "double" if size > 4 else "float",
"boolean": "boolean",
"timestamp": "timestamp",
"date": "date",
"time": "time",
"uuid": "uuid"
}
return type_mapping.get(field_type, "text")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'o_' + safe_name
return safe_name.lower()
def sanitize_table(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
# Replace non-alphanumeric characters with underscore
import re
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
safe_name = 'o_' + safe_name
return safe_name.lower()
def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema):
"""Ensure table exists with proper structure"""
table_key = f"{keyspace}.{table_name}"
if table_key in self.known_tables.get(keyspace, set()):
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Build column definitions
columns = ["collection text"] # Collection is always part of table
primary_key_fields = []
clustering_fields = []
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
cassandra_type = self.get_cassandra_type(field.type, field.size)
columns.append(f"{safe_field_name} {cassandra_type}")
if field.primary:
primary_key_fields.append(safe_field_name)
# Build primary key - collection is always first in partition key
if primary_key_fields:
primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))"
else:
# If no primary key defined, use collection and a synthetic id
columns.append("synthetic_id uuid")
primary_key = "PRIMARY KEY ((collection, synthetic_id))"
# Create table
create_table_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} (
{', '.join(columns)},
{primary_key}
)
"""
try:
self.session.execute(create_table_cql)
if keyspace not in self.known_tables:
self.known_tables[keyspace] = set()
self.known_tables[keyspace].add(table_key)
logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}")
# Create secondary indexes for indexed fields
for field in schema.fields:
if field.indexed and not field.primary:
safe_field_name = self.sanitize_name(field.name)
index_name = f"{safe_table}_{safe_field_name}_idx"
create_index_cql = f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {safe_keyspace}.{safe_table} ({safe_field_name})
"""
try:
self.session.execute(create_index_cql)
logger.info(f"Created index: {index_name}")
except Exception as e:
logger.warning(f"Failed to create index {index_name}: {e}")
except Exception as e:
logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True)
raise
def convert_value(self, value: Any, field_type: str) -> Any:
"""Convert value to appropriate type for Cassandra"""
if value is None:
return None
try:
if field_type == "integer":
return int(value)
elif field_type == "float":
return float(value)
elif field_type == "boolean":
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes')
return bool(value)
elif field_type == "timestamp":
# Handle timestamp conversion if needed
return value
else:
return str(value)
except Exception as e:
logger.warning(f"Failed to convert value {value} to type {field_type}: {e}")
return str(value)
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}")
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
# Ensure table exists
keyspace = obj.metadata.user
table_name = obj.schema_name
self.ensure_table(keyspace, table_name, schema)
# Prepare data for insertion
safe_keyspace = self.sanitize_name(keyspace)
safe_table = self.sanitize_table(table_name)
# Process each object in the batch
for obj_index, value_map in enumerate(obj.values):
# Build column names and values for this object
columns = ["collection"]
values = [obj.metadata.collection]
placeholders = ["%s"]
# Check if we need a synthetic ID
has_primary_key = any(field.primary for field in schema.fields)
if not has_primary_key:
import uuid
columns.append("synthetic_id")
values.append(uuid.uuid4())
placeholders.append("%s")
# Process fields for this object
skip_object = False
for field in schema.fields:
safe_field_name = self.sanitize_name(field.name)
raw_value = value_map.get(field.name)
# Handle required fields
if field.required and raw_value is None:
logger.warning(f"Required field {field.name} is missing in object {obj_index}")
# Continue anyway - Cassandra doesn't enforce NOT NULL
# Check if primary key field is NULL
if field.primary and raw_value is None:
logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}")
skip_object = True
break
# Convert value to appropriate type
converted_value = self.convert_value(raw_value, field.type)
columns.append(safe_field_name)
values.append(converted_value)
placeholders.append("%s")
# Skip this object if primary key validation failed
if skip_object:
continue
# Build and execute insert query for this object
insert_cql = f"""
INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
"""
# Debug: Show data being inserted
logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}")
if len(columns) != len(values) or len(columns) != len(placeholders):
raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}")
try:
# Convert to tuple - Cassandra driver requires tuple for parameters
self.session.execute(insert_cql, tuple(values))
except Exception as e:
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra object store"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Ensure keyspace exists
if safe_keyspace not in self.known_keyspaces:
self.ensure_keyspace(safe_keyspace)
self.known_keyspaces.add(safe_keyspace)
# For Cassandra objects, collection is just a property in rows
# No need to create separate tables per collection
# Just mark that we've seen this collection
logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using schema information"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if safe_keyspace not in self.known_keyspaces:
# Query to verify keyspace exists
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(safe_keyspace)
# Iterate over schemas we manage to delete from relevant tables
tables_deleted = 0
for schema_name, schema in self.schemas.items():
safe_table = self.sanitize_table(schema_name)
# Check if table exists
table_key = f"{user}.{schema_name}"
if table_key not in self.known_tables.get(user, set()):
logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping")
continue
try:
# Get primary key fields from schema
primary_key_fields = [field for field in schema.fields if field.primary]
if primary_key_fields:
# Schema has primary keys: need to query for partition keys first
# Build SELECT query for primary key fields
pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields]
select_cql = f"""
SELECT {', '.join(pk_field_names)}
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using full partition key
for row in rows:
where_clauses = ["collection = %s"]
values = [collection]
for field_name in pk_field_names:
where_clauses.append(f"{field_name} = %s")
values.append(getattr(row, field_name))
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE {' AND '.join(where_clauses)}
"""
self.session.execute(delete_cql, tuple(values))
else:
# No primary keys, uses synthetic_id
# Need to query for synthetic_ids first
select_cql = f"""
SELECT synthetic_id
FROM {safe_keyspace}.{safe_table}
WHERE collection = %s
ALLOW FILTERING
"""
rows = self.session.execute(select_cql, (collection,))
# Delete each row using collection and synthetic_id
for row in rows:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{safe_table}
WHERE collection = %s AND synthetic_id = %s
"""
self.session.execute(delete_cql, (collection, row.synthetic_id))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
"""Entry point for objects-write-cassandra command"""
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,3 @@
"""
Row embeddings storage modules.
"""

View file

@ -0,0 +1,5 @@
"""
Qdrant storage for row embeddings.
"""
from .write import Processor, run, default_ident

View file

@ -0,0 +1,4 @@
from .write import run
run()

View file

@ -0,0 +1,264 @@
"""
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
- index_value: The original list of values (for Cassandra lookup)
- text: The text that was embedded (for debugging/display)
"""
import logging
import re
import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.register_specification(
ConsumerSpec(
name="input",
schema=RowEmbeddings,
handler=self.on_embeddings
)
)
# Register config handler for collection management
self.register_config_handler(self.on_collection_config)
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections:
return
if not self.qdrant.collection_exists(collection_name):
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
)
)
self.created_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
embeddings = msg.value()
logger.info(
f"Writing {len(embeddings.embeddings)} embeddings for schema "
f"{embeddings.schema_name} from {embeddings.metadata.id}"
)
# Validate collection exists in config before processing
if not self.collection_exists(
embeddings.metadata.user, embeddings.metadata.collection
):
logger.warning(
f"Collection {embeddings.metadata.collection} for user "
f"{embeddings.metadata.user} does not exist in config. "
f"Dropping message."
)
return
user = embeddings.metadata.user
collection = embeddings.metadata.collection
schema_name = embeddings.schema_name
embeddings_written = 0
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if not matching_collections:
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
def run():
"""Entry point for row-embeddings-write-qdrant command"""
Processor.launch(default_ident, __doc__)

View file

@ -1,46 +1,49 @@
"""
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
Row writer for Cassandra. Input is ExtractedObject.
Writes structured rows to a unified Cassandra table with multi-index support.
Uses a single 'rows' table with the schema:
- collection: text
- schema_name: text
- index_name: text
- index_value: frozen<list<text>>
- data: map<text, text>
- source: text
Each row is written multiple times - once per indexed field defined in the schema.
"""
raise RuntimeError("This code is no longer in use")
import pulsar
import base64
import os
import argparse
import time
import json
import logging
import re
from typing import Dict, Set, Optional, Any, List, Tuple
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from ssl import SSLContext, PROTOCOL_TLSv1_2
from .... schema import Rows
from .... log_level import LogLevel
from .... base import Consumer
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
logger = logging.getLogger(__name__)
module = "rows-write"
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
default_ident = "rows-write"
default_input_queue = "rows-store" # Default queue name
default_subscriber = module
class Processor(Consumer):
class Processor(CollectionConfigHandler, FlowProcessor):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
id = params.get("id", default_ident)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host,
@ -48,99 +51,549 @@ class Processor(Consumer):
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
# Config key for schemas
self.config_key = params.get("config_type", "schema")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Rows,
"cassandra_host": ','.join(hosts),
"cassandra_username": username,
"cassandra_password": password,
"id": id,
"config_type": self.config_key,
}
)
if username and password:
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.tables = set()
self.register_specification(
ConsumerSpec(
name="input",
schema=ExtractedObject,
handler=self.on_object
)
)
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
# Register config handlers
self.register_config_handler(self.on_schema_config)
self.register_config_handler(self.on_collection_config)
self.session.execute("use trustgraph");
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
async def handle(self, msg):
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
# Schema storage: name -> RowSchema
self.schemas: Dict[str, RowSchema] = {}
# Cassandra session
self.cluster = None
self.session = None
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
return
try:
v = msg.value()
name = v.row_schema.name
if name not in self.tables:
# FIXME: SQL injection?
pkey = []
stmt = "create table if not exists " + name + " ( "
for field in v.row_schema.fields:
stmt += field.name + " text, "
if field.primary:
pkey.append(field.name)
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
self.session.execute(stmt)
self.tables.add(name);
for row in v.rows:
field_names = []
values = []
for field in v.row_schema.fields:
field_names.append(field.name)
values.append(row[field.name])
# FIXME: SQL injection?
stmt = (
"insert into " + name + " (" + ", ".join(field_names) +
") values (" + ",".join(["%s"] * len(values)) + ")"
if self.cassandra_username and self.cassandra_password:
auth_provider = PlainTextAuthProvider(
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
self.session.execute(stmt, values)
self.session = self.cluster.connect()
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True)
raise
logger.error(f"Exception: {str(e)}", exc_info=True)
async def on_schema_config(self, config, version):
"""Handle schema configuration updates"""
logger.info(f"Loading schema configuration version {version}")
# If there's an error make sure to do table creation etc.
self.tables.remove(name)
# Track which schemas changed so we can clear partition cache
old_schema_names = set(self.schemas.keys())
raise e
# Clear existing schemas
self.schemas = {}
# Check if our config type exists
if self.config_key not in config:
logger.warning(f"No '{self.config_key}' type in configuration")
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = Field(
name=field_def["name"],
type=field_def["type"],
size=field_def.get("size", 0),
primary=field_def.get("primary_key", False),
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
fields=fields
)
self.schemas[schema_name] = row_schema
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
except Exception as e:
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
# Clear partition cache for schemas that changed
# This ensures next write will re-register partitions
new_schema_names = set(self.schemas.keys())
changed_schemas = old_schema_names.symmetric_difference(new_schema_names)
if changed_schemas:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if sch not in changed_schemas
}
logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}")
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Cassandra compatibility"""
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# Ensure it starts with a letter
if safe_name and not safe_name[0].isalpha():
safe_name = 'r_' + safe_name
return safe_name.lower()
def ensure_keyspace(self, keyspace: str):
"""Ensure keyspace exists in Cassandra"""
if keyspace in self.known_keyspaces:
return
# Connect if needed
self.connect_cassandra()
# Sanitize keyspace name
safe_keyspace = self.sanitize_name(keyspace)
# Create keyspace if not exists
create_keyspace_cql = f"""
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': 1
}}
"""
try:
self.session.execute(create_keyspace_cql)
self.known_keyspaces.add(keyspace)
logger.info(f"Ensured keyspace exists: {safe_keyspace}")
except Exception as e:
logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True)
raise
def ensure_tables(self, keyspace: str):
"""Ensure unified rows and row_partitions tables exist"""
if keyspace in self.tables_initialized:
return
# Ensure keyspace exists first
self.ensure_keyspace(keyspace)
safe_keyspace = self.sanitize_name(keyspace)
# Create unified rows table
create_rows_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows (
collection text,
schema_name text,
index_name text,
index_value frozen<list<text>>,
data map<text, text>,
source text,
PRIMARY KEY ((collection, schema_name, index_name), index_value)
)
"""
# Create row_partitions tracking table
create_partitions_cql = f"""
CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions (
collection text,
schema_name text,
index_name text,
PRIMARY KEY ((collection), schema_name, index_name)
)
"""
try:
self.session.execute(create_rows_cql)
logger.info(f"Ensured rows table exists: {safe_keyspace}.rows")
self.session.execute(create_partitions_cql)
logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions")
self.tables_initialized.add(keyspace)
except Exception as e:
logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True)
raise
def get_index_names(self, schema: RowSchema) -> List[str]:
"""
Get all index names for a schema.
Returns list of index_name strings (single field names or comma-joined composites).
"""
index_names = []
for field in schema.fields:
# Primary key fields are treated as indexes
if field.primary:
index_names.append(field.name)
# Indexed fields
elif field.indexed:
index_names.append(field.name)
# TODO: Support composite indexes in the future
# For now, each indexed field is a single-field index
return index_names
def register_partitions(self, keyspace: str, collection: str, schema_name: str):
"""
Register partition entries for a (collection, schema_name) pair.
Called once on first row for each pair.
"""
cache_key = (collection, schema_name)
if cache_key in self.registered_partitions:
return
schema = self.schemas.get(schema_name)
if not schema:
logger.warning(f"Cannot register partitions - schema {schema_name} not found")
return
safe_keyspace = self.sanitize_name(keyspace)
index_names = self.get_index_names(schema)
# Insert partition entries for each index
insert_cql = f"""
INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name)
VALUES (%s, %s, %s)
"""
for index_name in index_names:
try:
self.session.execute(insert_cql, (collection, schema_name, index_name))
except Exception as e:
logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}")
self.registered_partitions.add(cache_key)
logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}")
def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]:
"""
Build the index_value list for a given index.
For single-field indexes, returns a single-element list.
For composite indexes (comma-separated), returns multiple elements.
"""
field_names = [f.strip() for f in index_name.split(',')]
values = []
for field_name in field_names:
value = value_map.get(field_name)
# Convert to string for storage
values.append(str(value) if value is not None else "")
return values
async def on_object(self, msg, consumer, flow):
"""Process incoming ExtractedObject and store in Cassandra"""
obj = msg.value()
logger.info(
f"Storing {len(obj.values)} rows for schema {obj.schema_name} "
f"from {obj.metadata.id}"
)
# Validate collection exists before accepting writes
if not self.collection_exists(obj.metadata.user, obj.metadata.collection):
error_msg = (
f"Collection {obj.metadata.collection} does not exist. "
f"Create it first via collection management API."
)
logger.error(error_msg)
raise ValueError(error_msg)
# Get schema definition
schema = self.schemas.get(obj.schema_name)
if not schema:
logger.warning(f"No schema found for {obj.schema_name} - skipping")
return
keyspace = obj.metadata.user
collection = obj.metadata.collection
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist
self.ensure_tables(keyspace)
# Register partitions if first time seeing this (collection, schema_name)
self.register_partitions(keyspace, collection, schema_name)
safe_keyspace = self.sanitize_name(keyspace)
# Get all index names for this schema
index_names = self.get_index_names(schema)
if not index_names:
logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable")
return
# Prepare insert statement
insert_cql = f"""
INSERT INTO {safe_keyspace}.rows
(collection, schema_name, index_name, index_value, data, source)
VALUES (%s, %s, %s, %s, %s, %s)
"""
# Process each row in the batch
rows_written = 0
for row_index, value_map in enumerate(obj.values):
# Convert all values to strings for the data map
data_map = {}
for field in schema.fields:
raw_value = value_map.get(field.name)
if raw_value is not None:
data_map[field.name] = str(raw_value)
# Write one copy per index
for index_name in index_names:
index_value = self.build_index_value(value_map, index_name)
# Skip if index value is empty/null
if not index_value or all(v == "" for v in index_value):
logger.debug(
f"Skipping index {index_name} for row {row_index} - "
f"empty index value"
)
continue
try:
self.session.execute(
insert_cql,
(collection, schema_name, index_name, index_value, data_map, source)
)
rows_written += 1
except Exception as e:
logger.error(
f"Failed to insert row {row_index} for index {index_name}: {e}",
exc_info=True
)
raise
logger.info(
f"Wrote {rows_written} index entries for {len(obj.values)} rows "
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected
self.connect_cassandra()
# Ensure tables exist
self.ensure_tables(user)
logger.info(f"Collection {collection} ready for user {user}")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if user not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection,))
partition_list = list(partitions)
except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, partition.schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{partition.schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection,))
except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
# Clear from local cache
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
}
logger.info(
f"Deleted collection {collection}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
# Discover partitions for this collection + schema
select_partitions_cql = f"""
SELECT index_name FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
partition_list = list(partitions)
except Exception as e:
logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}"
)
raise
# Delete each partition from rows table
delete_rows_cql = f"""
DELETE FROM {safe_keyspace}.rows
WHERE collection = %s AND schema_name = %s AND index_name = %s
"""
partitions_deleted = 0
for partition in partition_list:
try:
self.session.execute(
delete_rows_cql,
(collection, schema_name, partition.index_name)
)
partitions_deleted += 1
except Exception as e:
logger.error(
f"Failed to delete partition {collection}/{schema_name}/"
f"{partition.index_name}: {e}"
)
raise
# Clean up row_partitions entries for this schema
delete_partitions_cql = f"""
DELETE FROM {safe_keyspace}.row_partitions
WHERE collection = %s AND schema_name = %s
"""
try:
self.session.execute(delete_partitions_cql, (collection, schema_name))
except Exception as e:
logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
)
raise
# Clear from local cache
self.registered_partitions.discard((collection, schema_name))
logger.info(
f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions "
f"from keyspace {safe_keyspace}"
)
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:
self.cluster.shutdown()
logger.info("Closed Cassandra connection")
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
FlowProcessor.add_args(parser)
add_cassandra_args(parser)
parser.add_argument(
'--config-type',
default='schema',
help='Configuration type prefix for schemas (default: schema)'
)
def run():
Processor.launch(module, __doc__)
"""Entry point for rows-write-cassandra command"""
Processor.launch(default_ident, __doc__)