diff --git a/Makefile b/Makefile index 8eb2d324..13679d0d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ VERSION=0.0.0 DOCKER=podman -all: container +all: containers # Not used wheels: @@ -49,7 +49,9 @@ update-package-versions: echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py -container: update-package-versions +FORCE: + +containers: FORCE ${DOCKER} build -f containers/Containerfile.base \ -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ diff --git a/docs/tech-specs/structured-data-2.md b/docs/tech-specs/structured-data-2.md new file mode 100644 index 00000000..10bbdae9 --- /dev/null +++ b/docs/tech-specs/structured-data-2.md @@ -0,0 +1,467 @@ +# Structured Data Technical Specification (Part 2) + +## Overview + +This specification addresses issues and gaps identified during the initial implementation of TrustGraph's structured data integration, as described in `structured-data.md`. + +## Problem Statements + +### 1. Naming Inconsistency: "Object" vs "Row" + +The current implementation uses "object" terminology throughout (e.g., `ExtractedObject`, object extraction, object embeddings). This naming is too generic and causes confusion: + +- "Object" is an overloaded term in software (Python objects, JSON objects, etc.) +- The data being handled is fundamentally tabular - rows in tables with defined schemas +- "Row" more accurately describes the data model and aligns with database terminology + +This inconsistency appears in module names, class names, message types, and documentation. + +### 2. Row Store Query Limitations + +The current row store implementation has significant query limitations: + +**Natural Language Mismatch**: Queries struggle with real-world data variations. For example: +- A street database containing `"CHESTNUT ST"` is difficult to find when asking about `"Chestnut Street"` +- Abbreviations, case differences, and formatting variations break exact-match queries +- Users expect semantic understanding, but the store provides literal matching + +**Schema Evolution Issues**: Changing schemas causes problems: +- Existing data may not conform to updated schemas +- Table structure changes can break queries and data integrity +- No clear migration path for schema updates + +### 3. Row Embeddings Required + +Related to problem 2, the system needs vector embeddings for row data to enable: + +- Semantic search across structured data (finding "Chestnut Street" when data contains "CHESTNUT ST") +- Similarity matching for fuzzy queries +- Hybrid search combining structured filters with semantic similarity +- Better natural language query support + +The embedding service was specified but not implemented. + +### 4. Row Data Ingestion Incomplete + +The structured data ingestion pipeline is not fully operational: + +- Diagnostic prompts exist to classify input formats (CSV, JSON, etc.) +- The ingestion service that uses these prompts is not plumbed into the system +- No end-to-end path for loading pre-structured data into the row store + +## Goals + +- **Schema Flexibility**: Enable schema evolution without breaking existing data or requiring migrations +- **Consistent Naming**: Standardize on "row" terminology throughout the codebase +- **Semantic Queryability**: Support fuzzy/semantic matching via row embeddings +- **Complete Ingestion Pipeline**: Provide end-to-end path for loading structured data + +## Technical Design + +### Unified Row Storage Schema + +The previous implementation created a separate Cassandra table for each schema. This caused problems when schemas evolved, as table structure changes required migrations. + +The new design uses a single unified table for all row data: + +```sql +CREATE TABLE rows ( + collection text, + schema_name text, + index_name text, + index_value frozen>, + data map, + source text, + PRIMARY KEY ((collection, schema_name, index_name), index_value) +) +``` + +#### Column Definitions + +| Column | Type | Description | +|--------|------|-------------| +| `collection` | `text` | Data collection/import identifier (from metadata) | +| `schema_name` | `text` | Name of the schema this row conforms to | +| `index_name` | `text` | Name of the indexed field(s), comma-joined for composites | +| `index_value` | `frozen>` | Index value(s) as a list | +| `data` | `map` | Row data as key-value pairs | +| `source` | `text` | Optional URI linking to provenance information in the knowledge graph. Empty string or NULL indicates no source. | + +#### Index Handling + +Each row is stored multiple times - once per indexed field defined in the schema. The primary key fields are treated as an index with no special marker, providing future flexibility. + +**Single-field index example:** +- Schema defines `email` as indexed +- `index_name = "email"` +- `index_value = ['foo@bar.com']` + +**Composite index example:** +- Schema defines composite index on `region` and `status` +- `index_name = "region,status"` (field names sorted and comma-joined) +- `index_value = ['US', 'active']` (values in same order as field names) + +**Primary key example:** +- Schema defines `customer_id` as primary key +- `index_name = "customer_id"` +- `index_value = ['CUST001']` + +#### Query Patterns + +All queries follow the same pattern regardless of which index is used: + +```sql +SELECT * FROM rows +WHERE collection = 'import_2024' + AND schema_name = 'customers' + AND index_name = 'email' + AND index_value = ['foo@bar.com'] +``` + +#### Design Trade-offs + +**Advantages:** +- Schema changes don't require table structure changes +- Row data is opaque to Cassandra - field additions/removals are transparent +- Consistent query pattern for all access methods +- No Cassandra secondary indexes (which can be slow at scale) +- Native Cassandra types throughout (`map`, `frozen`) + +**Trade-offs:** +- Write amplification: each row insert = N inserts (one per indexed field) +- Storage overhead from duplicated row data +- Type information stored in schema config, conversion at application layer + +#### Consistency Model + +The design accepts certain simplifications: + +1. **No row updates**: The system is append-only. This eliminates consistency concerns about updating multiple copies of the same row. + +2. **Schema change tolerance**: When schemas change (e.g., indexes added/removed), existing rows retain their original indexing. Old rows won't be discoverable via new indexes. Users can delete and recreate a schema to ensure consistency if needed. + +### Partition Tracking and Deletion + +#### The Problem + +With the partition key `(collection, schema_name, index_name)`, efficient deletion requires knowing all partition keys to delete. Deleting by just `collection` or `collection + schema_name` requires knowing all the `index_name` values that have data. + +#### Partition Tracking Table + +A secondary lookup table tracks which partitions exist: + +```sql +CREATE TABLE row_partitions ( + collection text, + schema_name text, + index_name text, + PRIMARY KEY ((collection), schema_name, index_name) +) +``` + +This enables efficient discovery of partitions for deletion operations. + +#### Row Writer Behavior + +The row writer maintains an in-memory cache of registered `(collection, schema_name)` pairs. When processing a row: + +1. Check if `(collection, schema_name)` is in the cache +2. If not cached (first row for this pair): + - Look up the schema config to get all index names + - Insert entries into `row_partitions` for each `(collection, schema_name, index_name)` + - Add the pair to the cache +3. Proceed with writing the row data + +The row writer also monitors schema config change events. When a schema changes, relevant cache entries are cleared so the next row triggers re-registration with the updated index names. + +This approach ensures: +- Lookup table writes happen once per `(collection, schema_name)` pair, not per row +- The lookup table reflects the indexes that were active when data was written +- Schema changes mid-import are picked up correctly + +#### Deletion Operations + +**Delete collection:** +```sql +-- 1. Discover all partitions +SELECT schema_name, index_name FROM row_partitions WHERE collection = 'X'; + +-- 2. Delete each partition from rows table +DELETE FROM rows WHERE collection = 'X' AND schema_name = '...' AND index_name = '...'; +-- (repeat for each discovered partition) + +-- 3. Clean up the lookup table +DELETE FROM row_partitions WHERE collection = 'X'; +``` + +**Delete collection + schema:** +```sql +-- 1. Discover partitions for this schema +SELECT index_name FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y'; + +-- 2. Delete each partition from rows table +DELETE FROM rows WHERE collection = 'X' AND schema_name = 'Y' AND index_name = '...'; +-- (repeat for each discovered partition) + +-- 3. Clean up the lookup table entries +DELETE FROM row_partitions WHERE collection = 'X' AND schema_name = 'Y'; +``` + +### Row Embeddings + +Row embeddings enable semantic/fuzzy matching on indexed values, solving the natural language mismatch problem (e.g., finding "CHESTNUT ST" when querying for "Chestnut Street"). + +#### Design Overview + +Each indexed value is embedded and stored in a vector store (Qdrant). At query time, the query is embedded, similar vectors are found, and the associated metadata is used to look up the actual rows in Cassandra. + +#### Qdrant Collection Structure + +One Qdrant collection per `(user, collection, schema_name, dimension)` tuple: + +- **Collection naming:** `rows_{user}_{collection}_{schema_name}_{dimension}` +- Names are sanitized (non-alphanumeric characters replaced with `_`, lowercased, numeric prefixes get `r_` prefix) +- **Rationale:** Enables clean deletion of a `(user, collection, schema_name)` instance by dropping matching Qdrant collections; dimension suffix allows different embedding models to coexist + +#### What Gets Embedded + +The text representation of index values: + +| Index Type | Example `index_value` | Text to Embed | +|------------|----------------------|---------------| +| Single-field | `['foo@bar.com']` | `"foo@bar.com"` | +| Composite | `['US', 'active']` | `"US active"` (space-joined) | + +#### Point Structure + +Each Qdrant point contains: + +```json +{ + "id": "", + "vector": [0.1, 0.2, ...], + "payload": { + "index_name": "street_name", + "index_value": ["CHESTNUT ST"], + "text": "CHESTNUT ST" + } +} +``` + +| Payload Field | Description | +|---------------|-------------| +| `index_name` | The indexed field(s) this embedding represents | +| `index_value` | The original list of values (for Cassandra lookup) | +| `text` | The text that was embedded (for debugging/display) | + +Note: `user`, `collection`, and `schema_name` are implicit from the Qdrant collection name. + +#### Query Flow + +1. User queries for "Chestnut Street" within user U, collection X, schema Y +2. Embed the query text +3. Determine Qdrant collection name(s) matching prefix `rows_U_X_Y_` +4. Search matching Qdrant collection(s) for nearest vectors +5. Get matching points with payloads containing `index_name` and `index_value` +6. Query Cassandra: + ```sql + SELECT * FROM rows + WHERE collection = 'X' + AND schema_name = 'Y' + AND index_name = '' + AND index_value = + ``` +7. Return matched rows + +#### Optional: Filtering by Index Name + +Queries can optionally filter by `index_name` in Qdrant to search only specific fields: + +- **"Find any field matching 'Chestnut'"** → search all vectors in the collection +- **"Find street_name matching 'Chestnut'"** → filter where `payload.index_name = 'street_name'` + +#### Architecture + +Row embeddings follow the **two-stage pattern** used by GraphRAG (graph-embeddings, document-embeddings): + +- **Stage 1: Embedding computation** (`trustgraph-flow/trustgraph/embeddings/row_embeddings/`) - Consumes `ExtractedObject`, computes embeddings via the embeddings service, outputs `RowEmbeddings` +- **Stage 2: Embedding storage** (`trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/`) - Consumes `RowEmbeddings`, writes vectors to Qdrant + +The Cassandra row writer is a separate parallel consumer: + +- **Cassandra row writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra`) - Consumes `ExtractedObject`, writes rows to Cassandra + +All three services consume from the same flow, keeping them decoupled. This allows: +- Independent scaling of Cassandra writes vs embedding generation vs vector storage +- Embedding services can be disabled if not needed +- Failures in one service don't affect the others +- Consistent architecture with GraphRAG pipelines + +#### Write Path + +**Stage 1 (row-embeddings processor):** When receiving an `ExtractedObject`: + +1. Look up the schema to find indexed fields +2. For each indexed field: + - Build the text representation of the index value + - Compute embedding via the embeddings service +3. Output a `RowEmbeddings` message containing all computed vectors + +**Stage 2 (row-embeddings-write-qdrant):** When receiving a `RowEmbeddings`: + +1. For each embedding in the message: + - Determine Qdrant collection from `(user, collection, schema_name, dimension)` + - Create collection if needed (lazy creation on first write) + - Upsert point with vector and payload + +#### Message Types + +```python +@dataclass +class RowIndexEmbedding: + index_name: str # The indexed field name(s) + index_value: list[str] # The field value(s) + text: str # Text that was embedded + vectors: list[list[float]] # Computed embedding vectors + +@dataclass +class RowEmbeddings: + metadata: Metadata + schema_name: str + embeddings: list[RowIndexEmbedding] +``` + +#### Deletion Integration + +Qdrant collections are discovered by prefix matching on the collection name pattern: + +**Delete `(user, collection)`:** +1. List all Qdrant collections matching prefix `rows_{user}_{collection}_` +2. Delete each matching collection +3. Delete Cassandra rows partitions (as documented above) +4. Clean up `row_partitions` entries + +**Delete `(user, collection, schema_name)`:** +1. List all Qdrant collections matching prefix `rows_{user}_{collection}_{schema_name}_` +2. Delete each matching collection (handles multiple dimensions) +3. Delete Cassandra rows partitions +4. Clean up `row_partitions` + +#### Module Locations + +| Stage | Module | Entry Point | +|-------|--------|-------------| +| Stage 1 | `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | `row-embeddings` | +| Stage 2 | `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | `row-embeddings-write-qdrant` | + +### Row Embeddings Query API + +The row embeddings query is a **separate API** from the GraphQL row query service: + +| API | Purpose | Backend | +|-----|---------|---------| +| Row Query (GraphQL) | Exact matching on indexed fields | Cassandra | +| Row Embeddings Query | Fuzzy/semantic matching | Qdrant | + +This separation keeps concerns clean: +- GraphQL service focuses on exact, structured queries +- Embeddings API handles semantic similarity +- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data + +Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant` + +### Row Data Ingestion + +Deferred to a subsequent phase. Will be designed alongside other ingestion changes. + +## Implementation Impact + +### Current State Analysis + +The existing implementation has two main components: + +| Component | Location | Lines | Description | +|-----------|----------|-------|-------------| +| Query Service | `trustgraph-flow/trustgraph/query/objects/cassandra/service.py` | ~740 | Monolithic: GraphQL schema generation, filter parsing, Cassandra queries, request handling | +| Writer | `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` | ~540 | Per-schema table creation, secondary indexes, insert/delete | + +**Current Query Pattern:** +```sql +SELECT * FROM {keyspace}.o_{schema_name} +WHERE collection = 'X' AND email = 'foo@bar.com' +ALLOW FILTERING +``` + +**New Query Pattern:** +```sql +SELECT * FROM {keyspace}.rows +WHERE collection = 'X' AND schema_name = 'customers' + AND index_name = 'email' AND index_value = ['foo@bar.com'] +``` + +### Key Changes + +1. **Query semantics simplify**: The new schema only supports exact matches on `index_value`. The current GraphQL filters (`gt`, `lt`, `contains`, etc.) either: + - Become post-filtering on returned data (if still needed) + - Are removed in favor of using the embeddings API for fuzzy matching + +2. **GraphQL code is tightly coupled**: The current `service.py` bundles Strawberry type generation, filter parsing, and Cassandra-specific queries. Adding another row store backend would duplicate ~400 lines of GraphQL code. + +### Proposed Refactor + +The refactor has two parts: + +#### 1. Break Out GraphQL Code + +Extract reusable GraphQL components into a shared module: + +``` +trustgraph-flow/trustgraph/query/graphql/ +├── __init__.py +├── types.py # Filter types (IntFilter, StringFilter, FloatFilter) +├── schema.py # Dynamic schema generation from RowSchema +└── filters.py # Filter parsing utilities +``` + +This enables: +- Reuse across different row store backends +- Cleaner separation of concerns +- Easier testing of GraphQL logic independently + +#### 2. Implement New Table Schema + +Refactor the Cassandra-specific code to use the unified table: + +**Writer** (`trustgraph-flow/trustgraph/storage/rows/cassandra/`): +- Single `rows` table instead of per-schema tables +- Write N copies per row (one per index) +- Register to `row_partitions` table +- Simpler table creation (one-time setup) + +**Query Service** (`trustgraph-flow/trustgraph/query/rows/cassandra/`): +- Query the unified `rows` table +- Use extracted GraphQL module for schema generation +- Simplified filter handling (exact match only at DB level) + +### Module Renames + +As part of the "object" → "row" naming cleanup: + +| Current | New | +|---------|-----| +| `storage/objects/cassandra/` | `storage/rows/cassandra/` | +| `query/objects/cassandra/` | `query/rows/cassandra/` | +| `embeddings/object_embeddings/` | `embeddings/row_embeddings/` | + +### New Modules + +| Module | Purpose | +|--------|---------| +| `trustgraph-flow/trustgraph/query/graphql/` | Shared GraphQL utilities | +| `trustgraph-flow/trustgraph/query/row_embeddings/qdrant/` | Row embeddings query API | +| `trustgraph-flow/trustgraph/embeddings/row_embeddings/` | Row embeddings computation (Stage 1) | +| `trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/` | Row embeddings storage (Stage 2) | + +## References + +- [Structured Data Technical Specification](structured-data.md) diff --git a/specs/api/components/schemas/query/ObjectsQueryRequest.yaml b/specs/api/components/schemas/query/RowsQueryRequest.yaml similarity index 92% rename from specs/api/components/schemas/query/ObjectsQueryRequest.yaml rename to specs/api/components/schemas/query/RowsQueryRequest.yaml index 775bbc4b..08f03ad3 100644 --- a/specs/api/components/schemas/query/ObjectsQueryRequest.yaml +++ b/specs/api/components/schemas/query/RowsQueryRequest.yaml @@ -1,6 +1,6 @@ type: object description: | - Objects query request - GraphQL query over knowledge graph. + Rows query request - GraphQL query over structured data. required: - query properties: diff --git a/specs/api/components/schemas/query/ObjectsQueryResponse.yaml b/specs/api/components/schemas/query/RowsQueryResponse.yaml similarity index 96% rename from specs/api/components/schemas/query/ObjectsQueryResponse.yaml rename to specs/api/components/schemas/query/RowsQueryResponse.yaml index 8fd9b6a6..a8fed63d 100644 --- a/specs/api/components/schemas/query/ObjectsQueryResponse.yaml +++ b/specs/api/components/schemas/query/RowsQueryResponse.yaml @@ -1,5 +1,5 @@ type: object -description: Objects query response (GraphQL format) +description: Rows query response (GraphQL format) properties: data: description: GraphQL response data (JSON object or null) diff --git a/specs/api/openapi.yaml b/specs/api/openapi.yaml index 55c05741..3cf5517e 100644 --- a/specs/api/openapi.yaml +++ b/specs/api/openapi.yaml @@ -121,8 +121,8 @@ paths: $ref: './paths/flow/mcp-tool.yaml' /api/v1/flow/{flow}/service/triples: $ref: './paths/flow/triples.yaml' - /api/v1/flow/{flow}/service/objects: - $ref: './paths/flow/objects.yaml' + /api/v1/flow/{flow}/service/rows: + $ref: './paths/flow/rows.yaml' /api/v1/flow/{flow}/service/nlp-query: $ref: './paths/flow/nlp-query.yaml' /api/v1/flow/{flow}/service/structured-query: diff --git a/specs/api/paths/flow/nlp-query.yaml b/specs/api/paths/flow/nlp-query.yaml index 7032b5b9..a10f3a67 100644 --- a/specs/api/paths/flow/nlp-query.yaml +++ b/specs/api/paths/flow/nlp-query.yaml @@ -34,7 +34,7 @@ post: ``` 1. User asks: "Who does Alice know?" 2. NLP Query generates GraphQL - 3. Execute via /api/v1/flow/{flow}/service/objects + 3. Execute via /api/v1/flow/{flow}/service/rows 4. Return results to user ``` diff --git a/specs/api/paths/flow/objects.yaml b/specs/api/paths/flow/rows.yaml similarity index 90% rename from specs/api/paths/flow/objects.yaml rename to specs/api/paths/flow/rows.yaml index ac94a353..d648c9db 100644 --- a/specs/api/paths/flow/objects.yaml +++ b/specs/api/paths/flow/rows.yaml @@ -1,19 +1,19 @@ post: tags: - Flow Services - summary: Objects query - GraphQL over knowledge graph + summary: Rows query - GraphQL over structured data description: | - Query knowledge graph using GraphQL for object-oriented data access. + Query structured data using GraphQL for row-oriented data access. - ## Objects Query Overview + ## Rows Query Overview - GraphQL interface to knowledge graph: + GraphQL interface to structured data: - **Schema-driven**: Predefined types and relationships - **Flexible queries**: Request exactly what you need - **Nested data**: Traverse relationships in single query - **Type-safe**: Strong typing with introspection - Abstracts RDF triples into familiar object model. + Abstracts structured rows into familiar object model. ## GraphQL Benefits @@ -61,7 +61,7 @@ post: Schema defines available types via config service. Use introspection query to discover schema. - operationId: objectsQueryService + operationId: rowsQueryService security: - bearerAuth: [] parameters: @@ -77,7 +77,7 @@ post: content: application/json: schema: - $ref: '../../components/schemas/query/ObjectsQueryRequest.yaml' + $ref: '../../components/schemas/query/RowsQueryRequest.yaml' examples: simpleQuery: summary: Simple query @@ -129,7 +129,7 @@ post: content: application/json: schema: - $ref: '../../components/schemas/query/ObjectsQueryResponse.yaml' + $ref: '../../components/schemas/query/RowsQueryResponse.yaml' examples: successfulQuery: summary: Successful query diff --git a/specs/api/paths/flow/structured-query.yaml b/specs/api/paths/flow/structured-query.yaml index c094c50a..6d4dfe87 100644 --- a/specs/api/paths/flow/structured-query.yaml +++ b/specs/api/paths/flow/structured-query.yaml @@ -9,7 +9,7 @@ post: Combines two operations in one call: 1. **NLP Query**: Generate GraphQL from question - 2. **Objects Query**: Execute generated query + 2. **Rows Query**: Execute generated query 3. **Return Results**: Direct answer data Simplest way to query knowledge graph with natural language. @@ -21,7 +21,7 @@ post: - **Output**: Query results (data) - **Use when**: Want simple, direct answers - ### NLP Query + Objects Query (separate calls) + ### NLP Query + Rows Query (separate calls) - **Step 1**: Convert question → GraphQL - **Step 2**: Execute GraphQL → results - **Use when**: Need to inspect/modify query before execution diff --git a/specs/websocket/components/messages/ServiceRequest.yaml b/specs/websocket/components/messages/ServiceRequest.yaml index 8df44caa..28fe7eff 100644 --- a/specs/websocket/components/messages/ServiceRequest.yaml +++ b/specs/websocket/components/messages/ServiceRequest.yaml @@ -25,7 +25,7 @@ payload: - $ref: './requests/EmbeddingsRequest.yaml' - $ref: './requests/McpToolRequest.yaml' - $ref: './requests/TriplesRequest.yaml' - - $ref: './requests/ObjectsRequest.yaml' + - $ref: './requests/RowsRequest.yaml' - $ref: './requests/NlpQueryRequest.yaml' - $ref: './requests/StructuredQueryRequest.yaml' - $ref: './requests/StructuredDiagRequest.yaml' diff --git a/specs/websocket/components/messages/requests/ObjectsRequest.yaml b/specs/websocket/components/messages/requests/RowsRequest.yaml similarity index 60% rename from specs/websocket/components/messages/requests/ObjectsRequest.yaml rename to specs/websocket/components/messages/requests/RowsRequest.yaml index 61c9ef64..8eaa0919 100644 --- a/specs/websocket/components/messages/requests/ObjectsRequest.yaml +++ b/specs/websocket/components/messages/requests/RowsRequest.yaml @@ -1,5 +1,5 @@ type: object -description: WebSocket request for objects service (flow-hosted service) +description: WebSocket request for rows service (flow-hosted service) required: - id - service @@ -11,16 +11,16 @@ properties: description: Unique request identifier service: type: string - const: objects - description: Service identifier for objects service + const: rows + description: Service identifier for rows service flow: type: string description: Flow ID request: - $ref: '../../../../api/components/schemas/query/ObjectsQueryRequest.yaml' + $ref: '../../../../api/components/schemas/query/RowsQueryRequest.yaml' examples: - id: req-1 - service: objects + service: rows flow: my-flow request: query: "{ entity(id: \"https://example.com/entity1\") { properties { key value } } }" diff --git a/tests/contract/test_objects_cassandra_contracts.py b/tests/contract/test_rows_cassandra_contracts.py similarity index 83% rename from tests/contract/test_objects_cassandra_contracts.py rename to tests/contract/test_rows_cassandra_contracts.py index bb8aec8a..d1a8ba26 100644 --- a/tests/contract/test_objects_cassandra_contracts.py +++ b/tests/contract/test_rows_cassandra_contracts.py @@ -1,8 +1,8 @@ """ -Contract tests for Cassandra Object Storage +Contract tests for Cassandra Row Storage These tests verify the message contracts and schema compatibility -for the objects storage processor. +for the rows storage processor. """ import pytest @@ -10,12 +10,12 @@ import json from pulsar.schema import AvroSchema from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field -from trustgraph.storage.objects.cassandra.write import Processor +from trustgraph.storage.rows.cassandra.write import Processor @pytest.mark.contract -class TestObjectsCassandraContracts: - """Contract tests for Cassandra object storage messages""" +class TestRowsCassandraContracts: + """Contract tests for Cassandra row storage messages""" def test_extracted_object_input_contract(self): """Test that ExtractedObject schema matches expected input format""" @@ -145,50 +145,6 @@ class TestObjectsCassandraContracts: assert required_field_keys.issubset(field.keys()) assert set(field.keys()).issubset(required_field_keys | optional_field_keys) - def test_cassandra_type_mapping_contract(self): - """Test that all supported field types have Cassandra mappings""" - processor = Processor.__new__(Processor) - - # All field types that should be supported - supported_types = [ - ("string", "text"), - ("integer", "int"), # or bigint based on size - ("float", "float"), # or double based on size - ("boolean", "boolean"), - ("timestamp", "timestamp"), - ("date", "date"), - ("time", "time"), - ("uuid", "uuid") - ] - - for field_type, expected_cassandra_type in supported_types: - cassandra_type = processor.get_cassandra_type(field_type) - # For integer and float, the exact type depends on size - if field_type in ["integer", "float"]: - assert cassandra_type in ["int", "bigint", "float", "double"] - else: - assert cassandra_type == expected_cassandra_type - - def test_value_conversion_contract(self): - """Test value conversion for all supported types""" - processor = Processor.__new__(Processor) - - # Test conversions maintain data integrity - test_cases = [ - # (input_value, field_type, expected_output, expected_type) - ("123", "integer", 123, int), - ("123.45", "float", 123.45, float), - ("true", "boolean", True, bool), - ("false", "boolean", False, bool), - ("test string", "string", "test string", str), - (None, "string", None, type(None)), - ] - - for input_val, field_type, expected_val, expected_type in test_cases: - result = processor.convert_value(input_val, field_type) - assert result == expected_val - assert isinstance(result, expected_type) or result is None - @pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type") def test_extracted_object_serialization_contract(self): """Test that ExtractedObject can be serialized/deserialized correctly""" @@ -222,43 +178,31 @@ class TestObjectsCassandraContracts: assert decoded.confidence == original.confidence assert decoded.source_span == original.source_span - def test_cassandra_table_naming_contract(self): + def test_cassandra_name_sanitization_contract(self): """Test Cassandra naming conventions and constraints""" processor = Processor.__new__(Processor) - - # Test table naming (always gets o_ prefix) - table_test_names = [ - ("simple_name", "o_simple_name"), - ("Name-With-Dashes", "o_name_with_dashes"), - ("name.with.dots", "o_name_with_dots"), - ("123_numbers", "o_123_numbers"), - ("special!@#chars", "o_special___chars"), # 3 special chars become 3 underscores - ("UPPERCASE", "o_uppercase"), - ("CamelCase", "o_camelcase"), - ("", "o_"), # Edge case - empty string becomes o_ - ] - - for input_name, expected_name in table_test_names: - result = processor.sanitize_table(input_name) - assert result == expected_name - # Verify result is valid Cassandra identifier (starts with letter) - assert result.startswith('o_') - assert result.replace('o_', '').replace('_', '').isalnum() or result == 'o_' - - # Test regular name sanitization (only adds o_ prefix if starts with number) + + # Test name sanitization for Cassandra identifiers + # - Non-alphanumeric chars (except underscore) become underscores + # - Names starting with non-letter get 'r_' prefix + # - All names converted to lowercase name_test_cases = [ ("simple_name", "simple_name"), ("Name-With-Dashes", "name_with_dashes"), ("name.with.dots", "name_with_dots"), - ("123_numbers", "o_123_numbers"), # Only this gets o_ prefix + ("123_numbers", "r_123_numbers"), # Gets r_ prefix (starts with number) ("special!@#chars", "special___chars"), # 3 special chars become 3 underscores ("UPPERCASE", "uppercase"), ("CamelCase", "camelcase"), + ("_underscore_start", "r__underscore_start"), # Gets r_ prefix (starts with underscore) ] - + for input_name, expected_name in name_test_cases: result = processor.sanitize_name(input_name) - assert result == expected_name + assert result == expected_name, f"Expected {expected_name} but got {result} for input {input_name}" + # Verify result is valid Cassandra identifier (starts with letter) + if result: # Skip empty string case + assert result[0].isalpha(), f"Result {result} should start with a letter" def test_primary_key_structure_contract(self): """Test that primary key structure follows Cassandra best practices""" @@ -308,8 +252,8 @@ class TestObjectsCassandraContracts: @pytest.mark.contract -class TestObjectsCassandraContractsBatch: - """Contract tests for Cassandra object storage batch processing""" +class TestRowsCassandraContractsBatch: + """Contract tests for Cassandra row storage batch processing""" def test_extracted_object_batch_input_contract(self): """Test that batched ExtractedObject schema matches expected input format""" diff --git a/tests/contract/test_objects_graphql_query_contracts.py b/tests/contract/test_rows_graphql_query_contracts.py similarity index 89% rename from tests/contract/test_objects_graphql_query_contracts.py rename to tests/contract/test_rows_graphql_query_contracts.py index ceb9dc17..db796306 100644 --- a/tests/contract/test_objects_graphql_query_contracts.py +++ b/tests/contract/test_rows_graphql_query_contracts.py @@ -1,26 +1,26 @@ """ -Contract tests for Objects GraphQL Query Service +Contract tests for Rows GraphQL Query Service These tests verify the message contracts and schema compatibility -for the objects GraphQL query processor. +for the rows GraphQL query processor. """ import pytest import json from pulsar.schema import AvroSchema -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError -from trustgraph.query.objects.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor @pytest.mark.contract -class TestObjectsGraphQLQueryContracts: +class TestRowsGraphQLQueryContracts: """Contract tests for GraphQL query service messages""" - def test_objects_query_request_contract(self): - """Test ObjectsQueryRequest schema structure and required fields""" + def test_rows_query_request_contract(self): + """Test RowsQueryRequest schema structure and required fields""" # Create test request with all required fields - test_request = ObjectsQueryRequest( + test_request = RowsQueryRequest( user="test_user", collection="test_collection", query='{ customers { id name email } }', @@ -49,10 +49,10 @@ class TestObjectsGraphQLQueryContracts: assert test_request.variables["status"] == "active" assert test_request.operation_name == "GetCustomers" - def test_objects_query_request_minimal(self): - """Test ObjectsQueryRequest with minimal required fields""" + def test_rows_query_request_minimal(self): + """Test RowsQueryRequest with minimal required fields""" # Create request with only essential fields - minimal_request = ObjectsQueryRequest( + minimal_request = RowsQueryRequest( user="user", collection="collection", query='{ test }', @@ -91,10 +91,10 @@ class TestObjectsGraphQLQueryContracts: assert test_error.path == ["customers", "0", "nonexistent"] assert test_error.extensions["code"] == "FIELD_ERROR" - def test_objects_query_response_success_contract(self): - """Test ObjectsQueryResponse schema for successful queries""" + def test_rows_query_response_success_contract(self): + """Test RowsQueryResponse schema for successful queries""" # Create successful response - success_response = ObjectsQueryResponse( + success_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}', errors=[], @@ -119,11 +119,11 @@ class TestObjectsGraphQLQueryContracts: assert len(parsed_data["customers"]) == 1 assert parsed_data["customers"][0]["id"] == "1" - def test_objects_query_response_error_contract(self): - """Test ObjectsQueryResponse schema for error cases""" + def test_rows_query_response_error_contract(self): + """Test RowsQueryResponse schema for error cases""" # Create GraphQL errors - work around Pulsar Array(Record) validation bug # by creating a response without the problematic errors array first - error_response = ObjectsQueryResponse( + error_response = RowsQueryResponse( error=None, # System error is None - these are GraphQL errors data=None, # No data due to errors errors=[], # Empty errors array to avoid Pulsar bug @@ -160,14 +160,14 @@ class TestObjectsGraphQLQueryContracts: assert validation_error.path == ["customers", "email"] assert validation_error.extensions["details"] == "Invalid email format" - def test_objects_query_response_system_error_contract(self): - """Test ObjectsQueryResponse schema for system errors""" + def test_rows_query_response_system_error_contract(self): + """Test RowsQueryResponse schema for system errors""" from trustgraph.schema import Error # Create system error response - system_error_response = ObjectsQueryResponse( + system_error_response = RowsQueryResponse( error=Error( - type="objects-query-error", + type="rows-query-error", message="Failed to connect to Cassandra cluster" ), data=None, @@ -177,7 +177,7 @@ class TestObjectsGraphQLQueryContracts: # Verify system error structure assert system_error_response.error is not None - assert system_error_response.error.type == "objects-query-error" + assert system_error_response.error.type == "rows-query-error" assert "Cassandra" in system_error_response.error.message assert system_error_response.data is None assert len(system_error_response.errors) == 0 @@ -186,7 +186,7 @@ class TestObjectsGraphQLQueryContracts: def test_request_response_serialization_contract(self): """Test that request/response can be serialized/deserialized correctly""" # Create original request - original_request = ObjectsQueryRequest( + original_request = RowsQueryRequest( user="serialization_test", collection="test_data", query='{ orders(limit: 5) { id total customer { name } } }', @@ -195,7 +195,7 @@ class TestObjectsGraphQLQueryContracts: ) # Test request serialization using Pulsar schema - request_schema = AvroSchema(ObjectsQueryRequest) + request_schema = AvroSchema(RowsQueryRequest) # Encode and decode request encoded_request = request_schema.encode(original_request) @@ -209,7 +209,7 @@ class TestObjectsGraphQLQueryContracts: assert decoded_request.operation_name == original_request.operation_name # Create original response - work around Pulsar Array(Record) bug - original_response = ObjectsQueryResponse( + original_response = RowsQueryResponse( error=None, data='{"orders": []}', errors=[], # Empty to avoid Pulsar validation bug @@ -224,7 +224,7 @@ class TestObjectsGraphQLQueryContracts: ) # Test response serialization - response_schema = AvroSchema(ObjectsQueryResponse) + response_schema = AvroSchema(RowsQueryResponse) # Encode and decode response encoded_response = response_schema.encode(original_response) @@ -244,7 +244,7 @@ class TestObjectsGraphQLQueryContracts: def test_graphql_query_format_contract(self): """Test supported GraphQL query formats""" # Test basic query - basic_query = ObjectsQueryRequest( + basic_query = RowsQueryRequest( user="test", collection="test", query='{ customers { id } }', variables={}, operation_name="" ) @@ -253,7 +253,7 @@ class TestObjectsGraphQLQueryContracts: assert basic_query.query.strip().endswith('}') # Test query with variables - parameterized_query = ObjectsQueryRequest( + parameterized_query = RowsQueryRequest( user="test", collection="test", query='query GetCustomers($status: String, $limit: Int) { customers(status: $status, limit: $limit) { id name } }', variables={"status": "active", "limit": "10"}, @@ -265,7 +265,7 @@ class TestObjectsGraphQLQueryContracts: assert parameterized_query.operation_name == "GetCustomers" # Test complex nested query - nested_query = ObjectsQueryRequest( + nested_query = RowsQueryRequest( user="test", collection="test", query=''' { @@ -296,7 +296,7 @@ class TestObjectsGraphQLQueryContracts: # Note: Current schema uses Map(String()) which only supports string values # This test verifies the current contract, though ideally we'd support all JSON types - variables_test = ObjectsQueryRequest( + variables_test = RowsQueryRequest( user="test", collection="test", query='{ test }', variables={ "string_var": "test_value", @@ -319,7 +319,7 @@ class TestObjectsGraphQLQueryContracts: def test_cassandra_context_fields_contract(self): """Test that request contains necessary fields for Cassandra operations""" # Verify request has fields needed for Cassandra keyspace/table targeting - request = ObjectsQueryRequest( + request = RowsQueryRequest( user="keyspace_name", # Maps to Cassandra keyspace collection="partition_collection", # Used in partition key query='{ objects { id } }', @@ -338,7 +338,7 @@ class TestObjectsGraphQLQueryContracts: def test_graphql_extensions_contract(self): """Test GraphQL extensions field format and usage""" # Extensions should support query metadata - response_with_extensions = ObjectsQueryResponse( + response_with_extensions = RowsQueryResponse( error=None, data='{"test": "data"}', errors=[], @@ -404,7 +404,7 @@ class TestObjectsGraphQLQueryContracts: ''' # Request to execute specific operation - multi_op_request = ObjectsQueryRequest( + multi_op_request = RowsQueryRequest( user="test", collection="test", query=multi_op_query, variables={}, @@ -417,7 +417,7 @@ class TestObjectsGraphQLQueryContracts: assert "GetOrders" in multi_op_request.query # Test single operation (operation_name optional) - single_op_request = ObjectsQueryRequest( + single_op_request = RowsQueryRequest( user="test", collection="test", query='{ customers { id } }', variables={}, operation_name="" diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index a06ec509..6c83fb05 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -12,7 +12,7 @@ from argparse import ArgumentParser # Import processors that use Cassandra configuration from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter -from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter +from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery from trustgraph.storage.knowledge.store import Processor as KgStore @@ -55,8 +55,8 @@ class TestEndToEndConfigurationFlow: assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3'] assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster): """Test complete flow from environment variables to Cassandra Cluster connection.""" env_vars = { @@ -73,7 +73,7 @@ class TestEndToEndConfigurationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) # Trigger Cassandra connection processor.connect_cassandra() @@ -320,7 +320,7 @@ class TestNoBackwardCompatibilityEndToEnd: class TestMultipleHostsHandling: """Test multiple Cassandra hosts handling end-to-end.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_multiple_hosts_passed_to_cluster(self, mock_cluster): """Test that multiple hosts are correctly passed to Cassandra cluster.""" env_vars = { @@ -333,7 +333,7 @@ class TestMultipleHostsHandling: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify all hosts were passed to Cluster @@ -386,8 +386,8 @@ class TestMultipleHostsHandling: class TestAuthenticationFlow: """Test authentication configuration flow end-to-end.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster): """Test that authentication is enabled when both username and password are provided.""" env_vars = { @@ -402,7 +402,7 @@ class TestAuthenticationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Auth provider should be created @@ -416,8 +416,8 @@ class TestAuthenticationFlow: assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): """Test that authentication is not used when credentials are missing.""" env_vars = { @@ -429,7 +429,7 @@ class TestAuthenticationFlow: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Auth provider should not be created @@ -439,11 +439,11 @@ class TestAuthenticationFlow: call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): """Test that authentication is not used when only username is provided.""" - processor = ObjectsWriter( + processor = RowsWriter( taskgroup=MagicMock(), cassandra_host='partial-auth-host', cassandra_username='partial-user' diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index 7b2245ce..dd48affe 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -11,7 +11,7 @@ import json import asyncio from unittest.mock import AsyncMock, MagicMock, patch -from trustgraph.extract.kg.objects.processor import Processor +from trustgraph.extract.kg.rows.processor import Processor from trustgraph.schema import ( Chunk, ExtractedObject, Metadata, RowSchema, Field, PromptRequest, PromptResponse @@ -220,7 +220,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -288,7 +288,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -353,7 +353,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration @@ -447,7 +447,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Mock flow with failing prompt service @@ -496,7 +496,7 @@ class TestObjectExtractionServiceIntegration: processor.extract_objects_for_schema = Processor.extract_objects_for_schema.__get__(processor, Processor) # Import and bind the convert_values_to_strings function - from trustgraph.extract.kg.objects.processor import convert_values_to_strings + from trustgraph.extract.kg.rows.processor import convert_values_to_strings processor.convert_values_to_strings = convert_values_to_strings # Load configuration diff --git a/tests/integration/test_objects_cassandra_integration.py b/tests/integration/test_objects_cassandra_integration.py deleted file mode 100644 index 3310b396..00000000 --- a/tests/integration/test_objects_cassandra_integration.py +++ /dev/null @@ -1,608 +0,0 @@ -""" -Integration tests for Cassandra Object Storage - -These tests verify the end-to-end functionality of storing ExtractedObjects -in Cassandra, including table creation, data insertion, and error handling. -""" - -import pytest -from unittest.mock import MagicMock, AsyncMock, patch -import json -import uuid - -from trustgraph.storage.objects.cassandra.write import Processor -from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field - - -@pytest.mark.integration -class TestObjectsCassandraIntegration: - """Integration tests for Cassandra object storage""" - - @pytest.fixture - def mock_cassandra_session(self): - """Mock Cassandra session for integration tests""" - session = MagicMock() - - # Track if keyspaces have been created - created_keyspaces = set() - - # Mock the execute method to return a valid result for keyspace checks - def execute_mock(query, *args, **kwargs): - result = MagicMock() - query_str = str(query) - - # Track keyspace creation - if "CREATE KEYSPACE" in query_str: - # Extract keyspace name from query - import re - match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str) - if match: - created_keyspaces.add(match.group(1)) - - # For keyspace existence checks - if "system_schema.keyspaces" in query_str: - # Check if this keyspace was created - if args and args[0] in created_keyspaces: - result.one.return_value = MagicMock() # Exists - else: - result.one.return_value = None # Doesn't exist - else: - result.one.return_value = None - - return result - - session.execute = MagicMock(side_effect=execute_mock) - return session - - @pytest.fixture - def mock_cassandra_cluster(self, mock_cassandra_session): - """Mock Cassandra cluster""" - cluster = MagicMock() - cluster.connect.return_value = mock_cassandra_session - cluster.shutdown = MagicMock() - return cluster - - @pytest.fixture - def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session): - """Create processor with mocked Cassandra dependencies""" - processor = MagicMock() - processor.graph_host = "localhost" - processor.graph_username = None - processor.graph_password = None - processor.config_key = "schema" - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.cluster = None - processor.session = None - - # Bind actual methods - processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) - processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.create_collection = Processor.create_collection.__get__(processor, Processor) - - return processor, mock_cassandra_cluster, mock_cassandra_session - - @pytest.mark.asyncio - async def test_end_to_end_object_storage(self, processor_with_mocks): - """Test complete flow from schema config to object storage""" - processor, mock_cluster, mock_session = processor_with_mocks - - # Mock Cluster creation - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Step 1: Configure schema - config = { - "schema": { - "customer_records": json.dumps({ - "name": "customer_records", - "description": "Customer information", - "fields": [ - {"name": "customer_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string", "required": True}, - {"name": "email", "type": "string", "indexed": True}, - {"name": "age", "type": "integer"} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - assert "customer_records" in processor.schemas - - # Step 1.5: Create the collection first (simulate tg-set-collection) - await processor.create_collection("test_user", "import_2024", {}) - - # Step 2: Process an ExtractedObject - test_obj = ExtractedObject( - metadata=Metadata( - id="doc-001", - user="test_user", - collection="import_2024", - metadata=[] - ), - schema_name="customer_records", - values=[{ - "customer_id": "CUST001", - "name": "John Doe", - "email": "john@example.com", - "age": "30" - }], - confidence=0.95, - source_span="Customer: John Doe..." - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - await processor.on_object(msg, None, None) - - # Verify Cassandra interactions - assert mock_cluster.connect.called - - # Verify keyspace creation - keyspace_calls = [call for call in mock_session.execute.call_args_list - if "CREATE KEYSPACE" in str(call)] - assert len(keyspace_calls) == 1 - assert "test_user" in str(keyspace_calls[0]) - - # Verify table creation - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "o_customer_records" in str(table_calls[0]) # Table gets o_ prefix - assert "collection text" in str(table_calls[0]) - assert "PRIMARY KEY ((collection, customer_id))" in str(table_calls[0]) - - # Verify index creation - index_calls = [call for call in mock_session.execute.call_args_list - if "CREATE INDEX" in str(call)] - assert len(index_calls) == 1 - assert "email" in str(index_calls[0]) - - # Verify data insertion - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - insert_call = insert_calls[0] - assert "test_user.o_customer_records" in str(insert_call) # Table gets o_ prefix - - # Check inserted values - values = insert_call[0][1] - assert "import_2024" in values # collection - assert "CUST001" in values # customer_id - assert "John Doe" in values # name - assert "john@example.com" in values # email - assert 30 in values # age (converted to int) - - @pytest.mark.asyncio - async def test_multi_schema_handling(self, processor_with_mocks): - """Test handling multiple schemas and objects""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure multiple schemas - config = { - "schema": { - "products": json.dumps({ - "name": "products", - "fields": [ - {"name": "product_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string"}, - {"name": "price", "type": "float"} - ] - }), - "orders": json.dumps({ - "name": "orders", - "fields": [ - {"name": "order_id", "type": "string", "primary_key": True}, - {"name": "customer_id", "type": "string"}, - {"name": "total", "type": "float"} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - assert len(processor.schemas) == 2 - - # Create collections first - await processor.create_collection("shop", "catalog", {}) - await processor.create_collection("shop", "sales", {}) - - # Process objects for different schemas - product_obj = ExtractedObject( - metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), - schema_name="products", - values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], - confidence=0.9, - source_span="Product..." - ) - - order_obj = ExtractedObject( - metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), - schema_name="orders", - values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], - confidence=0.85, - source_span="Order..." - ) - - # Process both objects - for obj in [product_obj, order_obj]: - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Verify separate tables were created - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 2 - assert any("o_products" in str(call) for call in table_calls) # Tables get o_ prefix - assert any("o_orders" in str(call) for call in table_calls) # Tables get o_ prefix - - @pytest.mark.asyncio - async def test_missing_required_fields(self, processor_with_mocks): - """Test handling of objects with missing required fields""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema with required field - processor.schemas["test_schema"] = RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True, required=True), - Field(name="required_field", type="string", size=100, required=True) - ] - ) - - # Create collection first - await processor.create_collection("test", "test", {}) - - # Create object missing required field - test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), - schema_name="test_schema", - values=[{"id": "123"}], # missing required_field - confidence=0.8, - source_span="Test" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - # Should still process (Cassandra doesn't enforce NOT NULL) - await processor.on_object(msg, None, None) - - # Verify insert was attempted - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - - @pytest.mark.asyncio - async def test_schema_without_primary_key(self, processor_with_mocks): - """Test handling schemas without defined primary keys""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema without primary key - processor.schemas["events"] = RowSchema( - name="events", - description="Event log", - fields=[ - Field(name="event_type", type="string", size=50), - Field(name="timestamp", type="timestamp", size=0) - ] - ) - - # Create collection first - await processor.create_collection("logger", "app_events", {}) - - # Process object - test_obj = ExtractedObject( - metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]), - schema_name="events", - values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}], - confidence=1.0, - source_span="Event" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - await processor.on_object(msg, None, None) - - # Verify synthetic_id was added - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "synthetic_id uuid" in str(table_calls[0]) - - # Verify insert includes UUID - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 1 - values = insert_calls[0][0][1] - # Check that a UUID was generated (will be in values list) - uuid_found = any(isinstance(v, uuid.UUID) for v in values) - assert uuid_found - - @pytest.mark.asyncio - async def test_authentication_handling(self, processor_with_mocks): - """Test Cassandra authentication""" - processor, mock_cluster, mock_session = processor_with_mocks - processor.cassandra_username = "cassandra_user" - processor.cassandra_password = "cassandra_pass" - - with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class: - with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth: - mock_cluster_class.return_value = mock_cluster - - # Trigger connection - processor.connect_cassandra() - - # Verify authentication was configured - mock_auth.assert_called_once_with( - username="cassandra_user", - password="cassandra_pass" - ) - mock_cluster_class.assert_called_once() - call_kwargs = mock_cluster_class.call_args[1] - assert 'auth_provider' in call_kwargs - - @pytest.mark.asyncio - async def test_error_handling_during_insert(self, processor_with_mocks): - """Test error handling when insertion fails""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["test"] = RowSchema( - name="test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Make insert fail - mock_result = MagicMock() - mock_result.one.return_value = MagicMock() # Keyspace exists - mock_session.execute.side_effect = [ - mock_result, # keyspace existence check succeeds - None, # table creation succeeds - Exception("Connection timeout") # insert fails - ] - - test_obj = ExtractedObject( - metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), - schema_name="test", - values=[{"id": "123"}], - confidence=0.9, - source_span="Test" - ) - - msg = MagicMock() - msg.value.return_value = test_obj - - # Should raise the exception - with pytest.raises(Exception, match="Connection timeout"): - await processor.on_object(msg, None, None) - - @pytest.mark.asyncio - async def test_collection_partitioning(self, processor_with_mocks): - """Test that objects are properly partitioned by collection""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["data"] = RowSchema( - name="data", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Process objects from different collections - collections = ["import_jan", "import_feb", "import_mar"] - - # Create all collections first - for coll in collections: - await processor.create_collection("analytics", coll, {}) - - for coll in collections: - obj = ExtractedObject( - metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]), - schema_name="data", - values=[{"id": f"ID-{coll}"}], - confidence=0.9, - source_span="Data" - ) - - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Verify all inserts include collection in values - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 3 - - # Check each insert has the correct collection - for i, call in enumerate(insert_calls): - values = call[0][1] - assert collections[i] in values - - @pytest.mark.asyncio - async def test_batch_object_processing(self, processor_with_mocks): - """Test processing objects with batched values""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - # Configure schema - config = { - "schema": { - "batch_customers": json.dumps({ - "name": "batch_customers", - "description": "Customer batch data", - "fields": [ - {"name": "customer_id", "type": "string", "primary_key": True}, - {"name": "name", "type": "string", "required": True}, - {"name": "email", "type": "string", "indexed": True} - ] - }) - } - } - - await processor.on_schema_config(config, version=1) - - # Process batch object with multiple values - batch_obj = ExtractedObject( - metadata=Metadata( - id="batch-001", - user="test_user", - collection="batch_import", - metadata=[] - ), - schema_name="batch_customers", - values=[ - { - "customer_id": "CUST001", - "name": "John Doe", - "email": "john@example.com" - }, - { - "customer_id": "CUST002", - "name": "Jane Smith", - "email": "jane@example.com" - }, - { - "customer_id": "CUST003", - "name": "Bob Johnson", - "email": "bob@example.com" - } - ], - confidence=0.92, - source_span="Multiple customers extracted from document" - ) - - # Create collection first - await processor.create_collection("test_user", "batch_import", {}) - - msg = MagicMock() - msg.value.return_value = batch_obj - - await processor.on_object(msg, None, None) - - # Verify table creation - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - assert "o_batch_customers" in str(table_calls[0]) - - # Verify multiple inserts for batch values - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - # Should have 3 separate inserts for the 3 objects in the batch - assert len(insert_calls) == 3 - - # Check each insert has correct data - for i, call in enumerate(insert_calls): - values = call[0][1] - assert "batch_import" in values # collection - assert f"CUST00{i+1}" in values # customer_id - if i == 0: - assert "John Doe" in values - assert "john@example.com" in values - elif i == 1: - assert "Jane Smith" in values - assert "jane@example.com" in values - elif i == 2: - assert "Bob Johnson" in values - assert "bob@example.com" in values - - @pytest.mark.asyncio - async def test_empty_batch_processing(self, processor_with_mocks): - """Test processing objects with empty values array""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["empty_test"] = RowSchema( - name="empty_test", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - - # Create collection first - await processor.create_collection("test", "empty", {}) - - # Process empty batch object - empty_obj = ExtractedObject( - metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), - schema_name="empty_test", - values=[], # Empty batch - confidence=1.0, - source_span="No objects found" - ) - - msg = MagicMock() - msg.value.return_value = empty_obj - - await processor.on_object(msg, None, None) - - # Should still create table - table_calls = [call for call in mock_session.execute.call_args_list - if "CREATE TABLE" in str(call)] - assert len(table_calls) == 1 - - # Should not create any insert statements for empty batch - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 0 - - @pytest.mark.asyncio - async def test_mixed_single_and_batch_objects(self, processor_with_mocks): - """Test processing mix of single and batch objects""" - processor, mock_cluster, mock_session = processor_with_mocks - - with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster): - processor.schemas["mixed_test"] = RowSchema( - name="mixed_test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="data", type="string", size=100) - ] - ) - - # Create collection first - await processor.create_collection("test", "mixed", {}) - - # Single object (backward compatibility) - single_obj = ExtractedObject( - metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]), - schema_name="mixed_test", - values=[{"id": "single-1", "data": "single data"}], # Array with single item - confidence=0.9, - source_span="Single object" - ) - - # Batch object - batch_obj = ExtractedObject( - metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]), - schema_name="mixed_test", - values=[ - {"id": "batch-1", "data": "batch data 1"}, - {"id": "batch-2", "data": "batch data 2"} - ], - confidence=0.85, - source_span="Batch objects" - ) - - # Process both - for obj in [single_obj, batch_obj]: - msg = MagicMock() - msg.value.return_value = obj - await processor.on_object(msg, None, None) - - # Should have 3 total inserts (1 + 2) - insert_calls = [call for call in mock_session.execute.call_args_list - if "INSERT INTO" in str(call)] - assert len(insert_calls) == 3 \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py new file mode 100644 index 00000000..2cb973a7 --- /dev/null +++ b/tests/integration/test_rows_cassandra_integration.py @@ -0,0 +1,492 @@ +""" +Integration tests for Cassandra Row Storage (Unified Table Implementation) + +These tests verify the end-to-end functionality of storing ExtractedObjects +in the unified Cassandra rows table, including table creation, data insertion, +and error handling. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +import json + +from trustgraph.storage.rows.cassandra.write import Processor +from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +@pytest.mark.integration +class TestRowsCassandraIntegration: + """Integration tests for Cassandra row storage with unified table""" + + @pytest.fixture + def mock_cassandra_session(self): + """Mock Cassandra session for integration tests""" + session = MagicMock() + + # Track if keyspaces have been created + created_keyspaces = set() + + # Mock the execute method to return a valid result for keyspace checks + def execute_mock(query, *args, **kwargs): + result = MagicMock() + query_str = str(query) + + # Track keyspace creation + if "CREATE KEYSPACE" in query_str: + import re + match = re.search(r'CREATE KEYSPACE IF NOT EXISTS (\w+)', query_str) + if match: + created_keyspaces.add(match.group(1)) + + # For keyspace existence checks + if "system_schema.keyspaces" in query_str: + if args and args[0] in created_keyspaces: + result.one.return_value = MagicMock() # Exists + else: + result.one.return_value = None # Doesn't exist + else: + result.one.return_value = None + + return result + + session.execute = MagicMock(side_effect=execute_mock) + return session + + @pytest.fixture + def mock_cassandra_cluster(self, mock_cassandra_session): + """Mock Cassandra cluster""" + cluster = MagicMock() + cluster.connect.return_value = mock_cassandra_session + cluster.shutdown = MagicMock() + return cluster + + @pytest.fixture + def processor_with_mocks(self, mock_cassandra_cluster, mock_cassandra_session): + """Create processor with mocked Cassandra dependencies""" + processor = MagicMock() + processor.cassandra_host = ["localhost"] + processor.cassandra_username = None + processor.cassandra_password = None + processor.config_key = "schema" + processor.schemas = {} + processor.known_keyspaces = set() + processor.tables_initialized = set() + processor.registered_partitions = set() + processor.cluster = None + processor.session = None + + # Bind actual methods from the new unified table implementation + processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) + processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) + processor.on_object = Processor.on_object.__get__(processor, Processor) + processor.collection_exists = MagicMock(return_value=True) + + return processor, mock_cassandra_cluster, mock_cassandra_session + + @pytest.mark.asyncio + async def test_end_to_end_object_storage(self, processor_with_mocks): + """Test complete flow from schema config to object storage""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Step 1: Configure schema + config = { + "schema": { + "customer_records": json.dumps({ + "name": "customer_records", + "description": "Customer information", + "fields": [ + {"name": "customer_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string", "required": True}, + {"name": "email", "type": "string", "indexed": True}, + {"name": "age", "type": "integer"} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + assert "customer_records" in processor.schemas + + # Step 2: Process an ExtractedObject + test_obj = ExtractedObject( + metadata=Metadata( + id="doc-001", + user="test_user", + collection="import_2024", + metadata=[] + ), + schema_name="customer_records", + values=[{ + "customer_id": "CUST001", + "name": "John Doe", + "email": "john@example.com", + "age": "30" + }], + confidence=0.95, + source_span="Customer: John Doe..." + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify Cassandra interactions + assert mock_cluster.connect.called + + # Verify keyspace creation + keyspace_calls = [call for call in mock_session.execute.call_args_list + if "CREATE KEYSPACE" in str(call)] + assert len(keyspace_calls) == 1 + assert "test_user" in str(keyspace_calls[0]) + + # Verify unified table creation (rows table, not per-schema table) + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 2 # rows table + row_partitions table + assert any("rows" in str(call) for call in table_calls) + assert any("row_partitions" in str(call) for call in table_calls) + + # Verify the rows table has correct structure + rows_table_call = [call for call in table_calls if ".rows" in str(call)][0] + assert "collection text" in str(rows_table_call) + assert "schema_name text" in str(rows_table_call) + assert "index_name text" in str(rows_table_call) + assert "data map" in str(rows_table_call) + + # Verify data insertion into unified table + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + # Should have 2 data inserts: one for customer_id (primary), one for email (indexed) + assert len(rows_insert_calls) == 2 + + @pytest.mark.asyncio + async def test_multi_schema_handling(self, processor_with_mocks): + """Test handling multiple schemas stored in unified table""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Configure multiple schemas + config = { + "schema": { + "products": json.dumps({ + "name": "products", + "fields": [ + {"name": "product_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string"}, + {"name": "price", "type": "float"} + ] + }), + "orders": json.dumps({ + "name": "orders", + "fields": [ + {"name": "order_id", "type": "string", "primary_key": True}, + {"name": "customer_id", "type": "string"}, + {"name": "total", "type": "float"} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + assert len(processor.schemas) == 2 + + # Process objects for different schemas + product_obj = ExtractedObject( + metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]), + schema_name="products", + values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}], + confidence=0.9, + source_span="Product..." + ) + + order_obj = ExtractedObject( + metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]), + schema_name="orders", + values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}], + confidence=0.85, + source_span="Order..." + ) + + # Process both objects + for obj in [product_obj, order_obj]: + msg = MagicMock() + msg.value.return_value = obj + await processor.on_object(msg, None, None) + + # All data goes into the same unified rows table + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + # Should only create 2 tables: rows + row_partitions (not per-schema tables) + assert len(table_calls) == 2 + + # Verify data inserts go to unified rows table + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) > 0 + for call in rows_insert_calls: + assert ".rows" in str(call) + + @pytest.mark.asyncio + async def test_multi_index_storage(self, processor_with_mocks): + """Test that rows are stored with multiple indexes""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Schema with multiple indexed fields + processor.schemas["indexed_data"] = RowSchema( + name="indexed_data", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True), + Field(name="status", type="string", size=50, indexed=True), + Field(name="description", type="string", size=200) # Not indexed + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + schema_name="indexed_data", + values=[{ + "id": "123", + "category": "electronics", + "status": "active", + "description": "A product" + }], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Should have 3 data inserts (one per indexed field: id, category, status) + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 3 + + # Verify different index names were used + index_names = set() + for call in rows_insert_calls: + values = call[0][1] + index_names.add(values[2]) # index_name is 3rd parameter + + assert index_names == {"id", "category", "status"} + + @pytest.mark.asyncio + async def test_authentication_handling(self, processor_with_mocks): + """Test Cassandra authentication""" + processor, mock_cluster, mock_session = processor_with_mocks + processor.cassandra_username = "cassandra_user" + processor.cassandra_password = "cassandra_pass" + + with patch('trustgraph.storage.rows.cassandra.write.Cluster') as mock_cluster_class: + with patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') as mock_auth: + mock_cluster_class.return_value = mock_cluster + + # Trigger connection + processor.connect_cassandra() + + # Verify authentication was configured + mock_auth.assert_called_once_with( + username="cassandra_user", + password="cassandra_pass" + ) + mock_cluster_class.assert_called_once() + call_kwargs = mock_cluster_class.call_args[1] + assert 'auth_provider' in call_kwargs + + @pytest.mark.asyncio + async def test_batch_object_processing(self, processor_with_mocks): + """Test processing objects with batched values""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + # Configure schema + config = { + "schema": { + "batch_customers": json.dumps({ + "name": "batch_customers", + "description": "Customer batch data", + "fields": [ + {"name": "customer_id", "type": "string", "primary_key": True}, + {"name": "name", "type": "string", "required": True}, + {"name": "email", "type": "string", "indexed": True} + ] + }) + } + } + + await processor.on_schema_config(config, version=1) + + # Process batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_import", + metadata=[] + ), + schema_name="batch_customers", + values=[ + { + "customer_id": "CUST001", + "name": "John Doe", + "email": "john@example.com" + }, + { + "customer_id": "CUST002", + "name": "Jane Smith", + "email": "jane@example.com" + }, + { + "customer_id": "CUST003", + "name": "Bob Johnson", + "email": "bob@example.com" + } + ], + confidence=0.92, + source_span="Multiple customers extracted from document" + ) + + msg = MagicMock() + msg.value.return_value = batch_obj + + await processor.on_object(msg, None, None) + + # Verify unified table creation + table_calls = [call for call in mock_session.execute.call_args_list + if "CREATE TABLE" in str(call)] + assert len(table_calls) == 2 # rows + row_partitions + + # Each row in batch gets 2 data inserts (customer_id primary + email indexed) + # 3 rows * 2 indexes = 6 data inserts + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 6 + + @pytest.mark.asyncio + async def test_empty_batch_processing(self, processor_with_mocks): + """Test processing objects with empty values array""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["empty_test"] = RowSchema( + name="empty_test", + fields=[Field(name="id", type="string", size=50, primary=True)] + ) + + # Process empty batch object + empty_obj = ExtractedObject( + metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]), + schema_name="empty_test", + values=[], # Empty batch + confidence=1.0, + source_span="No objects found" + ) + + msg = MagicMock() + msg.value.return_value = empty_obj + + await processor.on_object(msg, None, None) + + # Should not create any data insert statements for empty batch + # (partition registration may still happen) + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) == 0 + + @pytest.mark.asyncio + async def test_data_stored_as_map(self, processor_with_mocks): + """Test that data is stored as map""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["map_test"] = RowSchema( + name="map_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="name", type="string", size=100), + Field(name="count", type="integer", size=0) + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="test", metadata=[]), + schema_name="map_test", + values=[{"id": "123", "name": "Test Item", "count": "42"}], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify insert uses map for data + rows_insert_calls = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and ".rows" in str(call) + and "row_partitions" not in str(call)] + assert len(rows_insert_calls) >= 1 + + # Check that data is passed as a dict (will be map in Cassandra) + insert_call = rows_insert_calls[0] + values = insert_call[0][1] + # Values are: (collection, schema_name, index_name, index_value, data, source) + # values[4] should be the data map + data_map = values[4] + assert isinstance(data_map, dict) + assert data_map["id"] == "123" + assert data_map["name"] == "Test Item" + assert data_map["count"] == "42" + + @pytest.mark.asyncio + async def test_partition_registration(self, processor_with_mocks): + """Test that partitions are registered for efficient querying""" + processor, mock_cluster, mock_session = processor_with_mocks + + with patch('trustgraph.storage.rows.cassandra.write.Cluster', return_value=mock_cluster): + processor.schemas["partition_test"] = RowSchema( + name="partition_test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="category", type="string", size=50, indexed=True) + ] + ) + + test_obj = ExtractedObject( + metadata=Metadata(id="t1", user="test", collection="my_collection", metadata=[]), + schema_name="partition_test", + values=[{"id": "123", "category": "test"}], + confidence=0.9, + source_span="Test" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Verify partition registration + partition_inserts = [call for call in mock_session.execute.call_args_list + if "INSERT INTO" in str(call) and "row_partitions" in str(call)] + # Should register partitions for each index (id, category) + assert len(partition_inserts) == 2 + + # Verify cache was updated + assert ("my_collection", "partition_test") in processor.registered_partitions diff --git a/tests/integration/test_objects_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py similarity index 98% rename from tests/integration/test_objects_graphql_query_integration.py rename to tests/integration/test_rows_graphql_query_integration.py index 13b12532..a717901b 100644 --- a/tests/integration/test_objects_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -1,5 +1,5 @@ """ -Integration tests for Objects GraphQL Query Service +Integration tests for Rows GraphQL Query Service These tests verify end-to-end functionality including: - Real Cassandra database operations @@ -24,8 +24,8 @@ except Exception: DOCKER_AVAILABLE = False CassandraContainer = None -from trustgraph.query.objects.cassandra.service import Processor -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata @@ -390,7 +390,7 @@ class TestObjectsGraphQLQueryIntegration: processor.connect_cassandra() # Create mock message - request = ObjectsQueryRequest( + request = RowsQueryRequest( user="msg_test_user", collection="msg_test_collection", query='{ customer_objects { customer_id name } }', @@ -415,7 +415,7 @@ class TestObjectsGraphQLQueryIntegration: # Verify response structure sent_response = mock_response_producer.send.call_args[0][0] - assert isinstance(sent_response, ObjectsQueryResponse) + assert isinstance(sent_response, RowsQueryResponse) # Should have no system error (even if no data) assert sent_response.error is None diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index cf8037d0..d5fb5672 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -2,7 +2,7 @@ Integration tests for Structured Query Service These tests verify the end-to-end functionality of the structured query service, -testing orchestration between nlp-query and objects-query services. +testing orchestration between nlp-query and rows-query services. Following the TEST_STRATEGY.md approach for integration testing. """ @@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.schema import ( StructuredQueryRequest, StructuredQueryResponse, QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, - ObjectsQueryRequest, ObjectsQueryResponse, + RowsQueryRequest, RowsQueryResponse, Error, GraphQLError ) from trustgraph.retrieval.structured_query.service import Processor @@ -81,7 +81,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects Query Service Response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}', errors=None, @@ -99,7 +99,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -121,7 +121,7 @@ class TestStructuredQueryServiceIntegration: # Verify Objects service call mock_objects_client.request.assert_called_once() objects_call_args = mock_objects_client.request.call_args[0][0] - assert isinstance(objects_call_args, ObjectsQueryRequest) + assert isinstance(objects_call_args, RowsQueryRequest) assert "customers" in objects_call_args.query assert "orders" in objects_call_args.query assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string @@ -220,7 +220,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects service failure - objects_error_response = ObjectsQueryResponse( + objects_error_response = RowsQueryResponse( error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"), data=None, errors=None, @@ -237,7 +237,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -255,7 +255,7 @@ class TestStructuredQueryServiceIntegration: assert response.error is not None assert response.error.type == "structured-query-error" - assert "Objects query service error" in response.error.message + assert "Rows query service error" in response.error.message assert "nonexistent_table" in response.error.message @pytest.mark.asyncio @@ -298,7 +298,7 @@ class TestStructuredQueryServiceIntegration: ) ] - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, # No data when validation fails errors=validation_errors, @@ -315,7 +315,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -422,7 +422,7 @@ class TestStructuredQueryServiceIntegration: ] } - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=json.dumps(complex_data), errors=None, @@ -443,7 +443,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -503,7 +503,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock empty Objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": []}', # Empty result set errors=None, @@ -520,7 +520,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -577,7 +577,7 @@ class TestStructuredQueryServiceIntegration: confidence=0.9 ) - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=f'{{"test_{i}": [{{"id": "{i}"}}]}}', errors=None, @@ -599,7 +599,7 @@ class TestStructuredQueryServiceIntegration: if service_name == "nlp-query-request": service_call_count += 1 return nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": service_call_count += 1 return objects_client elif service_name == "response": @@ -700,7 +700,7 @@ class TestStructuredQueryServiceIntegration: ) # Mock Objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}', errors=None, @@ -717,7 +717,7 @@ class TestStructuredQueryServiceIntegration: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response diff --git a/tests/unit/test_embeddings/test_row_embeddings_processor.py b/tests/unit/test_embeddings/test_row_embeddings_processor.py new file mode 100644 index 00000000..47405431 --- /dev/null +++ b/tests/unit/test_embeddings/test_row_embeddings_processor.py @@ -0,0 +1,380 @@ +""" +Unit tests for trustgraph.embeddings.row_embeddings.embeddings +Tests the Stage 1 processor that computes embeddings for row index fields. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + + +class TestRowEmbeddingsProcessor(IsolatedAsyncioTestCase): + """Test row embeddings processor functionality""" + + async def test_processor_initialization(self): + """Test basic processor initialization""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-row-embeddings' + } + + processor = Processor(**config) + + assert hasattr(processor, 'schemas') + assert processor.schemas == {} + assert processor.batch_size == 10 # default + + async def test_processor_initialization_with_custom_batch_size(self): + """Test processor initialization with custom batch size""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-row-embeddings', + 'batch_size': 25 + } + + processor = Processor(**config) + + assert processor.batch_size == 25 + + async def test_get_index_names_single_index(self): + """Test getting index names with single indexed field""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import RowSchema, Field + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + schema = RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + Field(name='email', type='text', indexed=False), + ] + ) + + index_names = processor.get_index_names(schema) + + # Should include primary key and indexed field + assert 'id' in index_names + assert 'name' in index_names + assert 'email' not in index_names + + async def test_get_index_names_no_indexes(self): + """Test getting index names when no fields are indexed""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import RowSchema, Field + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + schema = RowSchema( + name='logs', + description='Log records', + fields=[ + Field(name='timestamp', type='text'), + Field(name='message', type='text'), + ] + ) + + index_names = processor.get_index_names(schema) + + assert index_names == [] + + async def test_build_index_value_single_field(self): + """Test building index value for single field""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'id': 'CUST001', + 'name': 'John Doe', + 'email': 'john@example.com' + } + + result = processor.build_index_value(value_map, 'name') + + assert result == ['John Doe'] + + async def test_build_index_value_composite_index(self): + """Test building index value for composite index""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'first_name': 'John', + 'last_name': 'Doe', + 'city': 'New York' + } + + result = processor.build_index_value(value_map, 'first_name, last_name') + + assert result == ['John', 'Doe'] + + async def test_build_index_value_missing_field(self): + """Test building index value when field is missing""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + value_map = { + 'name': 'John Doe' + } + + result = processor.build_index_value(value_map, 'missing_field') + + assert result == [''] + + async def test_build_text_for_embedding_single_value(self): + """Test building text representation for single value""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + result = processor.build_text_for_embedding(['John Doe']) + + assert result == 'John Doe' + + async def test_build_text_for_embedding_multiple_values(self): + """Test building text representation for multiple values""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + result = processor.build_text_for_embedding(['John', 'Doe', 'NYC']) + + assert result == 'John Doe NYC' + + async def test_on_schema_config_loads_schemas(self): + """Test that schema configuration is loaded correctly""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + import json + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + + schema_def = { + 'name': 'customers', + 'description': 'Customer records', + 'fields': [ + {'name': 'id', 'type': 'text', 'primary_key': True}, + {'name': 'name', 'type': 'text', 'indexed': True}, + {'name': 'email', 'type': 'text'} + ] + } + + config_data = { + 'schema': { + 'customers': json.dumps(schema_def) + } + } + + await processor.on_schema_config(config_data, 1) + + assert 'customers' in processor.schemas + assert processor.schemas['customers'].name == 'customers' + assert len(processor.schemas['customers'].fields) == 3 + + async def test_on_schema_config_handles_missing_type(self): + """Test that missing schema type is handled gracefully""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + + config_data = { + 'other_type': {} + } + + await processor.on_schema_config(config_data, 1) + + assert processor.schemas == {} + + async def test_on_message_drops_unknown_collection(self): + """Test that messages for unknown collections are dropped""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + # No collections registered + + metadata = MagicMock() + metadata.user = 'unknown_user' + metadata.collection = 'unknown_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='customers', + values=[{'id': '123', 'name': 'Test'}] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + mock_flow = MagicMock() + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Flow should not be called for output + mock_flow.assert_not_called() + + async def test_on_message_drops_unknown_schema(self): + """Test that messages for unknown schemas are dropped""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + # No schemas registered + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='unknown_schema', + values=[{'id': '123', 'name': 'Test'}] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + mock_flow = MagicMock() + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Flow should not be called for output + mock_flow.assert_not_called() + + async def test_on_message_processes_embeddings(self): + """Test processing a message and computing embeddings""" + from trustgraph.embeddings.row_embeddings.embeddings import Processor + from trustgraph.schema import ExtractedObject, RowSchema, Field + import json + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor', + 'config_type': 'schema' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + # Set up schema + processor.schemas['customers'] = RowSchema( + name='customers', + description='Customer records', + fields=[ + Field(name='id', type='text', primary=True), + Field(name='name', type='text', indexed=True), + ] + ) + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + obj = ExtractedObject( + metadata=metadata, + schema_name='customers', + values=[ + {'id': 'CUST001', 'name': 'John Doe'}, + {'id': 'CUST002', 'name': 'Jane Smith'} + ] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = obj + + # Mock the flow + mock_embeddings_request = AsyncMock() + mock_embeddings_request.embed.return_value = [[0.1, 0.2, 0.3]] + + mock_output = AsyncMock() + + def flow_factory(name): + if name == 'embeddings-request': + return mock_embeddings_request + elif name == 'output': + return mock_output + return MagicMock() + + mock_flow = MagicMock(side_effect=flow_factory) + + await processor.on_message(mock_msg, MagicMock(), mock_flow) + + # Should have called embed for each unique text + # 4 values: CUST001, John Doe, CUST002, Jane Smith + assert mock_embeddings_request.embed.call_count == 4 + + # Should have sent output + mock_output.send.assert_called() + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_gateway/test_objects_import_dispatcher.py b/tests/unit/test_gateway/test_rows_import_dispatcher.py similarity index 83% rename from tests/unit/test_gateway/test_objects_import_dispatcher.py rename to tests/unit/test_gateway/test_rows_import_dispatcher.py index 0332c1a1..ab72cae1 100644 --- a/tests/unit/test_gateway/test_objects_import_dispatcher.py +++ b/tests/unit/test_gateway/test_rows_import_dispatcher.py @@ -1,7 +1,7 @@ """ -Unit tests for objects import dispatcher. +Unit tests for rows import dispatcher. -Tests the business logic of objects import dispatcher +Tests the business logic of rows import dispatcher while mocking the Publisher and websocket components. """ @@ -11,7 +11,7 @@ import asyncio from unittest.mock import Mock, AsyncMock, patch, MagicMock from aiohttp import web -from trustgraph.gateway.dispatch.objects_import import ObjectsImport +from trustgraph.gateway.dispatch.rows_import import RowsImport from trustgraph.schema import Metadata, ExtractedObject @@ -92,16 +92,16 @@ def minimal_objects_message(): } -class TestObjectsImportInitialization: - """Test ObjectsImport initialization.""" +class TestRowsImportInitialization: + """Test RowsImport initialization.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): - """Test that ObjectsImport creates Publisher with correct parameters.""" + """Test that RowsImport creates Publisher with correct parameters.""" mock_publisher_instance = Mock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -116,28 +116,28 @@ class TestObjectsImportInitialization: ) # Verify instance variables are set correctly - assert objects_import.ws == mock_websocket - assert objects_import.running == mock_running - assert objects_import.publisher == mock_publisher_instance + assert rows_import.ws == mock_websocket + assert rows_import.running == mock_running + assert rows_import.publisher == mock_publisher_instance - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): - """Test that ObjectsImport stores all required references.""" - objects_import = ObjectsImport( + """Test that RowsImport stores all required references.""" + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="objects-queue" ) - assert objects_import.ws is mock_websocket - assert objects_import.running is mock_running + assert rows_import.ws is mock_websocket + assert rows_import.running is mock_running -class TestObjectsImportLifecycle: - """Test ObjectsImport lifecycle methods.""" +class TestRowsImportLifecycle: + """Test RowsImport lifecycle methods.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that start() calls publisher.start().""" @@ -145,18 +145,18 @@ class TestObjectsImportLifecycle: mock_publisher_instance.start = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.start() + await rows_import.start() mock_publisher_instance.start.assert_called_once() - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that destroy() properly stops publisher and closes websocket.""" @@ -164,21 +164,21 @@ class TestObjectsImportLifecycle: mock_publisher_instance.stop = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.destroy() + await rows_import.destroy() # Verify sequence of operations mock_running.stop.assert_called_once() mock_publisher_instance.stop.assert_called_once() mock_websocket.close.assert_called_once() - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running): """Test that destroy() handles None websocket gracefully.""" @@ -186,7 +186,7 @@ class TestObjectsImportLifecycle: mock_publisher_instance.stop = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=None, # None websocket running=mock_running, backend=mock_backend, @@ -194,16 +194,16 @@ class TestObjectsImportLifecycle: ) # Should not raise exception - await objects_import.destroy() + await rows_import.destroy() mock_running.stop.assert_called_once() mock_publisher_instance.stop.assert_called_once() -class TestObjectsImportMessageProcessing: - """Test ObjectsImport message processing.""" +class TestRowsImportMessageProcessing: + """Test RowsImport message processing.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() processes complete message correctly.""" @@ -211,7 +211,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -222,7 +222,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = sample_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -246,7 +246,7 @@ class TestObjectsImportMessageProcessing: assert sent_object.metadata.collection == "testcollection" assert len(sent_object.metadata.metadata) == 1 # One triple in metadata - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message): """Test that receive() handles message with minimal required fields.""" @@ -254,7 +254,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -265,7 +265,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = minimal_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -279,7 +279,7 @@ class TestObjectsImportMessageProcessing: assert sent_object.source_span == "" # Default value assert len(sent_object.metadata.metadata) == 0 # Default empty list - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() uses appropriate default values for optional fields.""" @@ -287,7 +287,7 @@ class TestObjectsImportMessageProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -309,7 +309,7 @@ class TestObjectsImportMessageProcessing: mock_msg = Mock() mock_msg.json.return_value = message_data - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Get the sent object and verify defaults sent_object = mock_publisher_instance.send.call_args[0][1] @@ -317,11 +317,11 @@ class TestObjectsImportMessageProcessing: assert sent_object.source_span == "" -class TestObjectsImportRunMethod: - """Test ObjectsImport run method.""" +class TestRowsImportRunMethod: + """Test RowsImport run method.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep') @pytest.mark.asyncio async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that run() loops while running.get() returns True.""" @@ -331,14 +331,14 @@ class TestObjectsImportRunMethod: # Set up running state to return True twice, then False mock_running.get.side_effect = [True, True, False] - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, queue="test-queue" ) - await objects_import.run() + await rows_import.run() # Verify sleep was called twice (for the two True iterations) assert mock_sleep.call_count == 2 @@ -348,10 +348,10 @@ class TestObjectsImportRunMethod: mock_websocket.close.assert_called_once() # Verify websocket was set to None - assert objects_import.ws is None + assert rows_import.ws is None - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.asyncio.sleep') @pytest.mark.asyncio async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running): """Test that run() handles None websocket gracefully.""" @@ -360,7 +360,7 @@ class TestObjectsImportRunMethod: mock_running.get.return_value = False # Exit immediately - objects_import = ObjectsImport( + rows_import = RowsImport( ws=None, # None websocket running=mock_running, backend=mock_backend, @@ -368,14 +368,14 @@ class TestObjectsImportRunMethod: ) # Should not raise exception - await objects_import.run() + await rows_import.run() # Verify websocket remains None - assert objects_import.ws is None + assert rows_import.ws is None -class TestObjectsImportBatchProcessing: - """Test ObjectsImport batch processing functionality.""" +class TestRowsImportBatchProcessing: + """Test RowsImport batch processing functionality.""" @pytest.fixture def batch_objects_message(self): @@ -415,7 +415,7 @@ class TestObjectsImportBatchProcessing: "source_span": "Multiple people found in document" } - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message): """Test that receive() processes batch message correctly.""" @@ -423,7 +423,7 @@ class TestObjectsImportBatchProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -434,7 +434,7 @@ class TestObjectsImportBatchProcessing: mock_msg = Mock() mock_msg.json.return_value = batch_objects_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Verify publisher.send was called mock_publisher_instance.send.assert_called_once() @@ -465,7 +465,7 @@ class TestObjectsImportBatchProcessing: assert sent_object.confidence == 0.85 assert sent_object.source_span == "Multiple people found in document" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles empty batch correctly.""" @@ -473,7 +473,7 @@ class TestObjectsImportBatchProcessing: mock_publisher_instance.send = AsyncMock() mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -494,7 +494,7 @@ class TestObjectsImportBatchProcessing: mock_msg = Mock() mock_msg.json.return_value = empty_batch_message - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) # Should still send the message mock_publisher_instance.send.assert_called_once() @@ -502,10 +502,10 @@ class TestObjectsImportBatchProcessing: assert len(sent_object.values) == 0 -class TestObjectsImportErrorHandling: - """Test error handling in ObjectsImport.""" +class TestRowsImportErrorHandling: + """Test error handling in RowsImport.""" - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() propagates publisher send errors.""" @@ -513,7 +513,7 @@ class TestObjectsImportErrorHandling: mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error")) mock_publisher_class.return_value = mock_publisher_instance - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -524,15 +524,15 @@ class TestObjectsImportErrorHandling: mock_msg.json.return_value = sample_objects_message with pytest.raises(Exception, match="Publisher error"): - await objects_import.receive(mock_msg) + await rows_import.receive(mock_msg) - @patch('trustgraph.gateway.dispatch.objects_import.Publisher') + @patch('trustgraph.gateway.dispatch.rows_import.Publisher') @pytest.mark.asyncio async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles malformed JSON appropriately.""" mock_publisher_class.return_value = Mock() - objects_import = ObjectsImport( + rows_import = RowsImport( ws=mock_websocket, running=mock_running, backend=mock_backend, @@ -543,4 +543,4 @@ class TestObjectsImportErrorHandling: mock_msg.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) with pytest.raises(json.JSONDecodeError): - await objects_import.receive(mock_msg) \ No newline at end of file + await rows_import.receive(mock_msg) \ No newline at end of file diff --git a/tests/unit/test_knowledge_graph/test_object_validation.py b/tests/unit/test_knowledge_graph/test_object_validation.py index b2ac28aa..47d2e4d7 100644 --- a/tests/unit/test_knowledge_graph/test_object_validation.py +++ b/tests/unit/test_knowledge_graph/test_object_validation.py @@ -76,7 +76,7 @@ def cities_schema(): def validator(): """Create a mock processor with just the validation method""" from unittest.mock import MagicMock - from trustgraph.extract.kg.objects.processor import Processor + from trustgraph.extract.kg.rows.processor import Processor # Create a mock processor mock_processor = MagicMock() diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index f86ae3da..80443a0c 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -167,7 +167,7 @@ class TestFlowClient: expected_methods = [ 'text_completion', 'agent', 'graph_rag', 'document_rag', 'graph_embeddings_query', 'embeddings', 'prompt', - 'triples_query', 'objects_query' + 'triples_query', 'rows_query' ] for method in expected_methods: @@ -216,7 +216,7 @@ class TestSocketClient: expected_methods = [ 'agent', 'text_completion', 'graph_rag', 'document_rag', 'prompt', 'graph_embeddings_query', 'embeddings', - 'triples_query', 'objects_query', 'mcp_tool' + 'triples_query', 'rows_query', 'mcp_tool' ] for method in expected_methods: @@ -243,7 +243,7 @@ class TestBulkClient: 'import_graph_embeddings', 'import_document_embeddings', 'import_entity_contexts', - 'import_objects' + 'import_rows' ] for method in import_methods: diff --git a/tests/unit/test_query/test_objects_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py similarity index 52% rename from tests/unit/test_query/test_objects_cassandra_query.py rename to tests/unit/test_query/test_rows_cassandra_query.py index ab11d5a1..879a81c5 100644 --- a/tests/unit/test_query/test_objects_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -1,10 +1,11 @@ """ -Unit tests for Cassandra Objects GraphQL Query Processor +Unit tests for Cassandra Rows GraphQL Query Processor (Unified Table Implementation) Tests the business logic of the GraphQL query processor including: -- GraphQL schema generation from RowSchema -- Query execution and validation -- CQL translation logic +- Schema configuration handling +- Query execution using unified rows table +- Name sanitization +- GraphQL query execution - Message processing logic """ @@ -12,119 +13,91 @@ import pytest from unittest.mock import MagicMock, AsyncMock, patch import json -import strawberry -from strawberry import Schema - -from trustgraph.query.objects.cassandra.service import Processor -from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError +from trustgraph.query.rows.cassandra.service import Processor +from trustgraph.schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from trustgraph.schema import RowSchema, Field -class TestObjectsGraphQLQueryLogic: - """Test business logic without external dependencies""" - - def test_get_python_type_mapping(self): - """Test schema field type conversion to Python types""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - - # Basic type mappings - assert processor.get_python_type("string") == str - assert processor.get_python_type("integer") == int - assert processor.get_python_type("float") == float - assert processor.get_python_type("boolean") == bool - assert processor.get_python_type("timestamp") == str - assert processor.get_python_type("date") == str - assert processor.get_python_type("time") == str - assert processor.get_python_type("uuid") == str - - # Unknown type defaults to str - assert processor.get_python_type("unknown_type") == str - - def test_create_graphql_type_basic_fields(self): - """Test GraphQL type creation for basic field types""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create test schema - schema = RowSchema( - name="test_table", - description="Test table", - fields=[ - Field( - name="id", - type="string", - primary=True, - required=True, - description="Primary key" - ), - Field( - name="name", - type="string", - required=True, - description="Name field" - ), - Field( - name="age", - type="integer", - required=False, - description="Optional age" - ), - Field( - name="active", - type="boolean", - required=False, - description="Status flag" - ) - ] - ) - - # Create GraphQL type - graphql_type = processor.create_graphql_type("test_table", schema) - - # Verify type was created - assert graphql_type is not None - assert hasattr(graphql_type, '__name__') - assert "TestTable" in graphql_type.__name__ or "test_table" in graphql_type.__name__.lower() +class TestRowsGraphQLQueryLogic: + """Test business logic for unified table query implementation""" def test_sanitize_name_cassandra_compatibility(self): """Test name sanitization for Cassandra field names""" processor = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - - # Test field name sanitization (matches storage processor) + + # Test field name sanitization (uses r_ prefix like storage processor) assert processor.sanitize_name("simple_field") == "simple_field" assert processor.sanitize_name("Field-With-Dashes") == "field_with_dashes" assert processor.sanitize_name("field.with.dots") == "field_with_dots" - assert processor.sanitize_name("123_field") == "o_123_field" + assert processor.sanitize_name("123_field") == "r_123_field" assert processor.sanitize_name("field with spaces") == "field_with_spaces" assert processor.sanitize_name("special!@#chars") == "special___chars" assert processor.sanitize_name("UPPERCASE") == "uppercase" assert processor.sanitize_name("CamelCase") == "camelcase" - def test_sanitize_table_name(self): - """Test table name sanitization (always gets o_ prefix)""" + def test_get_index_names(self): + """Test extraction of index names from schema""" processor = MagicMock() - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - - # Table names always get o_ prefix - assert processor.sanitize_table("simple_table") == "o_simple_table" - assert processor.sanitize_table("Table-Name") == "o_table_name" - assert processor.sanitize_table("123table") == "o_123table" - assert processor.sanitize_table("") == "o_" + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + schema = RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string"), # Not indexed + Field(name="status", type="string", indexed=True) + ] + ) + + index_names = processor.get_index_names(schema) + + assert "id" in index_names + assert "category" in index_names + assert "status" in index_names + assert "name" not in index_names + assert len(index_names) == 3 + + def test_find_matching_index_exact_match(self): + """Test finding matching index for exact match query""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + + schema = RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string") # Not indexed + ] + ) + + # Filter on indexed field should return match + filters = {"category": "electronics"} + result = processor.find_matching_index(schema, filters) + assert result is not None + assert result[0] == "category" + assert result[1] == ["electronics"] + + # Filter on non-indexed field should return None + filters = {"name": "test"} + result = processor.find_matching_index(schema, filters) + assert result is None @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" processor = MagicMock() processor.schemas = {} - processor.graphql_types = {} - processor.graphql_schema = None - processor.config_key = "schema" # Set the config key - processor.generate_graphql_schema = AsyncMock() + processor.config_key = "schema" + processor.schema_builder = MagicMock() + processor.schema_builder.clear = MagicMock() + processor.schema_builder.add_schema = MagicMock() + processor.schema_builder.build = MagicMock(return_value=MagicMock()) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - + # Create test config schema_config = { "schema": { @@ -154,96 +127,29 @@ class TestObjectsGraphQLQueryLogic: }) } } - + # Process config await processor.on_schema_config(schema_config, version=1) - + # Verify schema was loaded assert "customer" in processor.schemas schema = processor.schemas["customer"] assert schema.name == "customer" assert len(schema.fields) == 3 - + # Verify fields id_field = next(f for f in schema.fields if f.name == "id") assert id_field.primary is True - # The field should have been created correctly from JSON - # Let's test what we can verify - that the field has the right attributes - assert hasattr(id_field, 'required') # Has the required attribute - assert hasattr(id_field, 'primary') # Has the primary attribute - + email_field = next(f for f in schema.fields if f.name == "email") assert email_field.indexed is True - + status_field = next(f for f in schema.fields if f.name == "status") assert status_field.enum_values == ["active", "inactive"] - - # Verify GraphQL schema regeneration was called - processor.generate_graphql_schema.assert_called_once() - def test_cql_query_building_basic(self): - """Test basic CQL query construction""" - processor = MagicMock() - processor.session = MagicMock() - processor.connect_cassandra = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.parse_filter_key = Processor.parse_filter_key.__get__(processor, Processor) - processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) - - # Mock session execute to capture the query - mock_result = [] - processor.session.execute.return_value = mock_result - - # Create test schema - schema = RowSchema( - name="test_table", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string", indexed=True), - Field(name="status", type="string") - ] - ) - - # Test query building - asyncio = pytest.importorskip("asyncio") - - async def run_test(): - await processor.query_cassandra( - user="test_user", - collection="test_collection", - schema_name="test_table", - row_schema=schema, - filters={"name": "John", "invalid_filter": "ignored"}, - limit=10 - ) - - # Run the async test - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(run_test()) - finally: - loop.close() - - # Verify Cassandra connection and query execution - processor.connect_cassandra.assert_called_once() - processor.session.execute.assert_called_once() - - # Verify the query structure (can't easily test exact query without complex mocking) - call_args = processor.session.execute.call_args - query = call_args[0][0] # First positional argument is the query - params = call_args[0][1] # Second positional argument is parameters - - # Basic query structure checks - assert "SELECT * FROM test_user.o_test_table" in query - assert "WHERE" in query - assert "collection = %s" in query - assert "LIMIT 10" in query - - # Parameters should include collection and name filter - assert "test_collection" in params - assert "John" in params + # Verify schema builder was called + processor.schema_builder.add_schema.assert_called_once() + processor.schema_builder.build.assert_called_once() @pytest.mark.asyncio async def test_graphql_context_handling(self): @@ -251,13 +157,13 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.graphql_schema = AsyncMock() processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) - + # Mock schema execution mock_result = MagicMock() mock_result.data = {"customers": [{"id": "1", "name": "Test"}]} mock_result.errors = None processor.graphql_schema.execute.return_value = mock_result - + result = await processor.execute_graphql_query( query='{ customers { id name } }', variables={}, @@ -265,17 +171,17 @@ class TestObjectsGraphQLQueryLogic: user="test_user", collection="test_collection" ) - + # Verify schema.execute was called with correct context processor.graphql_schema.execute.assert_called_once() call_args = processor.graphql_schema.execute.call_args - + # Verify context was passed - context = call_args[1]['context_value'] # keyword argument + context = call_args[1]['context_value'] assert context["processor"] == processor assert context["user"] == "test_user" assert context["collection"] == "test_collection" - + # Verify result structure assert "data" in result assert result["data"] == {"customers": [{"id": "1", "name": "Test"}]} @@ -286,104 +192,79 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.graphql_schema = AsyncMock() processor.execute_graphql_query = Processor.execute_graphql_query.__get__(processor, Processor) - - # Create a simple object to simulate GraphQL error instead of MagicMock + + # Create a simple object to simulate GraphQL error class MockError: def __init__(self, message, path, extensions): self.message = message self.path = path self.extensions = extensions - + def __str__(self): return self.message - + mock_error = MockError( message="Field 'invalid_field' doesn't exist", path=["customers", "0", "invalid_field"], extensions={"code": "FIELD_NOT_FOUND"} ) - + mock_result = MagicMock() mock_result.data = None mock_result.errors = [mock_error] processor.graphql_schema.execute.return_value = mock_result - + result = await processor.execute_graphql_query( query='{ customers { invalid_field } }', variables={}, operation_name=None, - user="test_user", + user="test_user", collection="test_collection" ) - + # Verify error handling assert "errors" in result assert len(result["errors"]) == 1 - + error = result["errors"][0] assert error["message"] == "Field 'invalid_field' doesn't exist" - assert error["path"] == ["customers", "0", "invalid_field"] # Fixed to match string path + assert error["path"] == ["customers", "0", "invalid_field"] assert error["extensions"] == {"code": "FIELD_NOT_FOUND"} - def test_schema_generation_basic_structure(self): - """Test basic GraphQL schema generation structure""" - processor = MagicMock() - processor.schemas = { - "customer": RowSchema( - name="customer", - fields=[ - Field(name="id", type="string", primary=True), - Field(name="name", type="string") - ] - ) - } - processor.graphql_types = {} - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Test individual type creation (avoiding the full schema generation which has annotation issues) - graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"]) - processor.graphql_types["customer"] = graphql_type - - # Verify type was created - assert len(processor.graphql_types) == 1 - assert "customer" in processor.graphql_types - assert processor.graphql_types["customer"] is not None - @pytest.mark.asyncio async def test_message_processing_success(self): """Test successful message processing flow""" processor = MagicMock() processor.execute_graphql_query = AsyncMock() processor.on_message = Processor.on_message.__get__(processor, Processor) - + # Mock successful query result processor.execute_graphql_query.return_value = { "data": {"customers": [{"id": "1", "name": "John"}]}, "errors": [], - "extensions": {"execution_time": "0.1"} # Extensions must be strings for Map(String()) + "extensions": {} } - + # Create mock message mock_msg = MagicMock() - mock_request = ObjectsQueryRequest( + mock_request = RowsQueryRequest( user="test_user", - collection="test_collection", + collection="test_collection", query='{ customers { id name } }', variables={}, operation_name=None ) mock_msg.value.return_value = mock_request mock_msg.properties.return_value = {"id": "test-123"} - + # Mock flow mock_flow = MagicMock() mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow - + # Process message await processor.on_message(mock_msg, None, mock_flow) - + # Verify query was executed processor.execute_graphql_query.assert_called_once_with( query='{ customers { id name } }', @@ -392,13 +273,13 @@ class TestObjectsGraphQLQueryLogic: user="test_user", collection="test_collection" ) - + # Verify response was sent mock_response_flow.send.assert_called_once() response_call = mock_response_flow.send.call_args[0][0] - + # Verify response structure - assert isinstance(response_call, ObjectsQueryResponse) + assert isinstance(response_call, RowsQueryResponse) assert response_call.error is None assert '"customers"' in response_call.data # JSON encoded assert len(response_call.errors) == 0 @@ -409,13 +290,13 @@ class TestObjectsGraphQLQueryLogic: processor = MagicMock() processor.execute_graphql_query = AsyncMock() processor.on_message = Processor.on_message.__get__(processor, Processor) - + # Mock query execution error processor.execute_graphql_query.side_effect = RuntimeError("No schema available") - + # Create mock message mock_msg = MagicMock() - mock_request = ObjectsQueryRequest( + mock_request = RowsQueryRequest( user="test_user", collection="test_collection", query='{ invalid_query }', @@ -424,67 +305,225 @@ class TestObjectsGraphQLQueryLogic: ) mock_msg.value.return_value = mock_request mock_msg.properties.return_value = {"id": "test-456"} - + # Mock flow mock_flow = MagicMock() mock_response_flow = AsyncMock() mock_flow.return_value = mock_response_flow - + # Process message await processor.on_message(mock_msg, None, mock_flow) - + # Verify error response was sent mock_response_flow.send.assert_called_once() response_call = mock_response_flow.send.call_args[0][0] - + # Verify error response structure - assert isinstance(response_call, ObjectsQueryResponse) + assert isinstance(response_call, RowsQueryResponse) assert response_call.error is not None - assert response_call.error.type == "objects-query-error" + assert response_call.error.type == "rows-query-error" assert "No schema available" in response_call.error.message assert response_call.data is None -class TestCQLQueryGeneration: - """Test CQL query generation logic in isolation""" - - def test_partition_key_inclusion(self): - """Test that collection is always included in queries""" +class TestUnifiedTableQueries: + """Test queries against the unified rows table""" + + @pytest.mark.asyncio + async def test_query_with_index_match(self): + """Test query execution with matching index""" processor = MagicMock() + processor.session = MagicMock() + processor.connect_cassandra = MagicMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - - # Mock the query building (simplified version) - keyspace = processor.sanitize_name("test_user") - table = processor.sanitize_table("test_table") - - query = f"SELECT * FROM {keyspace}.{table}" - where_clauses = ["collection = %s"] - - assert "collection = %s" in where_clauses - assert keyspace == "test_user" - assert table == "o_test_table" - + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) + + # Mock session execute to return test data + mock_row = MagicMock() + mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"} + processor.session.execute.return_value = [mock_row] + + schema = RowSchema( + name="products", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string") + ] + ) + + # Query with filter on indexed field + results = await processor.query_cassandra( + user="test_user", + collection="test_collection", + schema_name="products", + row_schema=schema, + filters={"category": "electronics"}, + limit=10 + ) + + # Verify Cassandra was connected and queried + processor.connect_cassandra.assert_called_once() + processor.session.execute.assert_called_once() + + # Verify query structure - should query unified rows table + call_args = processor.session.execute.call_args + query = call_args[0][0] + params = call_args[0][1] + + assert "SELECT data, source FROM test_user.rows" in query + assert "collection = %s" in query + assert "schema_name = %s" in query + assert "index_name = %s" in query + assert "index_value = %s" in query + + assert params[0] == "test_collection" + assert params[1] == "products" + assert params[2] == "category" + assert params[3] == ["electronics"] + + # Verify results + assert len(results) == 1 + assert results[0]["id"] == "123" + assert results[0]["category"] == "electronics" + + @pytest.mark.asyncio + async def test_query_without_index_match(self): + """Test query execution without matching index (scan mode)""" + processor = MagicMock() + processor.session = MagicMock() + processor.connect_cassandra = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) + + # Mock session execute to return test data + mock_row1 = MagicMock() + mock_row1.data = {"id": "1", "name": "Product A", "price": "100"} + mock_row2 = MagicMock() + mock_row2.data = {"id": "2", "name": "Product B", "price": "200"} + processor.session.execute.return_value = [mock_row1, mock_row2] + + schema = RowSchema( + name="products", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string"), # Not indexed + Field(name="price", type="string") # Not indexed + ] + ) + + # Query with filter on non-indexed field + results = await processor.query_cassandra( + user="test_user", + collection="test_collection", + schema_name="products", + row_schema=schema, + filters={"name": "Product A"}, + limit=10 + ) + + # Query should use ALLOW FILTERING for scan + call_args = processor.session.execute.call_args + query = call_args[0][0] + + assert "ALLOW FILTERING" in query + + # Should post-filter results + assert len(results) == 1 + assert results[0]["name"] == "Product A" + + +class TestFilterMatching: + """Test filter matching logic""" + + def test_matches_filters_exact_match(self): + """Test exact match filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="status", type="string")]) + + row = {"status": "active", "name": "test"} + assert processor._matches_filters(row, {"status": "active"}, schema) is True + assert processor._matches_filters(row, {"status": "inactive"}, schema) is False + + def test_matches_filters_comparison_operators(self): + """Test comparison operators in filters""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="price", type="float")]) + + row = {"price": "100.0"} + + # Greater than + assert processor._matches_filters(row, {"price_gt": 50}, schema) is True + assert processor._matches_filters(row, {"price_gt": 150}, schema) is False + + # Less than + assert processor._matches_filters(row, {"price_lt": 150}, schema) is True + assert processor._matches_filters(row, {"price_lt": 50}, schema) is False + + # Greater than or equal + assert processor._matches_filters(row, {"price_gte": 100}, schema) is True + assert processor._matches_filters(row, {"price_gte": 101}, schema) is False + + # Less than or equal + assert processor._matches_filters(row, {"price_lte": 100}, schema) is True + assert processor._matches_filters(row, {"price_lte": 99}, schema) is False + + def test_matches_filters_contains(self): + """Test contains filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="description", type="string")]) + + row = {"description": "A great product for everyone"} + + assert processor._matches_filters(row, {"description_contains": "great"}, schema) is True + assert processor._matches_filters(row, {"description_contains": "terrible"}, schema) is False + + def test_matches_filters_in_list(self): + """Test in-list filter""" + processor = MagicMock() + processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) + + schema = RowSchema(name="test", fields=[Field(name="status", type="string")]) + + row = {"status": "active"} + + assert processor._matches_filters(row, {"status_in": ["active", "pending"]}, schema) is True + assert processor._matches_filters(row, {"status_in": ["inactive", "deleted"]}, schema) is False + + +class TestIndexedFieldFiltering: + """Test that only indexed or primary key fields can be directly filtered""" + def test_indexed_field_filtering(self): """Test that only indexed or primary key fields can be filtered""" - # Create schema with mixed field types schema = RowSchema( name="test", fields=[ Field(name="id", type="string", primary=True), - Field(name="indexed_field", type="string", indexed=True), + Field(name="indexed_field", type="string", indexed=True), Field(name="normal_field", type="string", indexed=False), Field(name="another_field", type="string") ] ) - + filters = { "id": "test123", # Primary key - should be included "indexed_field": "value", # Indexed - should be included "normal_field": "ignored", # Not indexed - should be ignored "another_field": "also_ignored" # Not indexed - should be ignored } - + # Simulate the filtering logic from the processor valid_filters = [] for field_name, value in filters.items(): @@ -492,7 +531,7 @@ class TestCQLQueryGeneration: schema_field = next((f for f in schema.fields if f.name == field_name), None) if schema_field and (schema_field.indexed or schema_field.primary): valid_filters.append((field_name, value)) - + # Only id and indexed_field should be included assert len(valid_filters) == 2 field_names = [f[0] for f in valid_filters] @@ -500,52 +539,3 @@ class TestCQLQueryGeneration: assert "indexed_field" in field_names assert "normal_field" not in field_names assert "another_field" not in field_names - - -class TestGraphQLSchemaGeneration: - """Test GraphQL schema generation in detail""" - - def test_field_type_annotations(self): - """Test that GraphQL types have correct field annotations""" - processor = MagicMock() - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create schema with various field types - schema = RowSchema( - name="test", - fields=[ - Field(name="id", type="string", required=True, primary=True), - Field(name="count", type="integer", required=True), - Field(name="price", type="float", required=False), - Field(name="active", type="boolean", required=False), - Field(name="optional_text", type="string", required=False) - ] - ) - - # Create GraphQL type - graphql_type = processor.create_graphql_type("test", schema) - - # Verify type was created successfully - assert graphql_type is not None - - def test_basic_type_creation(self): - """Test that GraphQL types are created correctly""" - processor = MagicMock() - processor.schemas = { - "customer": RowSchema( - name="customer", - fields=[Field(name="id", type="string", primary=True)] - ) - } - processor.graphql_types = {} - processor.get_python_type = Processor.get_python_type.__get__(processor, Processor) - processor.create_graphql_type = Processor.create_graphql_type.__get__(processor, Processor) - - # Create GraphQL type directly - graphql_type = processor.create_graphql_type("customer", processor.schemas["customer"]) - processor.graphql_types["customer"] = graphql_type - - # Verify customer type was created - assert "customer" in processor.graphql_types - assert processor.graphql_types["customer"] is not None \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index 27c09ca4..76bf5b08 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.schema import ( StructuredQueryRequest, StructuredQueryResponse, QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse, - ObjectsQueryRequest, ObjectsQueryResponse, + RowsQueryRequest, RowsQueryResponse, Error, GraphQLError ) from trustgraph.retrieval.structured_query.service import Processor @@ -68,7 +68,7 @@ class TestStructuredQueryProcessor: ) # Mock objects query service response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "John", "email": "john@example.com"}]}', errors=None, @@ -86,7 +86,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -108,7 +108,7 @@ class TestStructuredQueryProcessor: # Verify objects query service was called correctly mock_objects_client.request.assert_called_once() objects_call_args = mock_objects_client.request.call_args[0][0] - assert isinstance(objects_call_args, ObjectsQueryRequest) + assert isinstance(objects_call_args, RowsQueryRequest) assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }' assert objects_call_args.variables == {"state": "NY"} assert objects_call_args.user == "trustgraph" @@ -224,7 +224,7 @@ class TestStructuredQueryProcessor: assert response.error is not None assert "empty GraphQL query" in response.error.message - async def test_objects_query_service_error(self, processor): + async def test_rows_query_service_error(self, processor): """Test handling of objects query service errors""" # Arrange request = StructuredQueryRequest( @@ -250,7 +250,7 @@ class TestStructuredQueryProcessor: ) # Mock objects query service error - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=Error(type="graphql-execution-error", message="Table 'customers' not found"), data=None, errors=None, @@ -267,7 +267,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -284,7 +284,7 @@ class TestStructuredQueryProcessor: response = response_call[0][0] assert response.error is not None - assert "Objects query service error" in response.error.message + assert "Rows query service error" in response.error.message assert "Table 'customers' not found" in response.error.message async def test_graphql_errors_handling(self, processor): @@ -321,7 +321,7 @@ class TestStructuredQueryProcessor: ) ] - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, errors=graphql_errors, @@ -338,7 +338,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -400,7 +400,7 @@ class TestStructuredQueryProcessor: ) # Mock objects response - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data='{"customers": [{"id": "1", "name": "Alice", "orders": [{"id": "100", "total": 150.0}]}]}', errors=None @@ -416,7 +416,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response @@ -464,7 +464,7 @@ class TestStructuredQueryProcessor: confidence=0.9 ) - objects_response = ObjectsQueryResponse( + objects_response = RowsQueryResponse( error=None, data=None, # Null data errors=None, @@ -481,7 +481,7 @@ class TestStructuredQueryProcessor: def flow_router(service_name): if service_name == "nlp-query-request": return mock_nlp_client - elif service_name == "objects-query-request": + elif service_name == "rows-query-request": return mock_objects_client elif service_name == "response": return flow_response diff --git a/tests/unit/test_storage/test_cassandra_config_integration.py b/tests/unit/test_storage/test_cassandra_config_integration.py index 754a4bb0..0956f4e7 100644 --- a/tests/unit/test_storage/test_cassandra_config_integration.py +++ b/tests/unit/test_storage/test_cassandra_config_integration.py @@ -10,7 +10,7 @@ import pytest from unittest.mock import Mock, patch, MagicMock from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter -from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter +from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery from trustgraph.storage.knowledge.store import Processor as KgStore @@ -81,10 +81,10 @@ class TestTriplesWriterConfiguration: assert processor.cassandra_password is None -class TestObjectsWriterConfiguration: +class TestRowsWriterConfiguration: """Test Cassandra configuration in objects writer processor.""" - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_environment_variable_configuration(self, mock_cluster): """Test processor picks up configuration from environment variables.""" env_vars = { @@ -97,13 +97,13 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) assert processor.cassandra_host == ['obj-env-host1', 'obj-env-host2'] assert processor.cassandra_username == 'obj-env-user' assert processor.cassandra_password == 'obj-env-pass' - @patch('trustgraph.storage.objects.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') def test_cassandra_connection_with_hosts_list(self, mock_cluster): """Test that Cassandra connection uses hosts list correctly.""" env_vars = { @@ -118,7 +118,7 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify cluster was called with hosts list @@ -129,8 +129,8 @@ class TestObjectsWriterConfiguration: assert 'contact_points' in call_args.kwargs assert call_args.kwargs['contact_points'] == ['conn-host1', 'conn-host2', 'conn-host3'] - @patch('trustgraph.storage.objects.cassandra.write.Cluster') - @patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') + @patch('trustgraph.storage.rows.cassandra.write.Cluster') + @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_authentication_configuration(self, mock_auth_provider, mock_cluster): """Test authentication is configured when credentials are provided.""" env_vars = { @@ -145,7 +145,7 @@ class TestObjectsWriterConfiguration: mock_cluster.return_value = mock_cluster_instance with patch.dict(os.environ, env_vars, clear=True): - processor = ObjectsWriter(taskgroup=MagicMock()) + processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() # Verify auth provider was created with correct credentials @@ -302,10 +302,10 @@ class TestCommandLineArgumentHandling: def test_objects_writer_add_args(self): """Test that objects writer adds standard Cassandra arguments.""" import argparse - from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter + from trustgraph.storage.rows.cassandra.write import Processor as RowsWriter parser = argparse.ArgumentParser() - ObjectsWriter.add_args(parser) + RowsWriter.add_args(parser) # Parse empty args to check that arguments exist args = parser.parse_args([]) diff --git a/tests/unit/test_storage/test_objects_cassandra_storage.py b/tests/unit/test_storage/test_objects_cassandra_storage.py deleted file mode 100644 index c7f5ff40..00000000 --- a/tests/unit/test_storage/test_objects_cassandra_storage.py +++ /dev/null @@ -1,533 +0,0 @@ -""" -Unit tests for Cassandra Object Storage Processor - -Tests the business logic of the object storage processor including: -- Schema configuration handling -- Type conversions -- Name sanitization -- Table structure generation -""" - -import pytest -from unittest.mock import MagicMock, AsyncMock, patch -import json - -from trustgraph.storage.objects.cassandra.write import Processor -from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field - - -class TestObjectsCassandraStorageLogic: - """Test business logic without FlowProcessor dependencies""" - - def test_sanitize_name(self): - """Test name sanitization for Cassandra compatibility""" - processor = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - - # Test various name patterns (back to original logic) - assert processor.sanitize_name("simple_name") == "simple_name" - assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes" - assert processor.sanitize_name("name.with.dots") == "name_with_dots" - assert processor.sanitize_name("123_starts_with_number") == "o_123_starts_with_number" - assert processor.sanitize_name("name with spaces") == "name_with_spaces" - assert processor.sanitize_name("special!@#$%^chars") == "special______chars" - - def test_get_cassandra_type(self): - """Test field type conversion to Cassandra types""" - processor = MagicMock() - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - - # Basic type mappings - assert processor.get_cassandra_type("string") == "text" - assert processor.get_cassandra_type("boolean") == "boolean" - assert processor.get_cassandra_type("timestamp") == "timestamp" - assert processor.get_cassandra_type("uuid") == "uuid" - - # Integer types with size hints - assert processor.get_cassandra_type("integer", size=2) == "int" - assert processor.get_cassandra_type("integer", size=8) == "bigint" - - # Float types with size hints - assert processor.get_cassandra_type("float", size=2) == "float" - assert processor.get_cassandra_type("float", size=8) == "double" - - # Unknown type defaults to text - assert processor.get_cassandra_type("unknown_type") == "text" - - def test_convert_value(self): - """Test value conversion for different field types""" - processor = MagicMock() - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - - # Integer conversions - assert processor.convert_value("123", "integer") == 123 - assert processor.convert_value(123.5, "integer") == 123 - assert processor.convert_value(None, "integer") is None - - # Float conversions - assert processor.convert_value("123.45", "float") == 123.45 - assert processor.convert_value(123, "float") == 123.0 - - # Boolean conversions - assert processor.convert_value("true", "boolean") is True - assert processor.convert_value("false", "boolean") is False - assert processor.convert_value("1", "boolean") is True - assert processor.convert_value("0", "boolean") is False - assert processor.convert_value("yes", "boolean") is True - assert processor.convert_value("no", "boolean") is False - - # String conversions - assert processor.convert_value(123, "string") == "123" - assert processor.convert_value(True, "string") == "True" - - def test_table_creation_cql_generation(self): - """Test CQL generation for table creation""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create test schema - schema = RowSchema( - name="customer_records", - description="Test customer schema", - fields=[ - Field( - name="customer_id", - type="string", - size=50, - primary=True, - required=True, - indexed=False - ), - Field( - name="email", - type="string", - size=100, - required=True, - indexed=True - ), - Field( - name="age", - type="integer", - size=4, - required=False, - indexed=False - ) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "customer_records", schema) - - # Verify keyspace was ensured (check that it was added to known_keyspaces) - assert "test_user" in processor.known_keyspaces - - # Check the CQL that was executed (first call should be table creation) - all_calls = processor.session.execute.call_args_list - table_creation_cql = all_calls[0][0][0] # First call - - # Verify table structure (keyspace uses sanitize_name, table uses sanitize_table) - assert "CREATE TABLE IF NOT EXISTS test_user.o_customer_records" in table_creation_cql - assert "collection text" in table_creation_cql - assert "customer_id text" in table_creation_cql - assert "email text" in table_creation_cql - assert "age int" in table_creation_cql - assert "PRIMARY KEY ((collection, customer_id))" in table_creation_cql - - def test_table_creation_without_primary_key(self): - """Test table creation when no primary key is defined""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = set() - processor.known_tables = {} - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create schema without primary key - schema = RowSchema( - name="events", - description="Event log", - fields=[ - Field(name="event_type", type="string", size=50), - Field(name="timestamp", type="timestamp", size=0) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "events", schema) - - # Check the CQL includes synthetic_id (field names don't get o_ prefix) - executed_cql = processor.session.execute.call_args[0][0] - assert "synthetic_id uuid" in executed_cql - assert "PRIMARY KEY ((collection, synthetic_id))" in executed_cql - - @pytest.mark.asyncio - async def test_schema_config_parsing(self): - """Test parsing of schema configurations""" - processor = MagicMock() - processor.schemas = {} - processor.config_key = "schema" - processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) - - # Create test configuration - config = { - "schema": { - "customer_records": json.dumps({ - "name": "customer_records", - "description": "Customer data", - "fields": [ - { - "name": "id", - "type": "string", - "primary_key": True, - "required": True - }, - { - "name": "name", - "type": "string", - "required": True - }, - { - "name": "balance", - "type": "float", - "size": 8 - } - ] - }) - } - } - - # Process configuration - await processor.on_schema_config(config, version=1) - - # Verify schema was loaded - assert "customer_records" in processor.schemas - schema = processor.schemas["customer_records"] - assert schema.name == "customer_records" - assert len(schema.fields) == 3 - - # Check field properties - id_field = schema.fields[0] - assert id_field.name == "id" - assert id_field.type == "string" - assert id_field.primary is True - # Note: Field.required always returns False due to Pulsar schema limitations - # The actual required value is tracked during schema parsing - - @pytest.mark.asyncio - async def test_object_processing_logic(self): - """Test the logic for processing ExtractedObject""" - processor = MagicMock() - processor.schemas = { - "test_schema": RowSchema( - name="test_schema", - description="Test", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="value", type="integer", size=4) - ] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create test object - test_obj = ExtractedObject( - metadata=Metadata( - id="test-001", - user="test_user", - collection="test_collection", - metadata=[] - ), - schema_name="test_schema", - values=[{"id": "123", "value": "456"}], - confidence=0.9, - source_span="test source" - ) - - # Create mock message - msg = MagicMock() - msg.value.return_value = test_obj - - # Process object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once_with("test_user", "test_schema", processor.schemas["test_schema"]) - - # Verify insert was executed (keyspace normal, table with o_ prefix) - processor.session.execute.assert_called_once() - insert_cql = processor.session.execute.call_args[0][0] - values = processor.session.execute.call_args[0][1] - - assert "INSERT INTO test_user.o_test_schema" in insert_cql - assert "collection" in insert_cql - assert values[0] == "test_collection" # collection value - assert values[1] == "123" # id value (from values[0]) - assert values[2] == 456 # converted integer value (from values[0]) - - def test_secondary_index_creation(self): - """Test that secondary indexes are created for indexed fields""" - processor = MagicMock() - processor.schemas = {} - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - processor.session = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor) - def mock_ensure_keyspace(keyspace): - processor.known_keyspaces.add(keyspace) - if keyspace not in processor.known_tables: - processor.known_tables[keyspace] = set() - processor.ensure_keyspace = mock_ensure_keyspace - processor.ensure_table = Processor.ensure_table.__get__(processor, Processor) - - # Create schema with indexed field - schema = RowSchema( - name="products", - description="Product catalog", - fields=[ - Field(name="product_id", type="string", size=50, primary=True), - Field(name="category", type="string", size=30, indexed=True), - Field(name="price", type="float", size=8, indexed=True) - ] - ) - - # Call ensure_table - processor.ensure_table("test_user", "products", schema) - - # Should have 3 calls: create table + 2 indexes - assert processor.session.execute.call_count == 3 - - # Check index creation calls (table has o_ prefix, fields don't) - calls = processor.session.execute.call_args_list - index_calls = [call[0][0] for call in calls if "CREATE INDEX" in call[0][0]] - assert len(index_calls) == 2 - assert any("o_products_category_idx" in call for call in index_calls) - assert any("o_products_price_idx" in call for call in index_calls) - - -class TestObjectsCassandraStorageBatchLogic: - """Test batch processing logic in Cassandra storage""" - - @pytest.mark.asyncio - async def test_batch_object_processing_logic(self): - """Test processing of batch ExtractedObjects""" - processor = MagicMock() - processor.schemas = { - "batch_schema": RowSchema( - name="batch_schema", - description="Test batch schema", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="name", type="string", size=100), - Field(name="value", type="integer", size=4) - ] - ) - } - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - - # Create batch object with multiple values - batch_obj = ExtractedObject( - metadata=Metadata( - id="batch-001", - user="test_user", - collection="batch_collection", - metadata=[] - ), - schema_name="batch_schema", - values=[ - {"id": "001", "name": "First", "value": "100"}, - {"id": "002", "name": "Second", "value": "200"}, - {"id": "003", "name": "Third", "value": "300"} - ], - confidence=0.95, - source_span="batch source" - ) - - # Create mock message - msg = MagicMock() - msg.value.return_value = batch_obj - - # Process batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured once - processor.ensure_table.assert_called_once_with("test_user", "batch_schema", processor.schemas["batch_schema"]) - - # Verify 3 separate insert calls (one per batch item) - assert processor.session.execute.call_count == 3 - - # Check each insert call - calls = processor.session.execute.call_args_list - for i, call in enumerate(calls): - insert_cql = call[0][0] - values = call[0][1] - - assert "INSERT INTO test_user.o_batch_schema" in insert_cql - assert "collection" in insert_cql - - # Check values for each batch item - assert values[0] == "batch_collection" # collection - assert values[1] == f"00{i+1}" # id from batch item i - assert values[2] == f"First" if i == 0 else f"Second" if i == 1 else f"Third" # name - assert values[3] == (i+1) * 100 # converted integer value - - @pytest.mark.asyncio - async def test_empty_batch_processing_logic(self): - """Test processing of empty batch ExtractedObjects""" - processor = MagicMock() - processor.schemas = { - "empty_schema": RowSchema( - name="empty_schema", - fields=[Field(name="id", type="string", size=50, primary=True)] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create empty batch object - empty_batch_obj = ExtractedObject( - metadata=Metadata( - id="empty-001", - user="test_user", - collection="empty_collection", - metadata=[] - ), - schema_name="empty_schema", - values=[], # Empty batch - confidence=1.0, - source_span="empty source" - ) - - msg = MagicMock() - msg.value.return_value = empty_batch_obj - - # Process empty batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once() - - # Verify no insert calls for empty batch - processor.session.execute.assert_not_called() - - @pytest.mark.asyncio - async def test_single_item_batch_processing_logic(self): - """Test processing of single-item batch (backward compatibility)""" - processor = MagicMock() - processor.schemas = { - "single_schema": RowSchema( - name="single_schema", - fields=[ - Field(name="id", type="string", size=50, primary=True), - Field(name="data", type="string", size=100) - ] - ) - } - processor.ensure_table = MagicMock() - processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) - processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor) - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - processor.session = MagicMock() - processor.on_object = Processor.on_object.__get__(processor, Processor) - processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query - processor.known_tables = {"test_user": set()} # Pre-populate - - # Create single-item batch object (backward compatibility case) - single_batch_obj = ExtractedObject( - metadata=Metadata( - id="single-001", - user="test_user", - collection="single_collection", - metadata=[] - ), - schema_name="single_schema", - values=[{"id": "single-1", "data": "single data"}], # Array with one item - confidence=0.8, - source_span="single source" - ) - - msg = MagicMock() - msg.value.return_value = single_batch_obj - - # Process single-item batch object - await processor.on_object(msg, None, None) - - # Verify table was ensured - processor.ensure_table.assert_called_once() - - # Verify exactly one insert call - processor.session.execute.assert_called_once() - - insert_cql = processor.session.execute.call_args[0][0] - values = processor.session.execute.call_args[0][1] - - assert "INSERT INTO test_user.o_single_schema" in insert_cql - assert values[0] == "single_collection" # collection - assert values[1] == "single-1" # id value - assert values[2] == "single data" # data value - - def test_batch_value_conversion_logic(self): - """Test value conversion works correctly for batch items""" - processor = MagicMock() - processor.convert_value = Processor.convert_value.__get__(processor, Processor) - - # Test various conversion scenarios that would occur in batch processing - test_cases = [ - # Integer conversions for batch items - ("123", "integer", 123), - ("456", "integer", 456), - ("789", "integer", 789), - # Float conversions for batch items - ("12.5", "float", 12.5), - ("34.7", "float", 34.7), - # Boolean conversions for batch items - ("true", "boolean", True), - ("false", "boolean", False), - ("1", "boolean", True), - ("0", "boolean", False), - # String conversions for batch items - (123, "string", "123"), - (45.6, "string", "45.6"), - ] - - for input_val, field_type, expected_output in test_cases: - result = processor.convert_value(input_val, field_type) - assert result == expected_output, f"Failed for {input_val} -> {field_type}: got {result}, expected {expected_output}" \ No newline at end of file diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py new file mode 100644 index 00000000..b4c5a5b4 --- /dev/null +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -0,0 +1,435 @@ +""" +Unit tests for trustgraph.storage.row_embeddings.qdrant.write +Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + + +class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): + """Test Qdrant row embeddings storage functionality""" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_processor_initialization_basic(self, mock_qdrant_client): + """Test basic Qdrant processor initialization""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'store_uri': 'http://localhost:6333', + 'api_key': 'test-api-key', + 'taskgroup': AsyncMock(), + 'id': 'test-qdrant-processor' + } + + processor = Processor(**config) + + mock_qdrant_client.assert_called_once_with( + url='http://localhost:6333', api_key='test-api-key' + ) + assert hasattr(processor, 'qdrant') + assert processor.qdrant == mock_qdrant_instance + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_processor_initialization_with_defaults(self, mock_qdrant_client): + """Test processor initialization with default values""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-qdrant-processor' + } + + processor = Processor(**config) + + mock_qdrant_client.assert_called_once_with( + url='http://localhost:6333', api_key=None + ) + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_sanitize_name(self, mock_qdrant_client): + """Test name sanitization for Qdrant collections""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_client.return_value = MagicMock() + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Test basic sanitization + assert processor.sanitize_name("simple") == "simple" + assert processor.sanitize_name("with-dash") == "with_dash" + assert processor.sanitize_name("with.dot") == "with_dot" + assert processor.sanitize_name("UPPERCASE") == "uppercase" + + # Test numeric prefix handling + assert processor.sanitize_name("123start") == "r_123start" + assert processor.sanitize_name("_underscore") == "r__underscore" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_get_collection_name(self, mock_qdrant_client): + """Test Qdrant collection name generation""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_client.return_value = MagicMock() + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + collection_name = processor.get_collection_name( + user="test_user", + collection="test_collection", + schema_name="customer_data", + dimension=384 + ) + + assert collection_name == "rows_test_user_test_collection_customer_data_384" + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_creates_new(self, mock_qdrant_client): + """Test that ensure_collection creates a new collection when needed""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = False + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + processor.ensure_collection("test_collection", 384) + + mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") + mock_qdrant_instance.create_collection.assert_called_once() + + # Verify the collection is cached + assert "test_collection" in processor.created_collections + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_skips_existing(self, mock_qdrant_client): + """Test that ensure_collection skips creation when collection exists""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + processor.ensure_collection("existing_collection", 384) + + mock_qdrant_instance.collection_exists.assert_called_once() + mock_qdrant_instance.create_collection.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_ensure_collection_uses_cache(self, mock_qdrant_client): + """Test that ensure_collection uses cache for previously created collections""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.created_collections.add("cached_collection") + + processor.ensure_collection("cached_collection", 384) + + # Should not check or create - just return + mock_qdrant_instance.collection_exists.assert_not_called() + mock_qdrant_instance.create_collection.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') + async def test_on_embeddings_basic(self, mock_uuid, mock_qdrant_client): + """Test processing basic row embeddings message""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding, Metadata + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = 'test-uuid-123' + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + # Create embeddings message + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + embedding = RowIndexEmbedding( + index_name='customer_id', + index_value=['CUST001'], + text='CUST001', + vectors=[[0.1, 0.2, 0.3]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='customers', + embeddings=[embedding] + ) + + # Mock message wrapper + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Verify upsert was called + mock_qdrant_instance.upsert.assert_called_once() + + # Verify upsert parameters + upsert_call_args = mock_qdrant_instance.upsert.call_args + assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3' + + point = upsert_call_args[1]['points'][0] + assert point.vector == [0.1, 0.2, 0.3] + assert point.payload['index_name'] == 'customer_id' + assert point.payload['index_value'] == ['CUST001'] + assert point.payload['text'] == 'CUST001' + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + @patch('trustgraph.storage.row_embeddings.qdrant.write.uuid') + async def test_on_embeddings_multiple_vectors(self, mock_uuid, mock_qdrant_client): + """Test processing embeddings with multiple vectors""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_instance.collection_exists.return_value = True + mock_qdrant_client.return_value = mock_qdrant_instance + mock_uuid.uuid4.return_value = 'test-uuid' + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + # Embedding with multiple vectors + embedding = RowIndexEmbedding( + index_name='name', + index_value=['John Doe'], + text='John Doe', + vectors=[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='people', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should be called 3 times (once per vector) + assert mock_qdrant_instance.upsert.call_count == 3 + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_on_embeddings_skips_empty_vectors(self, mock_qdrant_client): + """Test that embeddings with no vectors are skipped""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.known_collections[('test_user', 'test_collection')] = {} + + metadata = MagicMock() + metadata.user = 'test_user' + metadata.collection = 'test_collection' + metadata.id = 'doc-123' + + # Embedding with no vectors + embedding = RowIndexEmbedding( + index_name='id', + index_value=['123'], + text='123', + vectors=[] # Empty vectors + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='items', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should not call upsert for empty vectors + mock_qdrant_instance.upsert.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_on_embeddings_drops_unknown_collection(self, mock_qdrant_client): + """Test that messages for unknown collections are dropped""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + from trustgraph.schema import RowEmbeddings, RowIndexEmbedding + + mock_qdrant_instance = MagicMock() + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + # No collections registered + + metadata = MagicMock() + metadata.user = 'unknown_user' + metadata.collection = 'unknown_collection' + metadata.id = 'doc-123' + + embedding = RowIndexEmbedding( + index_name='id', + index_value=['123'], + text='123', + vectors=[[0.1, 0.2]] + ) + + embeddings_msg = RowEmbeddings( + metadata=metadata, + schema_name='items', + embeddings=[embedding] + ) + + mock_msg = MagicMock() + mock_msg.value.return_value = embeddings_msg + + await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) + + # Should not call upsert for unknown collection + mock_qdrant_instance.upsert.assert_not_called() + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_delete_collection(self, mock_qdrant_client): + """Test deleting all collections for a user/collection""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + + # Mock collections list + mock_coll1 = MagicMock() + mock_coll1.name = 'rows_test_user_test_collection_schema1_384' + mock_coll2 = MagicMock() + mock_coll2.name = 'rows_test_user_test_collection_schema2_384' + mock_coll3 = MagicMock() + mock_coll3.name = 'rows_other_user_other_collection_schema_384' + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3] + mock_qdrant_instance.get_collections.return_value = mock_collections + + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + processor.created_collections.add('rows_test_user_test_collection_schema1_384') + + await processor.delete_collection('test_user', 'test_collection') + + # Should delete only the matching collections + assert mock_qdrant_instance.delete_collection.call_count == 2 + + # Verify the cached collection was removed + assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections + + @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') + async def test_delete_collection_schema(self, mock_qdrant_client): + """Test deleting collections for a specific schema""" + from trustgraph.storage.row_embeddings.qdrant.write import Processor + + mock_qdrant_instance = MagicMock() + + mock_coll1 = MagicMock() + mock_coll1.name = 'rows_test_user_test_collection_customers_384' + mock_coll2 = MagicMock() + mock_coll2.name = 'rows_test_user_test_collection_orders_384' + + mock_collections = MagicMock() + mock_collections.collections = [mock_coll1, mock_coll2] + mock_qdrant_instance.get_collections.return_value = mock_collections + + mock_qdrant_client.return_value = mock_qdrant_instance + + config = { + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + await processor.delete_collection_schema( + 'test_user', 'test_collection', 'customers' + ) + + # Should only delete the customers schema collection + mock_qdrant_instance.delete_collection.assert_called_once() + call_args = mock_qdrant_instance.delete_collection.call_args[0] + assert call_args[0] == 'rows_test_user_test_collection_customers_384' + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py new file mode 100644 index 00000000..c8b81447 --- /dev/null +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -0,0 +1,474 @@ +""" +Unit tests for Cassandra Row Storage Processor (Unified Table Implementation) + +Tests the business logic of the row storage processor including: +- Schema configuration handling +- Name sanitization +- Unified table structure +- Index management +- Row storage with multi-index support +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +import json + +from trustgraph.storage.rows.cassandra.write import Processor +from trustgraph.schema import ExtractedObject, Metadata, RowSchema, Field + + +class TestRowsCassandraStorageLogic: + """Test business logic for unified table implementation""" + + def test_sanitize_name(self): + """Test name sanitization for Cassandra compatibility""" + processor = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + + # Test various name patterns + assert processor.sanitize_name("simple_name") == "simple_name" + assert processor.sanitize_name("Name-With-Dashes") == "name_with_dashes" + assert processor.sanitize_name("name.with.dots") == "name_with_dots" + assert processor.sanitize_name("123_starts_with_number") == "r_123_starts_with_number" + assert processor.sanitize_name("name with spaces") == "name_with_spaces" + assert processor.sanitize_name("special!@#$%^chars") == "special______chars" + assert processor.sanitize_name("UPPERCASE") == "uppercase" + assert processor.sanitize_name("CamelCase") == "camelcase" + assert processor.sanitize_name("_underscore_start") == "r__underscore_start" + + def test_get_index_names(self): + """Test extraction of index names from schema""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + # Schema with primary and indexed fields + schema = RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="name", type="string"), # Not indexed + Field(name="status", type="string", indexed=True) + ] + ) + + index_names = processor.get_index_names(schema) + + # Should include primary key and indexed fields + assert "id" in index_names + assert "category" in index_names + assert "status" in index_names + assert "name" not in index_names # Not indexed + assert len(index_names) == 3 + + def test_get_index_names_no_indexes(self): + """Test schema with no indexed fields""" + processor = MagicMock() + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + + schema = RowSchema( + name="no_index_schema", + fields=[ + Field(name="data1", type="string"), + Field(name="data2", type="string") + ] + ) + + index_names = processor.get_index_names(schema) + assert len(index_names) == 0 + + def test_build_index_value(self): + """Test building index values from row data""" + processor = MagicMock() + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + + value_map = {"id": "123", "category": "electronics", "name": "Widget"} + + # Single field index + result = processor.build_index_value(value_map, "id") + assert result == ["123"] + + result = processor.build_index_value(value_map, "category") + assert result == ["electronics"] + + # Missing field returns empty string + result = processor.build_index_value(value_map, "missing") + assert result == [""] + + def test_build_index_value_composite(self): + """Test building composite index values""" + processor = MagicMock() + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + + value_map = {"region": "us-west", "category": "electronics", "id": "123"} + + # Composite index (comma-separated field names) + result = processor.build_index_value(value_map, "region,category") + assert result == ["us-west", "electronics"] + + @pytest.mark.asyncio + async def test_schema_config_parsing(self): + """Test parsing of schema configurations""" + processor = MagicMock() + processor.schemas = {} + processor.config_key = "schema" + processor.registered_partitions = set() + processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) + + # Create test configuration + config = { + "schema": { + "customer_records": json.dumps({ + "name": "customer_records", + "description": "Customer data", + "fields": [ + { + "name": "id", + "type": "string", + "primary_key": True, + "required": True + }, + { + "name": "name", + "type": "string", + "required": True + }, + { + "name": "category", + "type": "string", + "indexed": True + } + ] + }) + } + } + + # Process configuration + await processor.on_schema_config(config, version=1) + + # Verify schema was loaded + assert "customer_records" in processor.schemas + schema = processor.schemas["customer_records"] + assert schema.name == "customer_records" + assert len(schema.fields) == 3 + + # Check field properties + id_field = schema.fields[0] + assert id_field.name == "id" + assert id_field.type == "string" + assert id_field.primary is True + + @pytest.mark.asyncio + async def test_object_processing_stores_data_map(self): + """Test that row processing stores data as map""" + processor = MagicMock() + processor.schemas = { + "test_schema": RowSchema( + name="test_schema", + description="Test", + fields=[ + Field(name="id", type="string", size=50, primary=True), + Field(name="value", type="string", size=100) + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create test object + test_obj = ExtractedObject( + metadata=Metadata( + id="test-001", + user="test_user", + collection="test_collection", + metadata=[] + ), + schema_name="test_schema", + values=[{"id": "123", "value": "test_data"}], + confidence=0.9, + source_span="test source" + ) + + # Create mock message + msg = MagicMock() + msg.value.return_value = test_obj + + # Process object + await processor.on_object(msg, None, None) + + # Verify insert was executed + processor.session.execute.assert_called() + insert_call = processor.session.execute.call_args + insert_cql = insert_call[0][0] + values = insert_call[0][1] + + # Verify using unified rows table + assert "INSERT INTO test_user.rows" in insert_cql + + # Values should be: (collection, schema_name, index_name, index_value, data, source) + assert values[0] == "test_collection" # collection + assert values[1] == "test_schema" # schema_name + assert values[2] == "id" # index_name (primary key field) + assert values[3] == ["123"] # index_value as list + assert values[4] == {"id": "123", "value": "test_data"} # data map + assert values[5] == "" # source + + @pytest.mark.asyncio + async def test_object_processing_multiple_indexes(self): + """Test that row is written once per indexed field""" + processor = MagicMock() + processor.schemas = { + "multi_index_schema": RowSchema( + name="multi_index_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True), + Field(name="status", type="string", indexed=True) + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + test_obj = ExtractedObject( + metadata=Metadata( + id="test-001", + user="test_user", + collection="test_collection", + metadata=[] + ), + schema_name="multi_index_schema", + values=[{"id": "123", "category": "electronics", "status": "active"}], + confidence=0.9, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = test_obj + + await processor.on_object(msg, None, None) + + # Should have 3 inserts (one per indexed field: id, category, status) + assert processor.session.execute.call_count == 3 + + # Check that different index_names were used + index_names_used = set() + for call in processor.session.execute.call_args_list: + values = call[0][1] + index_names_used.add(values[2]) # index_name is 3rd value + + assert index_names_used == {"id", "category", "status"} + + +class TestRowsCassandraStorageBatchLogic: + """Test batch processing logic for unified table implementation""" + + @pytest.mark.asyncio + async def test_batch_object_processing(self): + """Test processing of batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "batch_schema": RowSchema( + name="batch_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="name", type="string") + ] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create batch object with multiple values + batch_obj = ExtractedObject( + metadata=Metadata( + id="batch-001", + user="test_user", + collection="batch_collection", + metadata=[] + ), + schema_name="batch_schema", + values=[ + {"id": "001", "name": "First"}, + {"id": "002", "name": "Second"}, + {"id": "003", "name": "Third"} + ], + confidence=0.95, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = batch_obj + + await processor.on_object(msg, None, None) + + # Should have 3 inserts (one per row, one index per row since only primary key) + assert processor.session.execute.call_count == 3 + + # Check each insert has different id + ids_inserted = set() + for call in processor.session.execute.call_args_list: + values = call[0][1] + ids_inserted.add(tuple(values[3])) # index_value is 4th value + + assert ids_inserted == {("001",), ("002",), ("003",)} + + @pytest.mark.asyncio + async def test_empty_batch_processing(self): + """Test processing of empty batch ExtractedObjects""" + processor = MagicMock() + processor.schemas = { + "empty_schema": RowSchema( + name="empty_schema", + fields=[Field(name="id", type="string", primary=True)] + ) + } + processor.tables_initialized = {"test_user"} + processor.registered_partitions = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) + processor.ensure_tables = MagicMock() + processor.register_partitions = MagicMock() + processor.collection_exists = MagicMock(return_value=True) + processor.on_object = Processor.on_object.__get__(processor, Processor) + + # Create empty batch object + empty_batch_obj = ExtractedObject( + metadata=Metadata( + id="empty-001", + user="test_user", + collection="empty_collection", + metadata=[] + ), + schema_name="empty_schema", + values=[], # Empty batch + confidence=1.0, + source_span="" + ) + + msg = MagicMock() + msg.value.return_value = empty_batch_obj + + await processor.on_object(msg, None, None) + + # Verify no insert calls for empty batch + processor.session.execute.assert_not_called() + + +class TestUnifiedTableStructure: + """Test the unified rows table structure""" + + def test_ensure_tables_creates_unified_structure(self): + """Test that ensure_tables creates the unified rows table""" + processor = MagicMock() + processor.known_keyspaces = {"test_user"} + processor.tables_initialized = set() + processor.session = MagicMock() + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.ensure_keyspace = MagicMock() + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + + processor.ensure_tables("test_user") + + # Should have 2 calls: create rows table + create row_partitions table + assert processor.session.execute.call_count == 2 + + # Check rows table creation + rows_cql = processor.session.execute.call_args_list[0][0][0] + assert "CREATE TABLE IF NOT EXISTS test_user.rows" in rows_cql + assert "collection text" in rows_cql + assert "schema_name text" in rows_cql + assert "index_name text" in rows_cql + assert "index_value frozen>" in rows_cql + assert "data map" in rows_cql + assert "source text" in rows_cql + assert "PRIMARY KEY ((collection, schema_name, index_name), index_value)" in rows_cql + + # Check row_partitions table creation + partitions_cql = processor.session.execute.call_args_list[1][0][0] + assert "CREATE TABLE IF NOT EXISTS test_user.row_partitions" in partitions_cql + assert "PRIMARY KEY ((collection), schema_name, index_name)" in partitions_cql + + # Verify keyspace added to initialized set + assert "test_user" in processor.tables_initialized + + def test_ensure_tables_idempotent(self): + """Test that ensure_tables is idempotent""" + processor = MagicMock() + processor.tables_initialized = {"test_user"} # Already initialized + processor.session = MagicMock() + processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) + + processor.ensure_tables("test_user") + + # Should not execute any CQL since already initialized + processor.session.execute.assert_not_called() + + +class TestPartitionRegistration: + """Test partition registration for tracking what's stored""" + + def test_register_partitions(self): + """Test registering partitions for a collection/schema pair""" + processor = MagicMock() + processor.registered_partitions = set() + processor.session = MagicMock() + processor.schemas = { + "test_schema": RowSchema( + name="test_schema", + fields=[ + Field(name="id", type="string", primary=True), + Field(name="category", type="string", indexed=True) + ] + ) + } + processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) + processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + + processor.register_partitions("test_user", "test_collection", "test_schema") + + # Should have 2 inserts (one per index: id, category) + assert processor.session.execute.call_count == 2 + + # Verify cache was updated + assert ("test_collection", "test_schema") in processor.registered_partitions + + def test_register_partitions_idempotent(self): + """Test that partition registration is idempotent""" + processor = MagicMock() + processor.registered_partitions = {("test_collection", "test_schema")} # Already registered + processor.session = MagicMock() + processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + + processor.register_partitions("test_user", "test_collection", "test_schema") + + # Should not execute any CQL since already registered + processor.session.execute.assert_not_called() diff --git a/tests/unit/test_text_completion/test_googleaistudio_processor.py b/tests/unit/test_text_completion/test_googleaistudio_processor.py index c54b3928..aa04d2a3 100644 --- a/tests/unit/test_text_completion/test_googleaistudio_processor.py +++ b/tests/unit/test_text_completion/test_googleaistudio_processor.py @@ -48,7 +48,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert hasattr(processor, 'client') assert hasattr(processor, 'safety_settings') assert len(processor.safety_settings) == 4 # 4 safety categories - mock_genai_class.assert_called_once_with(api_key='test-api-key') + mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -208,7 +208,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert processor.default_model == 'gemini-1.5-pro' assert processor.temperature == 0.7 assert processor.max_output == 4096 - mock_genai_class.assert_called_once_with(api_key='custom-api-key') + mock_genai_class.assert_called_once_with(api_key='custom-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -237,7 +237,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): assert processor.default_model == 'gemini-2.0-flash-001' # default_model assert processor.temperature == 0.0 # default_temperature assert processor.max_output == 8192 # default_max_output - mock_genai_class.assert_called_once_with(api_key='test-api-key') + mock_genai_class.assert_called_once_with(api_key='test-api-key', vertexai=False) @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @@ -427,7 +427,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): # Assert # Verify Google AI Studio client was called with correct API key - mock_genai_class.assert_called_once_with(api_key='gai-test-key') + mock_genai_class.assert_called_once_with(api_key='gai-test-key', vertexai=False) # Verify processor has the client assert processor.client == mock_genai_client diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index bd9964d1..daa2cc5c 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -101,7 +101,7 @@ from .exceptions import ( LoadError, LookupError, NLPQueryError, - ObjectsQueryError, + RowsQueryError, RequestError, StructuredQueryError, UnexpectedError, @@ -161,7 +161,7 @@ __all__ = [ "LoadError", "LookupError", "NLPQueryError", - "ObjectsQueryError", + "RowsQueryError", "RequestError", "StructuredQueryError", "UnexpectedError", diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 76cb9f56..9a6a49c3 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -115,15 +115,15 @@ class AsyncBulkClient: async for raw_message in websocket: yield json.loads(raw_message) - async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: - """Bulk import objects via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import rows via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" if self.token: ws_url = f"{ws_url}?token={self.token}" async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - async for obj in objects: - await websocket.send(json.dumps(obj)) + async for row in rows: + await websocket.send(json.dumps(row)) async def aclose(self) -> None: """Close connections""" diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 38560b19..2cf1bedf 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -708,18 +708,18 @@ class AsyncFlowInstance: return await self.request("triples", request_data) - async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, - operation_name: Optional[str] = None, **kwargs: Any): + async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs: Any): """ - Execute a GraphQL query on stored objects. + Execute a GraphQL query on stored rows. - Queries structured data objects using GraphQL syntax. Supports complex + Queries structured data rows using GraphQL syntax. Supports complex queries with variables and named operations. Args: query: GraphQL query string user: User identifier - collection: Collection identifier containing objects + collection: Collection identifier containing rows variables: Optional GraphQL query variables operation_name: Optional operation name for multi-operation queries **kwargs: Additional service-specific parameters @@ -743,7 +743,7 @@ class AsyncFlowInstance: } ''' - result = await flow.objects_query( + result = await flow.rows_query( query=query, user="trustgraph", collection="users", @@ -765,4 +765,4 @@ class AsyncFlowInstance: request_data["operationName"] = operation_name request_data.update(kwargs) - return await self.request("objects", request_data) + return await self.request("rows", request_data) diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 53727ef6..ac83876b 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -320,9 +320,9 @@ class AsyncSocketFlowInstance: return await self.client._send_request("triples", self.flow_id, request) - async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, - operation_name: Optional[str] = None, **kwargs): - """GraphQL query""" + async def rows_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs): + """GraphQL query against structured rows""" request = { "query": query, "user": user, @@ -334,7 +334,7 @@ class AsyncSocketFlowInstance: request["operationName"] = operation_name request.update(kwargs) - return await self.client._send_request("objects", self.flow_id, request) + return await self.client._send_request("rows", self.flow_id, request) async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs): """Execute MCP tool""" diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 91369ef4..3dfb0fba 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -530,45 +530,45 @@ class BulkClient: async for raw_message in websocket: yield json.loads(raw_message) - def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + def import_rows(self, flow: str, rows: Iterator[Dict[str, Any]], **kwargs: Any) -> None: """ - Bulk import structured objects into a flow. + Bulk import structured rows into a flow. - Efficiently uploads structured data objects via WebSocket streaming + Efficiently uploads structured data rows via WebSocket streaming for use in GraphQL queries. Args: flow: Flow identifier - objects: Iterator yielding object dictionaries + rows: Iterator yielding row dictionaries **kwargs: Additional parameters (reserved for future use) Example: ```python bulk = api.bulk() - # Generate objects to import - def object_generator(): - yield {"id": "obj1", "name": "Object 1", "value": 100} - yield {"id": "obj2", "name": "Object 2", "value": 200} - # ... more objects + # Generate rows to import + def row_generator(): + yield {"id": "row1", "name": "Row 1", "value": 100} + yield {"id": "row2", "name": "Row 2", "value": 200} + # ... more rows - bulk.import_objects( + bulk.import_rows( flow="default", - objects=object_generator() + rows=row_generator() ) ``` """ - self._run_async(self._import_objects_async(flow, objects)) + self._run_async(self._import_rows_async(flow, rows)) - async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None: - """Async implementation of objects import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: + """Async implementation of rows import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" if self.token: ws_url = f"{ws_url}?token={self.token}" async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: - for obj in objects: - await websocket.send(json.dumps(obj)) + for row in rows: + await websocket.send(json.dumps(row)) def close(self) -> None: """Close connections""" diff --git a/trustgraph-base/trustgraph/api/exceptions.py b/trustgraph-base/trustgraph/api/exceptions.py index 311d2651..b60e41e1 100644 --- a/trustgraph-base/trustgraph/api/exceptions.py +++ b/trustgraph-base/trustgraph/api/exceptions.py @@ -71,8 +71,8 @@ class NLPQueryError(TrustGraphException): pass -class ObjectsQueryError(TrustGraphException): - """Objects query service error""" +class RowsQueryError(TrustGraphException): + """Rows query service error""" pass @@ -103,7 +103,7 @@ ERROR_TYPE_MAPPING = { "load-error": LoadError, "lookup-error": LookupError, "nlp-query-error": NLPQueryError, - "objects-query-error": ObjectsQueryError, + "rows-query-error": RowsQueryError, "request-error": RequestError, "structured-query-error": StructuredQueryError, "unexpected-error": UnexpectedError, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index e10ae0f7..e8da1522 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -1001,12 +1001,12 @@ class FlowInstance: input ) - def objects_query( + def rows_query( self, query, user="trustgraph", collection="default", variables=None, operation_name=None ): """ - Execute a GraphQL query against structured objects in the knowledge graph. + Execute a GraphQL query against structured rows in the knowledge graph. Queries structured data using GraphQL syntax, allowing complex queries with filtering, aggregation, and relationship traversal. @@ -1038,7 +1038,7 @@ class FlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, user="trustgraph", collection="scientists" @@ -1053,7 +1053,7 @@ class FlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, variables={"name": "Marie Curie"} ) @@ -1074,7 +1074,7 @@ class FlowInstance: input["operation_name"] = operation_name response = self.request( - "service/objects", + "service/rows", input ) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 53ad1b4b..c0246612 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -789,7 +789,7 @@ class SocketFlowInstance: return self.client._send_request_sync("triples", self.flow_id, request, False) - def objects_query( + def rows_query( self, query: str, user: str, @@ -799,7 +799,7 @@ class SocketFlowInstance: **kwargs: Any ) -> Dict[str, Any]: """ - Execute a GraphQL query against structured objects. + Execute a GraphQL query against structured rows. Args: query: GraphQL query string @@ -826,7 +826,7 @@ class SocketFlowInstance: } } ''' - result = flow.objects_query( + result = flow.rows_query( query=query, user="trustgraph", collection="scientists" @@ -844,7 +844,7 @@ class SocketFlowInstance: request["operationName"] = operation_name request.update(kwargs) - return self.client._send_request_sync("objects", self.flow_id, request, False) + return self.client._send_request_sync("rows", self.flow_id, request, False) def mcp_tool( self, diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 80c5438b..4d4e3c84 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -21,7 +21,7 @@ from .translators.embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator ) -from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator +from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator @@ -113,9 +113,9 @@ TranslatorRegistry.register_service( ) TranslatorRegistry.register_service( - "objects-query", - ObjectsQueryRequestTranslator(), - ObjectsQueryResponseTranslator() + "rows-query", + RowsQueryRequestTranslator(), + RowsQueryResponseTranslator() ) TranslatorRegistry.register_service( diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py index 5849f4ce..189265e1 100644 --- a/trustgraph-base/trustgraph/messaging/translators/__init__.py +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -17,5 +17,5 @@ from .embeddings_query import ( DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator ) -from .objects_query import ObjectsQueryRequestTranslator, ObjectsQueryResponseTranslator +from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator diff --git a/trustgraph-base/trustgraph/messaging/translators/objects_query.py b/trustgraph-base/trustgraph/messaging/translators/rows_query.py similarity index 68% rename from trustgraph-base/trustgraph/messaging/translators/objects_query.py rename to trustgraph-base/trustgraph/messaging/translators/rows_query.py index a746e0c7..6feb75a3 100644 --- a/trustgraph-base/trustgraph/messaging/translators/objects_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/rows_query.py @@ -1,44 +1,44 @@ from typing import Dict, Any, Tuple, Optional -from ...schema import ObjectsQueryRequest, ObjectsQueryResponse +from ...schema import RowsQueryRequest, RowsQueryResponse from .base import MessageTranslator import json -class ObjectsQueryRequestTranslator(MessageTranslator): - """Translator for ObjectsQueryRequest schema objects""" - - def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryRequest: - return ObjectsQueryRequest( +class RowsQueryRequestTranslator(MessageTranslator): + """Translator for RowsQueryRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryRequest: + return RowsQueryRequest( user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), query=data.get("query", ""), variables=data.get("variables", {}), operation_name=data.get("operation_name", None) ) - - def from_pulsar(self, obj: ObjectsQueryRequest) -> Dict[str, Any]: + + def from_pulsar(self, obj: RowsQueryRequest) -> Dict[str, Any]: result = { "user": obj.user, "collection": obj.collection, "query": obj.query, "variables": dict(obj.variables) if obj.variables else {} } - + if obj.operation_name: result["operation_name"] = obj.operation_name - + return result -class ObjectsQueryResponseTranslator(MessageTranslator): - """Translator for ObjectsQueryResponse schema objects""" - - def to_pulsar(self, data: Dict[str, Any]) -> ObjectsQueryResponse: +class RowsQueryResponseTranslator(MessageTranslator): + """Translator for RowsQueryResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> RowsQueryResponse: raise NotImplementedError("Response translation to Pulsar not typically needed") - - def from_pulsar(self, obj: ObjectsQueryResponse) -> Dict[str, Any]: + + def from_pulsar(self, obj: RowsQueryResponse) -> Dict[str, Any]: result = {} - + # Handle GraphQL response data if obj.data: try: @@ -47,7 +47,7 @@ class ObjectsQueryResponseTranslator(MessageTranslator): result["data"] = obj.data else: result["data"] = None - + # Handle GraphQL errors if obj.errors: result["errors"] = [] @@ -60,20 +60,20 @@ class ObjectsQueryResponseTranslator(MessageTranslator): if error.extensions: error_dict["extensions"] = dict(error.extensions) result["errors"].append(error_dict) - + # Handle extensions if obj.extensions: result["extensions"] = dict(obj.extensions) - + # Handle system-level error if obj.error: result["error"] = { "type": obj.error.type, "message": obj.error.message } - + return result - - def from_response_with_completion(self, obj: ObjectsQueryResponse) -> Tuple[Dict[str, Any], bool]: + + def from_response_with_completion(self, obj: RowsQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.from_pulsar(obj), True diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index 473ec3a4..93559056 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -60,3 +60,23 @@ class StructuredObjectEmbedding: field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings ############################################################################ + +# Row embeddings are embeddings associated with indexed field values +# in structured row data. Each index gets embedded separately. + +@dataclass +class RowIndexEmbedding: + """Single row's embedding for one index""" + index_name: str = "" # The indexed field name(s) + index_value: list[str] = field(default_factory=list) # The field value(s) + text: str = "" # Text that was embedded + vectors: list[list[float]] = field(default_factory=list) + +@dataclass +class RowEmbeddings: + """Batched row embeddings for a schema""" + metadata: Metadata | None = None + schema_name: str = "" + embeddings: list[RowIndexEmbedding] = field(default_factory=list) + +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index aaeb739f..7b40ca0a 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -9,7 +9,7 @@ from .library import * from .lookup import * from .nlp_query import * from .structured_query import * -from .objects_query import * +from .rows_query import * from .diagnosis import * from .collection import * from .storage import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index dc33febe..50ec416a 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -59,4 +59,39 @@ document_embeddings_request_queue = topic( ) document_embeddings_response_queue = topic( "document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' +) + +############################################################################ + +# Row embeddings query - for semantic/fuzzy matching on row index values + +@dataclass +class RowIndexMatch: + """A single matching row index from a semantic search""" + index_name: str = "" # The indexed field(s) + index_value: list[str] = field(default_factory=list) # The index values + text: str = "" # The text that was embedded + score: float = 0.0 # Similarity score + +@dataclass +class RowEmbeddingsRequest: + """Request for row embeddings semantic search""" + vectors: list[list[float]] = field(default_factory=list) # Query vectors + limit: int = 10 # Max results to return + user: str = "" # User/keyspace + collection: str = "" # Collection name + schema_name: str = "" # Schema name to search within + index_name: str | None = None # Optional: filter to specific index + +@dataclass +class RowEmbeddingsResponse: + """Response from row embeddings semantic search""" + error: Error | None = None + matches: list[RowIndexMatch] = field(default_factory=list) + +row_embeddings_request_queue = topic( + "row-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' +) +row_embeddings_response_queue = topic( + "row-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/objects_query.py b/trustgraph-base/trustgraph/schema/services/rows_query.py similarity index 91% rename from trustgraph-base/trustgraph/schema/services/objects_query.py rename to trustgraph-base/trustgraph/schema/services/rows_query.py index e24daef3..a4818329 100644 --- a/trustgraph-base/trustgraph/schema/services/objects_query.py +++ b/trustgraph-base/trustgraph/schema/services/rows_query.py @@ -6,7 +6,7 @@ from ..core.topic import topic ############################################################################ -# Objects Query Service - executes GraphQL queries against structured data +# Rows Query Service - executes GraphQL queries against structured data @dataclass class GraphQLError: @@ -15,7 +15,7 @@ class GraphQLError: extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata @dataclass -class ObjectsQueryRequest: +class RowsQueryRequest: user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest) collection: str = "" # Data collection identifier (required for partition key) query: str = "" # GraphQL query string @@ -23,7 +23,7 @@ class ObjectsQueryRequest: operation_name: Optional[str] = None # Operation to execute for multi-operation documents @dataclass -class ObjectsQueryResponse: +class RowsQueryResponse: error: Error | None = None # System-level error (connection, timeout, etc.) data: str = "" # JSON-encoded GraphQL response data errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 09a22bdc..49e24b8e 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -48,7 +48,7 @@ tg-invoke-graph-embeddings = "trustgraph.cli.invoke_graph_embeddings:main" tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main" tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main" tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main" -tg-invoke-objects-query = "trustgraph.cli.invoke_objects_query:main" +tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_objects_query.py b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py similarity index 96% rename from trustgraph-cli/trustgraph/cli/invoke_objects_query.py rename to trustgraph-cli/trustgraph/cli/invoke_rows_query.py index 50c4e8c2..962f353c 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_objects_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_rows_query.py @@ -1,5 +1,5 @@ """ -Uses the ObjectsQuery service to execute GraphQL queries against structured data +Uses the RowsQuery service to execute GraphQL queries against structured data """ import argparse @@ -81,7 +81,7 @@ def format_table_data(rows, table_name, output_format): else: return json.dumps({table_name: rows}, indent=2) -def objects_query( +def rows_query( url, flow_id, query, user, collection, variables, operation_name, output_format='table' ): @@ -96,7 +96,7 @@ def objects_query( print(f"Error parsing variables JSON: {e}", file=sys.stderr) sys.exit(1) - resp = api.objects_query( + resp = api.rows_query( query=query, user=user, collection=collection, @@ -126,7 +126,7 @@ def objects_query( def main(): parser = argparse.ArgumentParser( - prog='tg-invoke-objects-query', + prog='tg-invoke-rows-query', description=__doc__, ) @@ -181,7 +181,7 @@ def main(): try: - objects_query( + rows_query( url=args.url, flow_id=args.flow_id, query=args.query, diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index bf112417..fa167917 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -573,19 +573,19 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample return output_records, descriptor -def _send_to_trustgraph(objects, api_url, flow, batch_size=1000, token=None): +def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None): """Send ExtractedObject records to TrustGraph using Python API""" from trustgraph.api import Api try: - total_records = len(objects) + total_records = len(rows) logger.info(f"Importing {total_records} records to TrustGraph...") # Use Python API bulk import api = Api(api_url, token=token) bulk = api.bulk() - bulk.import_objects(flow=flow, objects=iter(objects)) + bulk.import_rows(flow=flow, rows=iter(rows)) logger.info(f"Successfully imported {total_records} records to TrustGraph") diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 499aa1c5..31a22a2f 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -60,27 +60,27 @@ api-gateway = "trustgraph.gateway:run" chunker-recursive = "trustgraph.chunking.recursive:run" chunker-token = "trustgraph.chunking.token:run" config-svc = "trustgraph.config.service:run" -de-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" -de-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" -de-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" -de-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run" -de-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run" -de-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run" +doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" +doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" +doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" +doc-embeddings-write-milvus = "trustgraph.storage.doc_embeddings.milvus:run" +doc-embeddings-write-pinecone = "trustgraph.storage.doc_embeddings.pinecone:run" +doc-embeddings-write-qdrant = "trustgraph.storage.doc_embeddings.qdrant:run" document-embeddings = "trustgraph.embeddings.document_embeddings:run" document-rag = "trustgraph.retrieval.document_rag:run" embeddings-fastembed = "trustgraph.embeddings.fastembed:run" embeddings-ollama = "trustgraph.embeddings.ollama:run" -ge-query-milvus = "trustgraph.query.graph_embeddings.milvus:run" -ge-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run" -ge-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run" -ge-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run" -ge-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run" -ge-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" +graph-embeddings-query-milvus = "trustgraph.query.graph_embeddings.milvus:run" +graph-embeddings-query-pinecone = "trustgraph.query.graph_embeddings.pinecone:run" +graph-embeddings-query-qdrant = "trustgraph.query.graph_embeddings.qdrant:run" +graph-embeddings-write-milvus = "trustgraph.storage.graph_embeddings.milvus:run" +graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:run" +graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" graph-embeddings = "trustgraph.embeddings.graph_embeddings:run" graph-rag = "trustgraph.retrieval.graph_rag:run" kg-extract-agent = "trustgraph.extract.kg.agent:run" kg-extract-definitions = "trustgraph.extract.kg.definitions:run" -kg-extract-objects = "trustgraph.extract.kg.objects:run" +kg-extract-rows = "trustgraph.extract.kg.rows:run" kg-extract-relationships = "trustgraph.extract.kg.relationships:run" kg-extract-topics = "trustgraph.extract.kg.topics:run" kg-extract-ontology = "trustgraph.extract.kg.ontology:run" @@ -90,8 +90,11 @@ librarian = "trustgraph.librarian:run" mcp-tool = "trustgraph.agent.mcp_tool:run" metering = "trustgraph.metering:run" nlp-query = "trustgraph.retrieval.nlp_query:run" -objects-write-cassandra = "trustgraph.storage.objects.cassandra:run" -objects-query-cassandra = "trustgraph.query.objects.cassandra:run" +rows-write-cassandra = "trustgraph.storage.rows.cassandra:run" +rows-query-cassandra = "trustgraph.query.rows.cassandra:run" +row-embeddings = "trustgraph.embeddings.row_embeddings:run" +row-embeddings-write-qdrant = "trustgraph.storage.row_embeddings.qdrant:run" +row-embeddings-query-qdrant = "trustgraph.query.row_embeddings.qdrant:run" pdf-decoder = "trustgraph.decoding.pdf:run" pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run" prompt-template = "trustgraph.prompt.template:run" diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py new file mode 100644 index 00000000..a48cc4d0 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/__main__.py @@ -0,0 +1,6 @@ + +from . embeddings import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py new file mode 100644 index 00000000..84c41ff3 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/row_embeddings/embeddings.py @@ -0,0 +1,263 @@ + +""" +Row embeddings processor. Calls the embeddings service to compute embeddings +for indexed field values in extracted row data. + +Input is ExtractedObject (structured row data with schema). +Output is RowEmbeddings (row data with embeddings for indexed fields). + +This follows the two-stage pattern used by graph-embeddings and document-embeddings: + Stage 1 (this processor): Compute embeddings + Stage 2 (row-embeddings-write-*): Store embeddings +""" + +import json +import logging +from typing import Dict, List, Set + +from ... schema import ExtractedObject, RowEmbeddings, RowIndexEmbedding +from ... schema import RowSchema, Field +from ... base import FlowProcessor, EmbeddingsClientSpec, ConsumerSpec +from ... base import ProducerSpec, CollectionConfigHandler + +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings" +default_batch_size = 10 + + +class Processor(CollectionConfigHandler, FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + self.batch_size = params.get("batch_size", default_batch_size) + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + + super(Processor, self).__init__( + **params | { + "id": id, + "config_type": self.config_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="input", + schema=ExtractedObject, + handler=self.on_message, + ) + ) + + self.register_specification( + EmbeddingsClientSpec( + request_name="embeddings-request", + response_name="embeddings-response", + ) + ) + + self.register_specification( + ProducerSpec( + name="output", + schema=RowEmbeddings + ) + ) + + # Register config handlers + self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_collection_config) + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") + + # Clear existing schemas + self.schemas = {} + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = Field( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + def get_index_names(self, schema: RowSchema) -> List[str]: + """Get all index names for a schema.""" + index_names = [] + for field in schema.fields: + if field.primary or field.indexed: + index_names.append(field.name) + return index_names + + def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]: + """Build the index_value list for a given index.""" + field_names = [f.strip() for f in index_name.split(',')] + values = [] + for field_name in field_names: + value = value_map.get(field_name) + values.append(str(value) if value is not None else "") + return values + + def build_text_for_embedding(self, index_value: List[str]) -> str: + """Build text representation for embedding from index values.""" + # Space-join the values for composite indexes + return " ".join(index_value) + + async def on_message(self, msg, consumer, flow): + """Process incoming ExtractedObject and compute embeddings""" + + obj = msg.value() + logger.info( + f"Computing embeddings for {len(obj.values)} rows, " + f"schema {obj.schema_name}, doc {obj.metadata.id}" + ) + + # Validate collection exists before processing + if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + logger.warning( + f"Collection {obj.metadata.collection} for user {obj.metadata.user} " + f"does not exist in config. Dropping message." + ) + return + + # Get schema definition + schema = self.schemas.get(obj.schema_name) + if not schema: + logger.warning(f"No schema found for {obj.schema_name} - skipping") + return + + # Get all index names for this schema + index_names = self.get_index_names(schema) + + if not index_names: + logger.warning(f"Schema {obj.schema_name} has no indexed fields - skipping") + return + + # Track unique texts to avoid duplicate embeddings + # text -> (index_name, index_value) + texts_to_embed: Dict[str, tuple] = {} + + # Collect all texts that need embeddings + for value_map in obj.values: + for index_name in index_names: + index_value = self.build_index_value(value_map, index_name) + + # Skip empty values + if not index_value or all(v == "" for v in index_value): + continue + + text = self.build_text_for_embedding(index_value) + if text and text not in texts_to_embed: + texts_to_embed[text] = (index_name, index_value) + + if not texts_to_embed: + logger.info("No texts to embed") + return + + # Compute embeddings + embeddings_list = [] + + try: + for text, (index_name, index_value) in texts_to_embed.items(): + vectors = await flow("embeddings-request").embed(text=text) + + embeddings_list.append( + RowIndexEmbedding( + index_name=index_name, + index_value=index_value, + text=text, + vectors=vectors + ) + ) + + # Send in batches to avoid oversized messages + for i in range(0, len(embeddings_list), self.batch_size): + batch = embeddings_list[i:i + self.batch_size] + result = RowEmbeddings( + metadata=obj.metadata, + schema_name=obj.schema_name, + embeddings=batch, + ) + await flow("output").send(result) + + logger.info( + f"Computed {len(embeddings_list)} embeddings for " + f"{len(obj.values)} rows ({len(index_names)} indexes)" + ) + + except Exception as e: + logger.error("Exception during embedding computation", exc_info=True) + raise e + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Collection creation notification - no action needed for embedding stage""" + logger.debug(f"Row embeddings collection notification for {user}/{collection}") + + async def delete_collection(self, user: str, collection: str): + """Collection deletion notification - no action needed for embedding stage""" + logger.debug(f"Row embeddings collection delete notification for {user}/{collection}") + + @staticmethod + def add_args(parser): + + FlowProcessor.add_args(parser) + + parser.add_argument( + '--batch-size', + type=int, + default=default_batch_size, + help=f'Maximum embeddings per output message (default: {default_batch_size})' + ) + + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + +def run(): + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/__init__.py b/trustgraph-flow/trustgraph/extract/kg/rows/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/extract/kg/objects/__init__.py rename to trustgraph-flow/trustgraph/extract/kg/rows/__init__.py diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/__main__.py b/trustgraph-flow/trustgraph/extract/kg/rows/__main__.py similarity index 100% rename from trustgraph-flow/trustgraph/extract/kg/objects/__main__.py rename to trustgraph-flow/trustgraph/extract/kg/rows/__main__.py diff --git a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py similarity index 98% rename from trustgraph-flow/trustgraph/extract/kg/objects/processor.py rename to trustgraph-flow/trustgraph/extract/kg/rows/processor.py index b3483240..bd7bc802 100644 --- a/trustgraph-flow/trustgraph/extract/kg/objects/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -1,5 +1,5 @@ """ -Object extraction service - extracts structured objects from text chunks +Row extraction service - extracts structured rows from text chunks based on configured schemas. """ @@ -18,7 +18,7 @@ from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import PromptClientSpec from .... messaging.translators import row_schema_translator -default_ident = "kg-extract-objects" +default_ident = "kg-extract-rows" def convert_values_to_strings(obj: Dict[str, Any]) -> Dict[str, str]: @@ -310,5 +310,5 @@ class Processor(FlowProcessor): FlowProcessor.add_args(parser) def run(): - """Entry point for kg-extract-objects command""" + """Entry point for kg-extract-rows command""" Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 2d401cf3..d7d04f83 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -20,7 +20,7 @@ from . prompt import PromptRequestor from . graph_rag import GraphRagRequestor from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor -from . objects_query import ObjectsQueryRequestor +from . rows_query import RowsQueryRequestor from . nlp_query import NLPQueryRequestor from . structured_query import StructuredQueryRequestor from . structured_diag import StructuredDiagRequestor @@ -40,7 +40,7 @@ from . triples_import import TriplesImport from . graph_embeddings_import import GraphEmbeddingsImport from . document_embeddings_import import DocumentEmbeddingsImport from . entity_contexts_import import EntityContextsImport -from . objects_import import ObjectsImport +from . rows_import import RowsImport from . core_export import CoreExport from . core_import import CoreImport @@ -58,7 +58,7 @@ request_response_dispatchers = { "graph-embeddings": GraphEmbeddingsQueryRequestor, "document-embeddings": DocumentEmbeddingsQueryRequestor, "triples": TriplesQueryRequestor, - "objects": ObjectsQueryRequestor, + "rows": RowsQueryRequestor, "nlp-query": NLPQueryRequestor, "structured-query": StructuredQueryRequestor, "structured-diag": StructuredDiagRequestor, @@ -89,7 +89,7 @@ import_dispatchers = { "graph-embeddings": GraphEmbeddingsImport, "document-embeddings": DocumentEmbeddingsImport, "entity-contexts": EntityContextsImport, - "objects": ObjectsImport, + "rows": RowsImport, } class DispatcherWrapper: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py similarity index 97% rename from trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py rename to trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py index fc982b69..6606dc1a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_import.py @@ -12,7 +12,7 @@ from . serialize import to_subgraph # Module logger logger = logging.getLogger(__name__) -class ObjectsImport: +class RowsImport: def __init__( self, ws, running, backend, queue @@ -20,7 +20,7 @@ class ObjectsImport: self.ws = ws self.running = running - + self.publisher = Publisher( backend, topic = queue, schema = ExtractedObject ) @@ -73,4 +73,4 @@ class ObjectsImport: if self.ws: await self.ws.close() - self.ws = None \ No newline at end of file + self.ws = None diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py similarity index 69% rename from trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py rename to trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py index fb8dc81d..57435be8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/rows_query.py @@ -1,30 +1,30 @@ -from ... schema import ObjectsQueryRequest, ObjectsQueryResponse +from ... schema import RowsQueryRequest, RowsQueryResponse from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -class ObjectsQueryRequestor(ServiceRequestor): +class RowsQueryRequestor(ServiceRequestor): def __init__( self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): - super(ObjectsQueryRequestor, self).__init__( + super(RowsQueryRequestor, self).__init__( backend=backend, request_queue=request_queue, response_queue=response_queue, - request_schema=ObjectsQueryRequest, - response_schema=ObjectsQueryResponse, + request_schema=RowsQueryRequest, + response_schema=RowsQueryResponse, subscription = subscriber, consumer_name = consumer, timeout=timeout, ) - self.request_translator = TranslatorRegistry.get_request_translator("objects-query") - self.response_translator = TranslatorRegistry.get_response_translator("objects-query") + self.request_translator = TranslatorRegistry.get_request_translator("rows-query") + self.response_translator = TranslatorRegistry.get_response_translator("rows-query") def to_request(self, body): return self.request_translator.to_pulsar(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/query/graphql/__init__.py b/trustgraph-flow/trustgraph/query/graphql/__init__.py new file mode 100644 index 00000000..32dc6a97 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/__init__.py @@ -0,0 +1,22 @@ +""" +Shared GraphQL utilities for row query services. + +This module provides reusable GraphQL components including: +- Filter types (IntFilter, StringFilter, FloatFilter) +- Dynamic schema generation from RowSchema definitions +- Filter parsing utilities +""" + +from .types import IntFilter, StringFilter, FloatFilter, SortDirection +from .schema import GraphQLSchemaBuilder +from .filters import parse_filter_key, parse_where_clause + +__all__ = [ + "IntFilter", + "StringFilter", + "FloatFilter", + "SortDirection", + "GraphQLSchemaBuilder", + "parse_filter_key", + "parse_where_clause", +] diff --git a/trustgraph-flow/trustgraph/query/graphql/filters.py b/trustgraph-flow/trustgraph/query/graphql/filters.py new file mode 100644 index 00000000..7788e20d --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/filters.py @@ -0,0 +1,104 @@ +""" +Filter parsing utilities for GraphQL row queries. + +Provides functions to parse GraphQL filter objects into a normalized +format that can be used by different query backends. +""" + +import logging +from typing import Dict, Any, Tuple + +logger = logging.getLogger(__name__) + + +def parse_filter_key(filter_key: str) -> Tuple[str, str]: + """ + Parse GraphQL filter key into field name and operator. + + Supports common GraphQL filter patterns: + - field_name -> (field_name, "eq") + - field_name_gt -> (field_name, "gt") + - field_name_gte -> (field_name, "gte") + - field_name_lt -> (field_name, "lt") + - field_name_lte -> (field_name, "lte") + - field_name_in -> (field_name, "in") + + Args: + filter_key: The filter key string from GraphQL + + Returns: + Tuple of (field_name, operator) + """ + if not filter_key: + return ("", "eq") + + operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"] + + for op_suffix in operators: + if filter_key.endswith(op_suffix): + field_name = filter_key[:-len(op_suffix)] + operator = op_suffix[1:] # Remove the leading underscore + return (field_name, operator) + + # Default to equality if no operator suffix found + return (filter_key, "eq") + + +def parse_where_clause(where_obj) -> Dict[str, Any]: + """ + Parse the idiomatic nested GraphQL filter structure into a flat dict. + + Converts Strawberry filter objects (StringFilter, IntFilter, etc.) + into a dictionary mapping field names with operators to values. + + Example: + Input: where_obj with email.eq = "foo@bar.com" + Output: {"email": "foo@bar.com"} + + Input: where_obj with age.gt = 21 + Output: {"age_gt": 21} + + Args: + where_obj: The GraphQL where clause object + + Returns: + Dictionary mapping field_operator keys to values + """ + if not where_obj: + return {} + + conditions = {} + + logger.debug(f"Parsing where clause: {where_obj}") + + for field_name, filter_obj in where_obj.__dict__.items(): + if filter_obj is None: + continue + + logger.debug(f"Processing field {field_name} with filter_obj: {filter_obj}") + + if hasattr(filter_obj, '__dict__'): + # This is a filter object (StringFilter, IntFilter, etc.) + for operator, value in filter_obj.__dict__.items(): + if value is not None: + logger.debug(f"Found operator {operator} with value {value}") + # Map GraphQL operators to our internal format + if operator == "eq": + conditions[field_name] = value + elif operator in ["gt", "gte", "lt", "lte"]: + conditions[f"{field_name}_{operator}"] = value + elif operator == "in_": + conditions[f"{field_name}_in"] = value + elif operator == "contains": + conditions[f"{field_name}_contains"] = value + elif operator == "startsWith": + conditions[f"{field_name}_startsWith"] = value + elif operator == "endsWith": + conditions[f"{field_name}_endsWith"] = value + elif operator == "not_": + conditions[f"{field_name}_not"] = value + elif operator == "not_in": + conditions[f"{field_name}_not_in"] = value + + logger.debug(f"Final parsed conditions: {conditions}") + return conditions diff --git a/trustgraph-flow/trustgraph/query/graphql/schema.py b/trustgraph-flow/trustgraph/query/graphql/schema.py new file mode 100644 index 00000000..0c97b1d9 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/schema.py @@ -0,0 +1,251 @@ +""" +Dynamic GraphQL schema generation from RowSchema definitions. + +Provides a builder class that creates Strawberry GraphQL schemas +from TrustGraph RowSchema definitions, with pluggable query backends. +""" + +import logging +from typing import Dict, Any, Optional, List, Callable, Awaitable + +import strawberry +from strawberry import Schema +from strawberry.types import Info + +from .types import IntFilter, StringFilter, FloatFilter, SortDirection + +logger = logging.getLogger(__name__) + +# Type alias for query callback function +QueryCallback = Callable[ + [str, str, str, Any, Dict[str, Any], int, Optional[str], Optional[SortDirection]], + Awaitable[List[Dict[str, Any]]] +] + + +class GraphQLSchemaBuilder: + """ + Builds GraphQL schemas from RowSchema definitions. + + This class extracts the GraphQL schema generation logic so it can be + reused across different query backends (Cassandra, etc.). + + Usage: + builder = GraphQLSchemaBuilder() + + # Add schemas + for name, row_schema in schemas.items(): + builder.add_schema(name, row_schema) + + # Build with a query callback + schema = builder.build(query_callback) + """ + + def __init__(self): + self.schemas: Dict[str, Any] = {} # name -> RowSchema + self.graphql_types: Dict[str, type] = {} + self.filter_types: Dict[str, type] = {} + + def add_schema(self, name: str, row_schema) -> None: + """ + Add a RowSchema to the builder. + + Args: + name: The schema name (used as the GraphQL query field name) + row_schema: The RowSchema object defining fields + """ + self.schemas[name] = row_schema + self.graphql_types[name] = self._create_graphql_type(name, row_schema) + self.filter_types[name] = self._create_filter_type(name, row_schema) + logger.debug(f"Added schema {name} with {len(row_schema.fields)} fields") + + def clear(self) -> None: + """Clear all schemas from the builder.""" + self.schemas = {} + self.graphql_types = {} + self.filter_types = {} + + def build(self, query_callback: QueryCallback) -> Optional[Schema]: + """ + Build the GraphQL schema with the provided query callback. + + The query callback will be invoked when resolving queries, with: + - user: str + - collection: str + - schema_name: str + - row_schema: RowSchema + - filters: Dict[str, Any] + - limit: int + - order_by: Optional[str] + - direction: Optional[SortDirection] + + It should return a list of row dictionaries. + + Args: + query_callback: Async function to execute queries + + Returns: + Strawberry Schema, or None if no schemas are loaded + """ + if not self.schemas: + logger.warning("No schemas loaded, cannot generate GraphQL schema") + return None + + # Create the Query class with resolvers + query_dict = {'__annotations__': {}} + + for schema_name, row_schema in self.schemas.items(): + graphql_type = self.graphql_types[schema_name] + filter_type = self.filter_types[schema_name] + + # Create resolver function for this schema + resolver_func = self._make_resolver( + schema_name, row_schema, graphql_type, filter_type, query_callback + ) + + # Add field to query dictionary + query_dict[schema_name] = strawberry.field(resolver=resolver_func) + query_dict['__annotations__'][schema_name] = List[graphql_type] + + # Create the Query class + Query = type('Query', (), query_dict) + Query = strawberry.type(Query) + + # Create the schema with auto_camel_case disabled to keep snake_case field names + schema = strawberry.Schema( + query=Query, + config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False) + ) + logger.info(f"Generated GraphQL schema with {len(self.schemas)} types") + return schema + + def _get_python_type(self, field_type: str): + """Convert schema field type to Python type for GraphQL.""" + type_mapping = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "timestamp": str, # Use string for timestamps in GraphQL + "date": str, + "time": str, + "uuid": str + } + return type_mapping.get(field_type, str) + + def _create_graphql_type(self, schema_name: str, row_schema) -> type: + """Create a GraphQL output type from a RowSchema.""" + # Create annotations for the GraphQL type + annotations = {} + defaults = {} + + for field in row_schema.fields: + python_type = self._get_python_type(field.type) + + # Make field optional if not required + if not field.required and not field.primary: + annotations[field.name] = Optional[python_type] + defaults[field.name] = None + else: + annotations[field.name] = python_type + + # Create the class dynamically + type_name = f"{schema_name.capitalize()}Type" + graphql_class = type( + type_name, + (), + { + "__annotations__": annotations, + **defaults + } + ) + + # Apply strawberry decorator + return strawberry.type(graphql_class) + + def _create_filter_type(self, schema_name: str, row_schema) -> type: + """Create a dynamic filter input type for a schema.""" + filter_type_name = f"{schema_name.capitalize()}Filter" + + # Add __annotations__ and defaults for the fields + annotations = {} + defaults = {} + + logger.debug(f"Creating filter type {filter_type_name} for schema {schema_name}") + + for field in row_schema.fields: + logger.debug( + f"Field {field.name}: type={field.type}, " + f"indexed={field.indexed}, primary={field.primary}" + ) + + # Allow filtering on any field + if field.type == "integer": + annotations[field.name] = Optional[IntFilter] + defaults[field.name] = None + elif field.type == "float": + annotations[field.name] = Optional[FloatFilter] + defaults[field.name] = None + elif field.type == "string": + annotations[field.name] = Optional[StringFilter] + defaults[field.name] = None + + logger.debug( + f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}" + ) + + # Create the class dynamically + FilterType = type( + filter_type_name, + (), + { + "__annotations__": annotations, + **defaults + } + ) + + # Apply strawberry input decorator + FilterType = strawberry.input(FilterType) + + return FilterType + + def _make_resolver( + self, + schema_name: str, + row_schema, + graphql_type: type, + filter_type: type, + query_callback: QueryCallback + ): + """Create a resolver function for a schema.""" + from .filters import parse_where_clause + + async def resolver( + info: Info, + where: Optional[filter_type] = None, + order_by: Optional[str] = None, + direction: Optional[SortDirection] = None, + limit: Optional[int] = 100 + ) -> List[graphql_type]: + # Get context values + user = info.context["user"] + collection = info.context["collection"] + + # Parse the where clause + filters = parse_where_clause(where) + + # Call the query backend + results = await query_callback( + user, collection, schema_name, row_schema, + filters, limit, order_by, direction + ) + + # Convert to GraphQL types + graphql_results = [] + for row in results: + graphql_obj = graphql_type(**row) + graphql_results.append(graphql_obj) + + return graphql_results + + return resolver diff --git a/trustgraph-flow/trustgraph/query/graphql/types.py b/trustgraph-flow/trustgraph/query/graphql/types.py new file mode 100644 index 00000000..4d288bb6 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/graphql/types.py @@ -0,0 +1,56 @@ +""" +GraphQL filter and sort types for row queries. + +These types are used to build dynamic GraphQL schemas for querying +structured row data. +""" + +from typing import Optional, List +from enum import Enum + +import strawberry + + +@strawberry.input +class IntFilter: + """Filter type for integer fields.""" + eq: Optional[int] = None + gt: Optional[int] = None + gte: Optional[int] = None + lt: Optional[int] = None + lte: Optional[int] = None + in_: Optional[List[int]] = strawberry.field(name="in", default=None) + not_: Optional[int] = strawberry.field(name="not", default=None) + not_in: Optional[List[int]] = None + + +@strawberry.input +class StringFilter: + """Filter type for string fields.""" + eq: Optional[str] = None + contains: Optional[str] = None + startsWith: Optional[str] = None + endsWith: Optional[str] = None + in_: Optional[List[str]] = strawberry.field(name="in", default=None) + not_: Optional[str] = strawberry.field(name="not", default=None) + not_in: Optional[List[str]] = None + + +@strawberry.input +class FloatFilter: + """Filter type for float fields.""" + eq: Optional[float] = None + gt: Optional[float] = None + gte: Optional[float] = None + lt: Optional[float] = None + lte: Optional[float] = None + in_: Optional[List[float]] = strawberry.field(name="in", default=None) + not_: Optional[float] = strawberry.field(name="not", default=None) + not_in: Optional[List[float]] = None + + +@strawberry.enum +class SortDirection(Enum): + """Sort direction for query results.""" + ASC = "asc" + DESC = "desc" diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py b/trustgraph-flow/trustgraph/query/objects/cassandra/service.py deleted file mode 100644 index a6683c40..00000000 --- a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py +++ /dev/null @@ -1,738 +0,0 @@ -""" -Objects query service using GraphQL. Input is a GraphQL query with variables. -Output is GraphQL response data with any errors. -""" - -import json -import logging -import asyncio -from typing import Dict, Any, Optional, List, Set -from enum import Enum -from dataclasses import dataclass, field -from cassandra.cluster import Cluster -from cassandra.auth import PlainTextAuthProvider - -import strawberry -from strawberry import Schema -from strawberry.types import Info -from strawberry.scalars import JSON -from strawberry.tools import create_type - -from .... schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError -from .... schema import Error, RowSchema, Field as SchemaField -from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config - -# Module logger -logger = logging.getLogger(__name__) - -default_ident = "objects-query" - -# GraphQL filter input types -@strawberry.input -class IntFilter: - eq: Optional[int] = None - gt: Optional[int] = None - gte: Optional[int] = None - lt: Optional[int] = None - lte: Optional[int] = None - in_: Optional[List[int]] = strawberry.field(name="in", default=None) - not_: Optional[int] = strawberry.field(name="not", default=None) - not_in: Optional[List[int]] = None - -@strawberry.input -class StringFilter: - eq: Optional[str] = None - contains: Optional[str] = None - startsWith: Optional[str] = None - endsWith: Optional[str] = None - in_: Optional[List[str]] = strawberry.field(name="in", default=None) - not_: Optional[str] = strawberry.field(name="not", default=None) - not_in: Optional[List[str]] = None - -@strawberry.input -class FloatFilter: - eq: Optional[float] = None - gt: Optional[float] = None - gte: Optional[float] = None - lt: Optional[float] = None - lte: Optional[float] = None - in_: Optional[List[float]] = strawberry.field(name="in", default=None) - not_: Optional[float] = strawberry.field(name="not", default=None) - not_in: Optional[List[float]] = None - - -class Processor(FlowProcessor): - - def __init__(self, **params): - - id = params.get("id", default_ident) - - # Get Cassandra parameters - cassandra_host = params.get("cassandra_host") - cassandra_username = params.get("cassandra_username") - cassandra_password = params.get("cassandra_password") - - # Resolve configuration with environment variable fallback - hosts, username, password, keyspace = resolve_cassandra_config( - host=cassandra_host, - username=cassandra_username, - password=cassandra_password - ) - - # Store resolved configuration with proper names - self.cassandra_host = hosts # Store as list - self.cassandra_username = username - self.cassandra_password = password - - # Config key for schemas - self.config_key = params.get("config_type", "schema") - - super(Processor, self).__init__( - **params | { - "id": id, - "config_type": self.config_key, - } - ) - - self.register_specification( - ConsumerSpec( - name = "request", - schema = ObjectsQueryRequest, - handler = self.on_message - ) - ) - - self.register_specification( - ProducerSpec( - name = "response", - schema = ObjectsQueryResponse, - ) - ) - - # Register config handler for schema updates - self.register_config_handler(self.on_schema_config) - - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - - # GraphQL schema - self.graphql_schema: Optional[Schema] = None - - # GraphQL types cache - self.graphql_types: Dict[str, type] = {} - - # Cassandra session - self.cluster = None - self.session = None - - # Known keyspaces and tables - self.known_keyspaces: Set[str] = set() - self.known_tables: Dict[str, Set[str]] = {} - - def connect_cassandra(self): - """Connect to Cassandra cluster""" - if self.session: - return - - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) - - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise - - def sanitize_name(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - if safe_name and not safe_name[0].isalpha(): - safe_name = 'o_' + safe_name - return safe_name.lower() - - def sanitize_table(self, name: str) -> str: - """Sanitize table names for Cassandra compatibility""" - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - safe_name = 'o_' + safe_name - return safe_name.lower() - - def parse_filter_key(self, filter_key: str) -> tuple[str, str]: - """Parse GraphQL filter key into field name and operator""" - if not filter_key: - return ("", "eq") - - # Support common GraphQL filter patterns: - # field_name -> (field_name, "eq") - # field_name_gt -> (field_name, "gt") - # field_name_gte -> (field_name, "gte") - # field_name_lt -> (field_name, "lt") - # field_name_lte -> (field_name, "lte") - # field_name_in -> (field_name, "in") - - operators = ["_gte", "_lte", "_gt", "_lt", "_in", "_eq"] - - for op_suffix in operators: - if filter_key.endswith(op_suffix): - field_name = filter_key[:-len(op_suffix)] - operator = op_suffix[1:] # Remove the leading underscore - return (field_name, operator) - - # Default to equality if no operator suffix found - return (filter_key, "eq") - - async def on_schema_config(self, config, version): - """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - self.graphql_types = {} - - # Check if our config type exists - if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") - return - - # Get the schemas dictionary for our type - schemas_config = config[self.config_key] - - # Process each schema in the schemas config - for schema_name, schema_json in schemas_config.items(): - try: - # Parse the JSON schema definition - schema_def = json.loads(schema_json) - - # Create Field objects - fields = [] - for field_def in schema_def.get("fields", []): - field = SchemaField( - name=field_def["name"], - type=field_def["type"], - size=field_def.get("size", 0), - primary=field_def.get("primary_key", False), - description=field_def.get("description", ""), - required=field_def.get("required", False), - enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) - ) - fields.append(field) - - # Create RowSchema - row_schema = RowSchema( - name=schema_def.get("name", schema_name), - description=schema_def.get("description", ""), - fields=fields - ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - - except Exception as e: - logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") - - # Regenerate GraphQL schema - self.generate_graphql_schema() - - def get_python_type(self, field_type: str): - """Convert schema field type to Python type for GraphQL""" - type_mapping = { - "string": str, - "integer": int, - "float": float, - "boolean": bool, - "timestamp": str, # Use string for timestamps in GraphQL - "date": str, - "time": str, - "uuid": str - } - return type_mapping.get(field_type, str) - - def create_graphql_type(self, schema_name: str, row_schema: RowSchema) -> type: - """Create a GraphQL type from a RowSchema""" - - # Create annotations for the GraphQL type - annotations = {} - defaults = {} - - for field in row_schema.fields: - python_type = self.get_python_type(field.type) - - # Make field optional if not required - if not field.required and not field.primary: - annotations[field.name] = Optional[python_type] - defaults[field.name] = None - else: - annotations[field.name] = python_type - - # Create the class dynamically - type_name = f"{schema_name.capitalize()}Type" - graphql_class = type( - type_name, - (), - { - "__annotations__": annotations, - **defaults - } - ) - - # Apply strawberry decorator - return strawberry.type(graphql_class) - - def create_filter_type_for_schema(self, schema_name: str, row_schema: RowSchema): - """Create a dynamic filter input type for a schema""" - # Create the filter type dynamically - filter_type_name = f"{schema_name.capitalize()}Filter" - - # Add __annotations__ and defaults for the fields - annotations = {} - defaults = {} - - logger.info(f"Creating filter type {filter_type_name} for schema {schema_name}") - - for field in row_schema.fields: - logger.info(f"Field {field.name}: type={field.type}, indexed={field.indexed}, primary={field.primary}") - - # Allow filtering on any field for now, not just indexed/primary - # if field.indexed or field.primary: - if field.type == "integer": - annotations[field.name] = Optional[IntFilter] - defaults[field.name] = None - logger.info(f"Added IntFilter for {field.name}") - elif field.type == "float": - annotations[field.name] = Optional[FloatFilter] - defaults[field.name] = None - logger.info(f"Added FloatFilter for {field.name}") - elif field.type == "string": - annotations[field.name] = Optional[StringFilter] - defaults[field.name] = None - logger.info(f"Added StringFilter for {field.name}") - - logger.info(f"Filter type {filter_type_name} will have fields: {list(annotations.keys())}") - - # Create the class dynamically - FilterType = type( - filter_type_name, - (), - { - "__annotations__": annotations, - **defaults - } - ) - - # Apply strawberry input decorator - FilterType = strawberry.input(FilterType) - - return FilterType - - def create_sort_direction_enum(self): - """Create sort direction enum""" - @strawberry.enum - class SortDirection(Enum): - ASC = "asc" - DESC = "desc" - - return SortDirection - - def parse_idiomatic_where_clause(self, where_obj) -> Dict[str, Any]: - """Parse the idiomatic nested filter structure""" - if not where_obj: - return {} - - conditions = {} - - logger.info(f"Parsing where clause: {where_obj}") - - for field_name, filter_obj in where_obj.__dict__.items(): - if filter_obj is None: - continue - - logger.info(f"Processing field {field_name} with filter_obj: {filter_obj}") - - if hasattr(filter_obj, '__dict__'): - # This is a filter object (StringFilter, IntFilter, etc.) - for operator, value in filter_obj.__dict__.items(): - if value is not None: - logger.info(f"Found operator {operator} with value {value}") - # Map GraphQL operators to our internal format - if operator == "eq": - conditions[field_name] = value - elif operator in ["gt", "gte", "lt", "lte"]: - conditions[f"{field_name}_{operator}"] = value - elif operator == "in_": - conditions[f"{field_name}_in"] = value - elif operator == "contains": - conditions[f"{field_name}_contains"] = value - - logger.info(f"Final parsed conditions: {conditions}") - return conditions - - def generate_graphql_schema(self): - """Generate GraphQL schema from loaded schemas using dynamic filter types""" - if not self.schemas: - logger.warning("No schemas loaded, cannot generate GraphQL schema") - self.graphql_schema = None - return - - # Create GraphQL types and filter types for each schema - filter_types = {} - sort_direction_enum = self.create_sort_direction_enum() - - for schema_name, row_schema in self.schemas.items(): - graphql_type = self.create_graphql_type(schema_name, row_schema) - filter_type = self.create_filter_type_for_schema(schema_name, row_schema) - - self.graphql_types[schema_name] = graphql_type - filter_types[schema_name] = filter_type - - # Create the Query class with resolvers - query_dict = {'__annotations__': {}} - - for schema_name, row_schema in self.schemas.items(): - graphql_type = self.graphql_types[schema_name] - filter_type = filter_types[schema_name] - - # Create resolver function for this schema - def make_resolver(s_name, r_schema, g_type, f_type, sort_enum): - async def resolver( - info: Info, - where: Optional[f_type] = None, - order_by: Optional[str] = None, - direction: Optional[sort_enum] = None, - limit: Optional[int] = 100 - ) -> List[g_type]: - # Get the processor instance from context - processor = info.context["processor"] - user = info.context["user"] - collection = info.context["collection"] - - # Parse the idiomatic where clause - filters = processor.parse_idiomatic_where_clause(where) - - # Query Cassandra - results = await processor.query_cassandra( - user, collection, s_name, r_schema, - filters, limit, order_by, direction - ) - - # Convert to GraphQL types - graphql_results = [] - for row in results: - graphql_obj = g_type(**row) - graphql_results.append(graphql_obj) - - return graphql_results - - return resolver - - # Add resolver to query - resolver_name = schema_name - resolver_func = make_resolver(schema_name, row_schema, graphql_type, filter_type, sort_direction_enum) - - # Add field to query dictionary - query_dict[resolver_name] = strawberry.field(resolver=resolver_func) - query_dict['__annotations__'][resolver_name] = List[graphql_type] - - # Create the Query class - Query = type('Query', (), query_dict) - Query = strawberry.type(Query) - - # Create the schema with auto_camel_case disabled to keep snake_case field names - self.graphql_schema = strawberry.Schema( - query=Query, - config=strawberry.schema.config.StrawberryConfig(auto_camel_case=False) - ) - logger.info(f"Generated GraphQL schema with {len(self.schemas)} types") - - async def query_cassandra( - self, - user: str, - collection: str, - schema_name: str, - row_schema: RowSchema, - filters: Dict[str, Any], - limit: int, - order_by: Optional[str] = None, - direction: Optional[Any] = None - ) -> List[Dict[str, Any]]: - """Execute a query against Cassandra""" - - # Connect if needed - self.connect_cassandra() - - # Build the query - keyspace = self.sanitize_name(user) - table = self.sanitize_table(schema_name) - - # Start with basic SELECT - query = f"SELECT * FROM {keyspace}.{table}" - - # Add WHERE clauses - where_clauses = [f"collection = %s"] - params = [collection] - - # Add filters for indexed or primary key fields - for filter_key, value in filters.items(): - if value is not None: - # Parse field name and operator from filter key - logger.debug(f"Parsing filter key: '{filter_key}' (type: {type(filter_key)})") - result = self.parse_filter_key(filter_key) - logger.debug(f"parse_filter_key returned: {result} (type: {type(result)}, len: {len(result) if hasattr(result, '__len__') else 'N/A'})") - - if not result or len(result) != 2: - logger.error(f"parse_filter_key returned invalid result: {result}") - continue # Skip this filter - - field_name, operator = result - - # Find the field in schema - schema_field = None - for f in row_schema.fields: - if f.name == field_name: - schema_field = f - break - - if schema_field: - safe_field = self.sanitize_name(field_name) - - # Build WHERE clause based on operator - if operator == "eq": - where_clauses.append(f"{safe_field} = %s") - params.append(value) - elif operator == "gt": - where_clauses.append(f"{safe_field} > %s") - params.append(value) - elif operator == "gte": - where_clauses.append(f"{safe_field} >= %s") - params.append(value) - elif operator == "lt": - where_clauses.append(f"{safe_field} < %s") - params.append(value) - elif operator == "lte": - where_clauses.append(f"{safe_field} <= %s") - params.append(value) - elif operator == "in": - if isinstance(value, list): - placeholders = ",".join(["%s"] * len(value)) - where_clauses.append(f"{safe_field} IN ({placeholders})") - params.extend(value) - else: - # Default to equality for unknown operators - where_clauses.append(f"{safe_field} = %s") - params.append(value) - - if where_clauses: - query += " WHERE " + " AND ".join(where_clauses) - - # Add ORDER BY if requested (will try Cassandra first, then fall back to post-query sort) - cassandra_order_by_added = False - if order_by and direction: - # Validate that order_by field exists in schema - order_field_exists = any(f.name == order_by for f in row_schema.fields) - if order_field_exists: - safe_order_field = self.sanitize_name(order_by) - direction_str = "ASC" if direction.value == "asc" else "DESC" - # Add ORDER BY - if Cassandra rejects it, we'll catch the error during execution - query += f" ORDER BY {safe_order_field} {direction_str}" - - # Add limit first (must come before ALLOW FILTERING) - if limit: - query += f" LIMIT {limit}" - - # Add ALLOW FILTERING for now (should optimize with proper indexes later) - query += " ALLOW FILTERING" - - # Execute query - try: - result = self.session.execute(query, params) - cassandra_order_by_added = True # If we get here, Cassandra handled ORDER BY - except Exception as e: - # If ORDER BY fails, try without it - if order_by and direction and "ORDER BY" in query: - logger.info(f"Cassandra rejected ORDER BY, falling back to post-query sorting: {e}") - # Remove ORDER BY clause and retry - query_parts = query.split(" ORDER BY ") - if len(query_parts) == 2: - query_without_order = query_parts[0] + " LIMIT " + str(limit) + " ALLOW FILTERING" if limit else " ALLOW FILTERING" - result = self.session.execute(query_without_order, params) - cassandra_order_by_added = False - else: - raise - else: - raise - - # Convert rows to dicts - results = [] - for row in result: - row_dict = {} - for field in row_schema.fields: - safe_field = self.sanitize_name(field.name) - if hasattr(row, safe_field): - value = getattr(row, safe_field) - # Use original field name in result - row_dict[field.name] = value - results.append(row_dict) - - # Post-query sorting if Cassandra didn't handle ORDER BY - if order_by and direction and not cassandra_order_by_added: - reverse_order = (direction.value == "desc") - try: - results.sort(key=lambda x: x.get(order_by, 0), reverse=reverse_order) - except Exception as e: - logger.warning(f"Failed to sort results by {order_by}: {e}") - - return results - - async def execute_graphql_query( - self, - query: str, - variables: Dict[str, Any], - operation_name: Optional[str], - user: str, - collection: str - ) -> Dict[str, Any]: - """Execute a GraphQL query""" - - if not self.graphql_schema: - raise RuntimeError("No GraphQL schema available - no schemas loaded") - - # Create context for the query - context = { - "processor": self, - "user": user, - "collection": collection - } - - # Execute the query - result = await self.graphql_schema.execute( - query, - variable_values=variables, - operation_name=operation_name, - context_value=context - ) - - # Build response - response = {} - - if result.data: - response["data"] = result.data - else: - response["data"] = None - - if result.errors: - response["errors"] = [ - { - "message": str(error), - "path": getattr(error, "path", []), - "extensions": getattr(error, "extensions", {}) - } - for error in result.errors - ] - else: - response["errors"] = [] - - # Add extensions if any - if hasattr(result, "extensions") and result.extensions: - response["extensions"] = result.extensions - - return response - - async def on_message(self, msg, consumer, flow): - """Handle incoming query request""" - - try: - request = msg.value() - - # Sender-produced ID - id = msg.properties()["id"] - - logger.debug(f"Handling objects query request {id}...") - - # Execute GraphQL query - result = await self.execute_graphql_query( - query=request.query, - variables=dict(request.variables) if request.variables else {}, - operation_name=request.operation_name, - user=request.user, - collection=request.collection - ) - - # Create response - graphql_errors = [] - if "errors" in result and result["errors"]: - for err in result["errors"]: - graphql_error = GraphQLError( - message=err.get("message", ""), - path=err.get("path", []), - extensions=err.get("extensions", {}) - ) - graphql_errors.append(graphql_error) - - response = ObjectsQueryResponse( - error=None, - data=json.dumps(result.get("data")) if result.get("data") else "null", - errors=graphql_errors, - extensions=result.get("extensions", {}) - ) - - logger.debug("Sending objects query response...") - await flow("response").send(response, properties={"id": id}) - - logger.debug("Objects query request completed") - - except Exception as e: - - logger.error(f"Exception in objects query service: {e}", exc_info=True) - - logger.info("Sending error response...") - - response = ObjectsQueryResponse( - error = Error( - type = "objects-query-error", - message = str(e), - ), - data = None, - errors = [], - extensions = {} - ) - - await flow("response").send(response, properties={"id": id}) - - def close(self): - """Clean up Cassandra connections""" - if self.cluster: - self.cluster.shutdown() - logger.info("Closed Cassandra connection") - - @staticmethod - def add_args(parser): - """Add command-line arguments""" - - FlowProcessor.add_args(parser) - add_cassandra_args(parser) - - parser.add_argument( - '--config-type', - default='schema', - help='Configuration type prefix for schemas (default: schema)' - ) - -def run(): - """Entry point for objects-query-graphql-cassandra command""" - Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py new file mode 100644 index 00000000..6c6391f5 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/__init__.py @@ -0,0 +1,3 @@ +""" +Row embeddings query modules. +""" diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py new file mode 100644 index 00000000..a4ca1c85 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__init__.py @@ -0,0 +1,5 @@ +""" +Qdrant row embeddings query service. +""" + +from .service import Processor, run, default_ident diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py new file mode 100644 index 00000000..66f42e76 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/__main__.py @@ -0,0 +1,4 @@ + +from .service import run + +run() diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py new file mode 100644 index 00000000..7ed6192f --- /dev/null +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -0,0 +1,209 @@ +""" +Row embeddings query service for Qdrant. + +Input is query vectors plus user/collection/schema context. +Output is matching row index information (index_name, index_value) for +use in subsequent Cassandra lookups. +""" + +import logging +import re +from typing import Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue + +from .... schema import ( + RowEmbeddingsRequest, RowEmbeddingsResponse, + RowIndexMatch, Error +) +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings-query" +default_store_uri = 'http://localhost:6333' + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) + + super(Processor, self).__init__( + **params | { + "id": id, + "store_uri": store_uri, + "api_key": api_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=RowEmbeddingsRequest, + handler=self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=RowEmbeddingsResponse + ) + ) + + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Qdrant collection naming""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]: + """Find the Qdrant collection for a given user/collection/schema""" + prefix = ( + f"rows_{self.sanitize_name(user)}_" + f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" + ) + + try: + all_collections = self.qdrant.get_collections().collections + matching = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if matching: + # Return first match (there should typically be only one per dimension) + return matching[0] + + except Exception as e: + logger.error(f"Failed to list Qdrant collections: {e}", exc_info=True) + + return None + + async def query_row_embeddings(self, request: RowEmbeddingsRequest): + """Execute row embeddings query""" + + matches = [] + + # Find the collection for this user/collection/schema + qdrant_collection = self.find_collection( + request.user, request.collection, request.schema_name + ) + + if not qdrant_collection: + logger.info( + f"No Qdrant collection found for " + f"{request.user}/{request.collection}/{request.schema_name}" + ) + return matches + + for vec in request.vectors: + try: + # Build optional filter for index_name + query_filter = None + if request.index_name: + query_filter = Filter( + must=[ + FieldCondition( + key="index_name", + match=MatchValue(value=request.index_name) + ) + ] + ) + + # Query Qdrant + search_result = self.qdrant.query_points( + collection_name=qdrant_collection, + query=vec, + limit=request.limit, + with_payload=True, + query_filter=query_filter, + ).points + + # Convert to RowIndexMatch objects + for point in search_result: + payload = point.payload or {} + match = RowIndexMatch( + index_name=payload.get("index_name", ""), + index_value=payload.get("index_value", []), + text=payload.get("text", ""), + score=point.score if hasattr(point, 'score') else 0.0 + ) + matches.append(match) + + except Exception as e: + logger.error(f"Failed to query Qdrant: {e}", exc_info=True) + raise + + return matches + + async def on_message(self, msg, consumer, flow): + """Handle incoming query request""" + + try: + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.debug( + f"Handling row embeddings query for " + f"{request.user}/{request.collection}/{request.schema_name}..." + ) + + # Execute query + matches = await self.query_row_embeddings(request) + + response = RowEmbeddingsResponse( + error=None, + matches=matches + ) + + logger.debug(f"Returning {len(matches)} matches") + await flow("response").send(response, properties={"id": id}) + + except Exception as e: + logger.error(f"Exception in row embeddings query: {e}", exc_info=True) + + response = RowEmbeddingsResponse( + error=Error( + type="row-embeddings-query-error", + message=str(e) + ), + matches=[] + ) + + await flow("response").send(response, properties={"id": id}) + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + + parser.add_argument( + '-t', '--store-uri', + default=default_store_uri, + help=f'Qdrant store URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help='API key for Qdrant (default: None)' + ) + + +def run(): + """Entry point for row-embeddings-query-qdrant command""" + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/objects/__init__.py b/trustgraph-flow/trustgraph/query/rows/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/__init__.py rename to trustgraph-flow/trustgraph/query/rows/__init__.py diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/__init__.py b/trustgraph-flow/trustgraph/query/rows/cassandra/__init__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/cassandra/__init__.py rename to trustgraph-flow/trustgraph/query/rows/cassandra/__init__.py diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/__main__.py b/trustgraph-flow/trustgraph/query/rows/cassandra/__main__.py similarity index 100% rename from trustgraph-flow/trustgraph/query/objects/cassandra/__main__.py rename to trustgraph-flow/trustgraph/query/rows/cassandra/__main__.py diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py new file mode 100644 index 00000000..3808cdb0 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -0,0 +1,523 @@ +""" +Row query service using GraphQL. Input is a GraphQL query with variables. +Output is GraphQL response data with any errors. + +Queries against the unified 'rows' table with schema: + - collection: text + - schema_name: text + - index_name: text + - index_value: frozen> + - data: map + - source: text +""" + +import json +import logging +import re +from typing import Dict, Any, Optional, List, Set + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider + +from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError +from .... schema import Error, RowSchema, Field as SchemaField +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config + +from ... graphql import GraphQLSchemaBuilder, SortDirection + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "rows-query" + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + # Get Cassandra parameters + cassandra_host = params.get("cassandra_host") + cassandra_username = params.get("cassandra_username") + cassandra_password = params.get("cassandra_password") + + # Resolve configuration with environment variable fallback + hosts, username, password, keyspace = resolve_cassandra_config( + host=cassandra_host, + username=cassandra_username, + password=cassandra_password + ) + + # Store resolved configuration with proper names + self.cassandra_host = hosts # Store as list + self.cassandra_username = username + self.cassandra_password = password + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + + super(Processor, self).__init__( + **params | { + "id": id, + "config_type": self.config_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=RowsQueryRequest, + handler=self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=RowsQueryResponse, + ) + ) + + # Register config handler for schema updates + self.register_config_handler(self.on_schema_config) + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + # GraphQL schema builder and generated schema + self.schema_builder = GraphQLSchemaBuilder() + self.graphql_schema = None + + # Cassandra session + self.cluster = None + self.session = None + + # Known keyspaces + self.known_keyspaces: Set[str] = set() + + def connect_cassandra(self): + """Connect to Cassandra cluster""" + if self.session: + return + + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + self.cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + self.cluster = Cluster(contact_points=self.cassandra_host) + + self.session = self.cluster.connect() + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Cassandra compatibility""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") + + # Clear existing schemas + self.schemas = {} + self.schema_builder.clear() + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = SchemaField( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + self.schema_builder.add_schema(schema_name, row_schema) + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + # Regenerate GraphQL schema + self.graphql_schema = self.schema_builder.build(self.query_cassandra) + + def get_index_names(self, schema: RowSchema) -> List[str]: + """Get all index names for a schema.""" + index_names = [] + for field in schema.fields: + if field.primary or field.indexed: + index_names.append(field.name) + return index_names + + def find_matching_index( + self, + schema: RowSchema, + filters: Dict[str, Any] + ) -> Optional[tuple]: + """ + Find an index that can satisfy the query filters. + Returns (index_name, index_value) if found, None otherwise. + + For exact match queries, we need a filter on an indexed field. + """ + index_names = self.get_index_names(schema) + + # Look for an exact match filter on an indexed field + for index_name in index_names: + if index_name in filters: + value = filters[index_name] + # Single field index -> single element list + index_value = [str(value)] + return (index_name, index_value) + + return None + + async def query_cassandra( + self, + user: str, + collection: str, + schema_name: str, + row_schema: RowSchema, + filters: Dict[str, Any], + limit: int, + order_by: Optional[str] = None, + direction: Optional[SortDirection] = None + ) -> List[Dict[str, Any]]: + """ + Execute a query against the unified Cassandra rows table. + + For exact match queries on indexed fields, we can query directly. + For other queries, we need to scan and post-filter. + """ + # Connect if needed + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Try to find an index that matches the filters + index_match = self.find_matching_index(row_schema, filters) + + results = [] + + if index_match: + # Direct query using index + index_name, index_value = index_match + + query = f""" + SELECT data, source FROM {safe_keyspace}.rows + WHERE collection = %s + AND schema_name = %s + AND index_name = %s + AND index_value = %s + """ + params = [collection, schema_name, index_name, index_value] + + if limit: + query += f" LIMIT {limit}" + + try: + rows = self.session.execute(query, params) + for row in rows: + # Convert data map to dict with proper field names + row_dict = dict(row.data) if row.data else {} + results.append(row_dict) + except Exception as e: + logger.error(f"Failed to query rows: {e}", exc_info=True) + raise + + else: + # No direct index match - scan all rows for this schema + # This is less efficient but necessary for non-indexed queries + logger.warning( + f"No index match for filters {filters} - scanning all indexes" + ) + + # Get all index names for this schema + index_names = self.get_index_names(row_schema) + + if not index_names: + logger.warning(f"Schema {schema_name} has no indexes") + return [] + + # Query using the first index (arbitrary choice for scan) + primary_index = index_names[0] + + # We need to scan all values for this index + # This requires ALLOW FILTERING or a different approach + query = f""" + SELECT data, source FROM {safe_keyspace}.rows + WHERE collection = %s + AND schema_name = %s + AND index_name = %s + ALLOW FILTERING + """ + params = [collection, schema_name, primary_index] + + try: + rows = self.session.execute(query, params) + + for row in rows: + row_dict = dict(row.data) if row.data else {} + + # Apply post-filters + if self._matches_filters(row_dict, filters, row_schema): + results.append(row_dict) + + if limit and len(results) >= limit: + break + + except Exception as e: + logger.error(f"Failed to scan rows: {e}", exc_info=True) + raise + + # Post-query sorting if requested + if order_by and results: + reverse_order = direction and direction.value == "desc" + try: + results.sort( + key=lambda x: x.get(order_by, ""), + reverse=reverse_order + ) + except Exception as e: + logger.warning(f"Failed to sort results by {order_by}: {e}") + + return results + + def _matches_filters( + self, + row_dict: Dict[str, Any], + filters: Dict[str, Any], + row_schema: RowSchema + ) -> bool: + """Check if a row matches the given filters.""" + for filter_key, filter_value in filters.items(): + if filter_value is None: + continue + + # Parse filter key for operator + if '_' in filter_key: + parts = filter_key.rsplit('_', 1) + if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']: + field_name = parts[0] + operator = parts[1] + else: + field_name = filter_key + operator = 'eq' + else: + field_name = filter_key + operator = 'eq' + + row_value = row_dict.get(field_name) + if row_value is None: + return False + + # Convert types for comparison + try: + if operator == 'eq': + if str(row_value) != str(filter_value): + return False + elif operator == 'gt': + if float(row_value) <= float(filter_value): + return False + elif operator == 'gte': + if float(row_value) < float(filter_value): + return False + elif operator == 'lt': + if float(row_value) >= float(filter_value): + return False + elif operator == 'lte': + if float(row_value) > float(filter_value): + return False + elif operator == 'contains': + if str(filter_value) not in str(row_value): + return False + elif operator == 'in': + if str(row_value) not in [str(v) for v in filter_value]: + return False + except (ValueError, TypeError): + return False + + return True + + async def execute_graphql_query( + self, + query: str, + variables: Dict[str, Any], + operation_name: Optional[str], + user: str, + collection: str + ) -> Dict[str, Any]: + """Execute a GraphQL query""" + + if not self.graphql_schema: + raise RuntimeError("No GraphQL schema available - no schemas loaded") + + # Create context for the query + context = { + "processor": self, + "user": user, + "collection": collection + } + + # Execute the query + result = await self.graphql_schema.execute( + query, + variable_values=variables, + operation_name=operation_name, + context_value=context + ) + + # Build response + response = {} + + if result.data: + response["data"] = result.data + else: + response["data"] = None + + if result.errors: + response["errors"] = [ + { + "message": str(error), + "path": getattr(error, "path", []), + "extensions": getattr(error, "extensions", {}) + } + for error in result.errors + ] + else: + response["errors"] = [] + + # Add extensions if any + if hasattr(result, "extensions") and result.extensions: + response["extensions"] = result.extensions + + return response + + async def on_message(self, msg, consumer, flow): + """Handle incoming query request""" + + try: + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.debug(f"Handling objects query request {id}...") + + # Execute GraphQL query + result = await self.execute_graphql_query( + query=request.query, + variables=dict(request.variables) if request.variables else {}, + operation_name=request.operation_name, + user=request.user, + collection=request.collection + ) + + # Create response + graphql_errors = [] + if "errors" in result and result["errors"]: + for err in result["errors"]: + graphql_error = GraphQLError( + message=err.get("message", ""), + path=err.get("path", []), + extensions=err.get("extensions", {}) + ) + graphql_errors.append(graphql_error) + + response = RowsQueryResponse( + error=None, + data=json.dumps(result.get("data")) if result.get("data") else "null", + errors=graphql_errors, + extensions=result.get("extensions", {}) + ) + + logger.debug("Sending objects query response...") + await flow("response").send(response, properties={"id": id}) + + logger.debug("Objects query request completed") + + except Exception as e: + + logger.error(f"Exception in rows query service: {e}", exc_info=True) + + logger.info("Sending error response...") + + response = RowsQueryResponse( + error=Error( + type="rows-query-error", + message=str(e), + ), + data=None, + errors=[], + extensions={} + ) + + await flow("response").send(response, properties={"id": id}) + + def close(self): + """Clean up Cassandra connections""" + if self.cluster: + self.cluster.shutdown() + logger.info("Closed Cassandra connection") + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + add_cassandra_args(parser) + + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + +def run(): + """Entry point for rows-query-cassandra command""" + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index 4b1a04a4..e39f9041 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -1,6 +1,6 @@ """ Structured Query Service - orchestrates natural language question processing. -Takes a question, converts it to GraphQL via nlp-query, executes via objects-query, +Takes a question, converts it to GraphQL via nlp-query, executes via rows-query, and returns the results. """ @@ -10,7 +10,7 @@ from typing import Dict, Any, Optional from ...schema import StructuredQueryRequest, StructuredQueryResponse from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse -from ...schema import ObjectsQueryRequest, ObjectsQueryResponse +from ...schema import RowsQueryRequest, RowsQueryResponse from ...schema import Error from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, RequestResponseSpec @@ -57,13 +57,13 @@ class Processor(FlowProcessor): ) ) - # Client spec for calling objects query service + # Client spec for calling rows query service self.register_specification( RequestResponseSpec( - request_name = "objects-query-request", - response_name = "objects-query-response", - request_schema = ObjectsQueryRequest, - response_schema = ObjectsQueryResponse + request_name = "rows-query-request", + response_name = "rows-query-response", + request_schema = RowsQueryRequest, + response_schema = RowsQueryResponse ) ) @@ -112,7 +112,7 @@ class Processor(FlowProcessor): variables_as_strings[key] = str(value) # Use user/collection values from request - objects_request = ObjectsQueryRequest( + objects_request = RowsQueryRequest( user=request.user, collection=request.collection, query=nlp_response.graphql_query, @@ -120,12 +120,12 @@ class Processor(FlowProcessor): operation_name=None ) - objects_response = await flow("objects-query-request").request(objects_request) - + objects_response = await flow("rows-query-request").request(objects_request) + if objects_response.error is not None: - raise Exception(f"Objects query service error: {objects_response.error.message}") - - # Handle GraphQL errors from the objects query service + raise Exception(f"Rows query service error: {objects_response.error.message}") + + # Handle GraphQL errors from the rows query service graphql_errors = [] if objects_response.errors: for gql_error in objects_response.errors: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 07dbf0eb..ae869413 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -13,7 +13,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_store_uri = 'http://localhost:19530' class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 6d1b23ba..a0e52253 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -18,7 +18,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index edfa8aa9..cb978048 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -16,7 +16,7 @@ from .... base import ConsumerMetrics, ProducerMetrics # Module logger logger = logging.getLogger(__name__) -default_ident = "de-write" +default_ident = "doc-embeddings-write" default_store_uri = 'http://localhost:6333' diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 148e866a..21aa21e6 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -27,7 +27,7 @@ def get_term_value(term): # For blank nodes or other types, use id or value return term.id or term.value -default_ident = "ge-write" +default_ident = "graph-embeddings-write" default_store_uri = 'http://localhost:19530' class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index c92d7661..c4b0065b 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -32,7 +32,7 @@ def get_term_value(term): # For blank nodes or other types, use id or value return term.id or term.value -default_ident = "ge-write" +default_ident = "graph-embeddings-write" default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index bdc5fa70..0da59bb9 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -31,7 +31,7 @@ def get_term_value(term): return term.id or term.value -default_ident = "ge-write" +default_ident = "graph-embeddings-write" default_store_uri = 'http://localhost:6333' diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/__init__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/trustgraph-flow/trustgraph/storage/objects/__init__.py b/trustgraph-flow/trustgraph/storage/objects/__init__.py deleted file mode 100644 index 56f5f66a..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Objects storage module \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py deleted file mode 100644 index 01adc061..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . write import * diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py deleted file mode 100644 index 95376fee..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . write import run - -run() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py deleted file mode 100644 index bcb0d57f..00000000 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py +++ /dev/null @@ -1,538 +0,0 @@ -""" -Object writer for Cassandra. Input is ExtractedObject. -Writes structured objects to Cassandra tables based on schema definitions. -""" - -import json -import logging -from typing import Dict, Set, Optional, Any -from cassandra.cluster import Cluster -from cassandra.auth import PlainTextAuthProvider -from cassandra.cqlengine import connection -from cassandra import ConsistencyLevel - -from .... schema import ExtractedObject -from .... schema import RowSchema, Field -from .... base import FlowProcessor, ConsumerSpec, ProducerSpec -from .... base import CollectionConfigHandler -from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config - -# Module logger -logger = logging.getLogger(__name__) - -default_ident = "objects-write" - -class Processor(CollectionConfigHandler, FlowProcessor): - - def __init__(self, **params): - - id = params.get("id", default_ident) - - # Get Cassandra parameters - cassandra_host = params.get("cassandra_host") - cassandra_username = params.get("cassandra_username") - cassandra_password = params.get("cassandra_password") - - # Resolve configuration with environment variable fallback - hosts, username, password, keyspace = resolve_cassandra_config( - host=cassandra_host, - username=cassandra_username, - password=cassandra_password - ) - - # Store resolved configuration with proper names - self.cassandra_host = hosts # Store as list - self.cassandra_username = username - self.cassandra_password = password - - # Config key for schemas - self.config_key = params.get("config_type", "schema") - - super(Processor, self).__init__( - **params | { - "id": id, - "config_type": self.config_key, - } - ) - - self.register_specification( - ConsumerSpec( - name = "input", - schema = ExtractedObject, - handler = self.on_object - ) - ) - - # Register config handlers - self.register_config_handler(self.on_schema_config) - self.register_config_handler(self.on_collection_config) - - # Cache of known keyspaces/tables - self.known_keyspaces: Set[str] = set() - self.known_tables: Dict[str, Set[str]] = {} # keyspace -> set of tables - - # Schema storage: name -> RowSchema - self.schemas: Dict[str, RowSchema] = {} - - # Cassandra session - self.cluster = None - self.session = None - - def connect_cassandra(self): - """Connect to Cassandra cluster""" - if self.session: - return - - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) - - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise - - async def on_schema_config(self, config, version): - """Handle schema configuration updates""" - logger.info(f"Loading schema configuration version {version}") - - # Clear existing schemas - self.schemas = {} - - # Check if our config type exists - if self.config_key not in config: - logger.warning(f"No '{self.config_key}' type in configuration") - return - - # Get the schemas dictionary for our type - schemas_config = config[self.config_key] - - # Process each schema in the schemas config - for schema_name, schema_json in schemas_config.items(): - try: - # Parse the JSON schema definition - schema_def = json.loads(schema_json) - - # Create Field objects - fields = [] - for field_def in schema_def.get("fields", []): - field = Field( - name=field_def["name"], - type=field_def["type"], - size=field_def.get("size", 0), - primary=field_def.get("primary_key", False), - description=field_def.get("description", ""), - required=field_def.get("required", False), - enum_values=field_def.get("enum", []), - indexed=field_def.get("indexed", False) - ) - fields.append(field) - - # Create RowSchema - row_schema = RowSchema( - name=schema_def.get("name", schema_name), - description=schema_def.get("description", ""), - fields=fields - ) - - self.schemas[schema_name] = row_schema - logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") - - except Exception as e: - logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) - - logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") - - def ensure_keyspace(self, keyspace: str): - """Ensure keyspace exists in Cassandra""" - if keyspace in self.known_keyspaces: - return - - # Connect if needed - self.connect_cassandra() - - # Sanitize keyspace name - safe_keyspace = self.sanitize_name(keyspace) - - # Create keyspace if not exists - create_keyspace_cql = f""" - CREATE KEYSPACE IF NOT EXISTS {safe_keyspace} - WITH REPLICATION = {{ - 'class': 'SimpleStrategy', - 'replication_factor': 1 - }} - """ - - try: - self.session.execute(create_keyspace_cql) - self.known_keyspaces.add(keyspace) - self.known_tables[keyspace] = set() - logger.info(f"Ensured keyspace exists: {safe_keyspace}") - except Exception as e: - logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True) - raise - - def get_cassandra_type(self, field_type: str, size: int = 0) -> str: - """Convert schema field type to Cassandra type""" - # Handle None size - if size is None: - size = 0 - - type_mapping = { - "string": "text", - "integer": "bigint" if size > 4 else "int", - "float": "double" if size > 4 else "float", - "boolean": "boolean", - "timestamp": "timestamp", - "date": "date", - "time": "time", - "uuid": "uuid" - } - - return type_mapping.get(field_type, "text") - - def sanitize_name(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - # Replace non-alphanumeric characters with underscore - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - # Ensure it starts with a letter - if safe_name and not safe_name[0].isalpha(): - safe_name = 'o_' + safe_name - return safe_name.lower() - - def sanitize_table(self, name: str) -> str: - """Sanitize names for Cassandra compatibility""" - # Replace non-alphanumeric characters with underscore - import re - safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) - # Ensure it starts with a letter - safe_name = 'o_' + safe_name - return safe_name.lower() - - def ensure_table(self, keyspace: str, table_name: str, schema: RowSchema): - """Ensure table exists with proper structure""" - table_key = f"{keyspace}.{table_name}" - if table_key in self.known_tables.get(keyspace, set()): - return - - # Ensure keyspace exists first - self.ensure_keyspace(keyspace) - - safe_keyspace = self.sanitize_name(keyspace) - safe_table = self.sanitize_table(table_name) - - # Build column definitions - columns = ["collection text"] # Collection is always part of table - primary_key_fields = [] - clustering_fields = [] - - for field in schema.fields: - safe_field_name = self.sanitize_name(field.name) - cassandra_type = self.get_cassandra_type(field.type, field.size) - columns.append(f"{safe_field_name} {cassandra_type}") - - if field.primary: - primary_key_fields.append(safe_field_name) - - # Build primary key - collection is always first in partition key - if primary_key_fields: - primary_key = f"PRIMARY KEY ((collection, {', '.join(primary_key_fields)}))" - else: - # If no primary key defined, use collection and a synthetic id - columns.append("synthetic_id uuid") - primary_key = "PRIMARY KEY ((collection, synthetic_id))" - - # Create table - create_table_cql = f""" - CREATE TABLE IF NOT EXISTS {safe_keyspace}.{safe_table} ( - {', '.join(columns)}, - {primary_key} - ) - """ - - try: - self.session.execute(create_table_cql) - if keyspace not in self.known_tables: - self.known_tables[keyspace] = set() - self.known_tables[keyspace].add(table_key) - logger.info(f"Ensured table exists: {safe_keyspace}.{safe_table}") - - # Create secondary indexes for indexed fields - for field in schema.fields: - if field.indexed and not field.primary: - safe_field_name = self.sanitize_name(field.name) - index_name = f"{safe_table}_{safe_field_name}_idx" - create_index_cql = f""" - CREATE INDEX IF NOT EXISTS {index_name} - ON {safe_keyspace}.{safe_table} ({safe_field_name}) - """ - try: - self.session.execute(create_index_cql) - logger.info(f"Created index: {index_name}") - except Exception as e: - logger.warning(f"Failed to create index {index_name}: {e}") - - except Exception as e: - logger.error(f"Failed to create table {safe_keyspace}.{safe_table}: {e}", exc_info=True) - raise - - def convert_value(self, value: Any, field_type: str) -> Any: - """Convert value to appropriate type for Cassandra""" - if value is None: - return None - - try: - if field_type == "integer": - return int(value) - elif field_type == "float": - return float(value) - elif field_type == "boolean": - if isinstance(value, str): - return value.lower() in ('true', '1', 'yes') - return bool(value) - elif field_type == "timestamp": - # Handle timestamp conversion if needed - return value - else: - return str(value) - except Exception as e: - logger.warning(f"Failed to convert value {value} to type {field_type}: {e}") - return str(value) - async def on_object(self, msg, consumer, flow): - """Process incoming ExtractedObject and store in Cassandra""" - - obj = msg.value() - logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}") - - # Validate collection exists before accepting writes - if not self.collection_exists(obj.metadata.user, obj.metadata.collection): - error_msg = ( - f"Collection {obj.metadata.collection} does not exist. " - f"Create it first via collection management API." - ) - logger.error(error_msg) - raise ValueError(error_msg) - - # Get schema definition - schema = self.schemas.get(obj.schema_name) - if not schema: - logger.warning(f"No schema found for {obj.schema_name} - skipping") - return - - # Ensure table exists - keyspace = obj.metadata.user - table_name = obj.schema_name - self.ensure_table(keyspace, table_name, schema) - - # Prepare data for insertion - safe_keyspace = self.sanitize_name(keyspace) - safe_table = self.sanitize_table(table_name) - - # Process each object in the batch - for obj_index, value_map in enumerate(obj.values): - # Build column names and values for this object - columns = ["collection"] - values = [obj.metadata.collection] - placeholders = ["%s"] - - # Check if we need a synthetic ID - has_primary_key = any(field.primary for field in schema.fields) - if not has_primary_key: - import uuid - columns.append("synthetic_id") - values.append(uuid.uuid4()) - placeholders.append("%s") - - # Process fields for this object - skip_object = False - for field in schema.fields: - safe_field_name = self.sanitize_name(field.name) - raw_value = value_map.get(field.name) - - # Handle required fields - if field.required and raw_value is None: - logger.warning(f"Required field {field.name} is missing in object {obj_index}") - # Continue anyway - Cassandra doesn't enforce NOT NULL - - # Check if primary key field is NULL - if field.primary and raw_value is None: - logger.error(f"Primary key field {field.name} cannot be NULL - skipping object {obj_index}") - skip_object = True - break - - # Convert value to appropriate type - converted_value = self.convert_value(raw_value, field.type) - - columns.append(safe_field_name) - values.append(converted_value) - placeholders.append("%s") - - # Skip this object if primary key validation failed - if skip_object: - continue - - # Build and execute insert query for this object - insert_cql = f""" - INSERT INTO {safe_keyspace}.{safe_table} ({', '.join(columns)}) - VALUES ({', '.join(placeholders)}) - """ - - # Debug: Show data being inserted - logger.debug(f"Storing {obj.schema_name} object {obj_index}: {dict(zip(columns, values))}") - - if len(columns) != len(values) or len(columns) != len(placeholders): - raise ValueError(f"Mismatch in counts - columns: {len(columns)}, values: {len(values)}, placeholders: {len(placeholders)}") - - try: - # Convert to tuple - Cassandra driver requires tuple for parameters - self.session.execute(insert_cql, tuple(values)) - except Exception as e: - logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True) - raise - - async def create_collection(self, user: str, collection: str, metadata: dict): - """Create/verify collection exists in Cassandra object store""" - # Connect if not already connected - self.connect_cassandra() - - # Sanitize names for safety - safe_keyspace = self.sanitize_name(user) - - # Ensure keyspace exists - if safe_keyspace not in self.known_keyspaces: - self.ensure_keyspace(safe_keyspace) - self.known_keyspaces.add(safe_keyspace) - - # For Cassandra objects, collection is just a property in rows - # No need to create separate tables per collection - # Just mark that we've seen this collection - logger.info(f"Collection {collection} ready for user {user} (using keyspace {safe_keyspace})") - - async def delete_collection(self, user: str, collection: str): - """Delete all data for a specific collection using schema information""" - # Connect if not already connected - self.connect_cassandra() - - # Sanitize names for safety - safe_keyspace = self.sanitize_name(user) - - # Check if keyspace exists - if safe_keyspace not in self.known_keyspaces: - # Query to verify keyspace exists - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) - if not result.one(): - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(safe_keyspace) - - # Iterate over schemas we manage to delete from relevant tables - tables_deleted = 0 - - for schema_name, schema in self.schemas.items(): - safe_table = self.sanitize_table(schema_name) - - # Check if table exists - table_key = f"{user}.{schema_name}" - if table_key not in self.known_tables.get(user, set()): - logger.debug(f"Table {safe_keyspace}.{safe_table} not in known tables, skipping") - continue - - try: - # Get primary key fields from schema - primary_key_fields = [field for field in schema.fields if field.primary] - - if primary_key_fields: - # Schema has primary keys: need to query for partition keys first - # Build SELECT query for primary key fields - pk_field_names = [self.sanitize_name(field.name) for field in primary_key_fields] - select_cql = f""" - SELECT {', '.join(pk_field_names)} - FROM {safe_keyspace}.{safe_table} - WHERE collection = %s - ALLOW FILTERING - """ - - rows = self.session.execute(select_cql, (collection,)) - - # Delete each row using full partition key - for row in rows: - where_clauses = ["collection = %s"] - values = [collection] - - for field_name in pk_field_names: - where_clauses.append(f"{field_name} = %s") - values.append(getattr(row, field_name)) - - delete_cql = f""" - DELETE FROM {safe_keyspace}.{safe_table} - WHERE {' AND '.join(where_clauses)} - """ - - self.session.execute(delete_cql, tuple(values)) - else: - # No primary keys, uses synthetic_id - # Need to query for synthetic_ids first - select_cql = f""" - SELECT synthetic_id - FROM {safe_keyspace}.{safe_table} - WHERE collection = %s - ALLOW FILTERING - """ - - rows = self.session.execute(select_cql, (collection,)) - - # Delete each row using collection and synthetic_id - for row in rows: - delete_cql = f""" - DELETE FROM {safe_keyspace}.{safe_table} - WHERE collection = %s AND synthetic_id = %s - """ - self.session.execute(delete_cql, (collection, row.synthetic_id)) - - tables_deleted += 1 - logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{safe_table}") - - except Exception as e: - logger.error(f"Failed to delete from table {safe_keyspace}.{safe_table}: {e}") - raise - - logger.info(f"Deleted collection {collection} from {tables_deleted} schema-based tables in keyspace {safe_keyspace}") - - def close(self): - """Clean up Cassandra connections""" - if self.cluster: - self.cluster.shutdown() - logger.info("Closed Cassandra connection") - - @staticmethod - def add_args(parser): - """Add command-line arguments""" - - FlowProcessor.add_args(parser) - add_cassandra_args(parser) - - parser.add_argument( - '--config-type', - default='schema', - help='Configuration type prefix for schemas (default: schema)' - ) - -def run(): - """Entry point for objects-write-cassandra command""" - Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py new file mode 100644 index 00000000..16b2f154 --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/__init__.py @@ -0,0 +1,3 @@ +""" +Row embeddings storage modules. +""" diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py new file mode 100644 index 00000000..65c5c514 --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__init__.py @@ -0,0 +1,5 @@ +""" +Qdrant storage for row embeddings. +""" + +from .write import Processor, run, default_ident diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py new file mode 100644 index 00000000..a349475c --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/__main__.py @@ -0,0 +1,4 @@ + +from .write import run + +run() diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py new file mode 100644 index 00000000..29848c4c --- /dev/null +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -0,0 +1,264 @@ +""" +Row embeddings writer for Qdrant (Stage 2). + +Consumes RowEmbeddings messages (which already contain computed vectors) +and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair. + +This follows the two-stage pattern used by graph-embeddings and document-embeddings: + Stage 1 (row-embeddings): Compute embeddings + Stage 2 (this processor): Store embeddings + +Collection naming: rows_{user}_{collection}_{schema_name}_{dimension} + +Payload structure: + - index_name: The indexed field(s) this embedding represents + - index_value: The original list of values (for Cassandra lookup) + - text: The text that was embedded (for debugging/display) +""" + +import logging +import re +import uuid +from typing import Set, Tuple + +from qdrant_client import QdrantClient +from qdrant_client.models import PointStruct, Distance, VectorParams + +from .... schema import RowEmbeddings +from .... base import FlowProcessor, ConsumerSpec +from .... base import CollectionConfigHandler + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "row-embeddings-write" +default_store_uri = 'http://localhost:6333' + + +class Processor(CollectionConfigHandler, FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) + + super(Processor, self).__init__( + **params | { + "id": id, + "store_uri": store_uri, + "api_key": api_key, + } + ) + + self.register_specification( + ConsumerSpec( + name="input", + schema=RowEmbeddings, + handler=self.on_embeddings + ) + ) + + # Register config handler for collection management + self.register_config_handler(self.on_collection_config) + + # Cache of created Qdrant collections + self.created_collections: Set[str] = set() + + # Qdrant client + self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Qdrant collection naming""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def get_collection_name( + self, user: str, collection: str, schema_name: str, dimension: int + ) -> str: + """Generate Qdrant collection name""" + safe_user = self.sanitize_name(user) + safe_collection = self.sanitize_name(collection) + safe_schema = self.sanitize_name(schema_name) + return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" + + def ensure_collection(self, collection_name: str, dimension: int): + """Create Qdrant collection if it doesn't exist""" + if collection_name in self.created_collections: + return + + if not self.qdrant.collection_exists(collection_name): + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" + ) + self.qdrant.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ) + ) + + self.created_collections.add(collection_name) + + async def on_embeddings(self, msg, consumer, flow): + """Process incoming RowEmbeddings and write to Qdrant""" + + embeddings = msg.value() + logger.info( + f"Writing {len(embeddings.embeddings)} embeddings for schema " + f"{embeddings.schema_name} from {embeddings.metadata.id}" + ) + + # Validate collection exists in config before processing + if not self.collection_exists( + embeddings.metadata.user, embeddings.metadata.collection + ): + logger.warning( + f"Collection {embeddings.metadata.collection} for user " + f"{embeddings.metadata.user} does not exist in config. " + f"Dropping message." + ) + return + + user = embeddings.metadata.user + collection = embeddings.metadata.collection + schema_name = embeddings.schema_name + + embeddings_written = 0 + qdrant_collection = None + + for row_emb in embeddings.embeddings: + if not row_emb.vectors: + logger.warning( + f"No vectors for index {row_emb.index_name} - skipping" + ) + continue + + # Use first vector (there may be multiple from different models) + for vector in row_emb.vectors: + dimension = len(vector) + + # Create/get collection name (lazily on first vector) + if qdrant_collection is None: + qdrant_collection = self.get_collection_name( + user, collection, schema_name, dimension + ) + self.ensure_collection(qdrant_collection, dimension) + + # Write to Qdrant + self.qdrant.upsert( + collection_name=qdrant_collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "index_name": row_emb.index_name, + "index_value": row_emb.index_value, + "text": row_emb.text + } + ) + ] + ) + embeddings_written += 1 + + logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Collection creation via config push - collections created lazily on first write""" + logger.info( + f"Row embeddings collection create request for {user}/{collection} - " + f"will be created lazily on first write" + ) + + async def delete_collection(self, user: str, collection: str): + """Delete all Qdrant collections for a given user/collection""" + try: + prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_" + + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No Qdrant collections found matching prefix {prefix}") + else: + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + self.created_collections.discard(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + logger.info( + f"Deleted {len(matching_collections)} collection(s) " + f"for {user}/{collection}" + ) + + except Exception as e: + logger.error( + f"Failed to delete collection {user}/{collection}: {e}", + exc_info=True + ) + raise + + async def delete_collection_schema( + self, user: str, collection: str, schema_name: str + ): + """Delete Qdrant collection for a specific user/collection/schema""" + try: + prefix = ( + f"rows_{self.sanitize_name(user)}_" + f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" + ) + + # Get all collections and filter for matches + all_collections = self.qdrant.get_collections().collections + matching_collections = [ + coll.name for coll in all_collections + if coll.name.startswith(prefix) + ] + + if not matching_collections: + logger.info(f"No Qdrant collections found matching prefix {prefix}") + else: + for collection_name in matching_collections: + self.qdrant.delete_collection(collection_name) + self.created_collections.discard(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + + except Exception as e: + logger.error( + f"Failed to delete collection {user}/{collection}/{schema_name}: {e}", + exc_info=True + ) + raise + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + + parser.add_argument( + '-t', '--store-uri', + default=default_store_uri, + help=f'Qdrant URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help='Qdrant API key (default: None)' + ) + + +def run(): + """Entry point for row-embeddings-write-qdrant command""" + Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index 1576b70c..d15916b6 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -1,46 +1,49 @@ - """ -Graph writer. Input is graph edge. Writes edges to Cassandra graph. +Row writer for Cassandra. Input is ExtractedObject. +Writes structured rows to a unified Cassandra table with multi-index support. + +Uses a single 'rows' table with the schema: + - collection: text + - schema_name: text + - index_name: text + - index_value: frozen> + - data: map + - source: text + +Each row is written multiple times - once per indexed field defined in the schema. """ -raise RuntimeError("This code is no longer in use") - -import pulsar -import base64 -import os -import argparse -import time +import json import logging +import re +from typing import Dict, Set, Optional, Any, List, Tuple + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider -from ssl import SSLContext, PROTOCOL_TLSv1_2 -from .... schema import Rows -from .... log_level import LogLevel -from .... base import Consumer +from .... schema import ExtractedObject +from .... schema import RowSchema, Field +from .... base import FlowProcessor, ConsumerSpec +from .... base import CollectionConfigHandler from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config # Module logger logger = logging.getLogger(__name__) -module = "rows-write" -ssl_context = SSLContext(PROTOCOL_TLSv1_2) +default_ident = "rows-write" -default_input_queue = "rows-store" # Default queue name -default_subscriber = module -class Processor(Consumer): +class Processor(CollectionConfigHandler, FlowProcessor): def __init__(self, **params): - - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - + + id = params.get("id", default_ident) + # Get Cassandra parameters cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - + # Resolve configuration with environment variable fallback hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, @@ -48,99 +51,549 @@ class Processor(Consumer): password=cassandra_password ) + # Store resolved configuration with proper names + self.cassandra_host = hosts # Store as list + self.cassandra_username = username + self.cassandra_password = password + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": Rows, - "cassandra_host": ','.join(hosts), - "cassandra_username": username, - "cassandra_password": password, + "id": id, + "config_type": self.config_key, } ) - - if username and password: - auth_provider = PlainTextAuthProvider(username=username, password=password) - self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) - else: - self.cluster = Cluster(hosts) - self.session = self.cluster.connect() - self.tables = set() + self.register_specification( + ConsumerSpec( + name="input", + schema=ExtractedObject, + handler=self.on_object + ) + ) - self.session.execute(""" - create keyspace if not exists trustgraph - with replication = { - 'class' : 'SimpleStrategy', - 'replication_factor' : 1 - }; - """); + # Register config handlers + self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_collection_config) - self.session.execute("use trustgraph"); + # Cache of known keyspaces and whether tables exist + self.known_keyspaces: Set[str] = set() + self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables - async def handle(self, msg): + # Cache of registered (collection, schema_name) pairs + self.registered_partitions: Set[Tuple[str, str]] = set() + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + # Cassandra session + self.cluster = None + self.session = None + + def connect_cassandra(self): + """Connect to Cassandra cluster""" + if self.session: + return try: - - v = msg.value() - name = v.row_schema.name - - if name not in self.tables: - - # FIXME: SQL injection? - - pkey = [] - - stmt = "create table if not exists " + name + " ( " - - for field in v.row_schema.fields: - - stmt += field.name + " text, " - - if field.primary: - pkey.append(field.name) - - stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));" - - self.session.execute(stmt) - - self.tables.add(name); - - for row in v.rows: - - field_names = [] - values = [] - - for field in v.row_schema.fields: - field_names.append(field.name) - values.append(row[field.name]) - - # FIXME: SQL injection? - stmt = ( - "insert into " + name + " (" + ", ".join(field_names) + - ") values (" + ",".join(["%s"] * len(values)) + ")" + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password ) + self.cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + self.cluster = Cluster(contact_points=self.cassandra_host) - self.session.execute(stmt, values) + self.session = self.cluster.connect() + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise - logger.error(f"Exception: {str(e)}", exc_info=True) + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") - # If there's an error make sure to do table creation etc. - self.tables.remove(name) + # Track which schemas changed so we can clear partition cache + old_schema_names = set(self.schemas.keys()) - raise e + # Clear existing schemas + self.schemas = {} + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = Field( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + # Clear partition cache for schemas that changed + # This ensures next write will re-register partitions + new_schema_names = set(self.schemas.keys()) + changed_schemas = old_schema_names.symmetric_difference(new_schema_names) + if changed_schemas: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if sch not in changed_schemas + } + logger.info(f"Cleared partition cache for changed schemas: {changed_schemas}") + + def sanitize_name(self, name: str) -> str: + """Sanitize names for Cassandra compatibility""" + safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + # Ensure it starts with a letter + if safe_name and not safe_name[0].isalpha(): + safe_name = 'r_' + safe_name + return safe_name.lower() + + def ensure_keyspace(self, keyspace: str): + """Ensure keyspace exists in Cassandra""" + if keyspace in self.known_keyspaces: + return + + # Connect if needed + self.connect_cassandra() + + # Sanitize keyspace name + safe_keyspace = self.sanitize_name(keyspace) + + # Create keyspace if not exists + create_keyspace_cql = f""" + CREATE KEYSPACE IF NOT EXISTS {safe_keyspace} + WITH REPLICATION = {{ + 'class': 'SimpleStrategy', + 'replication_factor': 1 + }} + """ + + try: + self.session.execute(create_keyspace_cql) + self.known_keyspaces.add(keyspace) + logger.info(f"Ensured keyspace exists: {safe_keyspace}") + except Exception as e: + logger.error(f"Failed to create keyspace {safe_keyspace}: {e}", exc_info=True) + raise + + def ensure_tables(self, keyspace: str): + """Ensure unified rows and row_partitions tables exist""" + if keyspace in self.tables_initialized: + return + + # Ensure keyspace exists first + self.ensure_keyspace(keyspace) + + safe_keyspace = self.sanitize_name(keyspace) + + # Create unified rows table + create_rows_cql = f""" + CREATE TABLE IF NOT EXISTS {safe_keyspace}.rows ( + collection text, + schema_name text, + index_name text, + index_value frozen>, + data map, + source text, + PRIMARY KEY ((collection, schema_name, index_name), index_value) + ) + """ + + # Create row_partitions tracking table + create_partitions_cql = f""" + CREATE TABLE IF NOT EXISTS {safe_keyspace}.row_partitions ( + collection text, + schema_name text, + index_name text, + PRIMARY KEY ((collection), schema_name, index_name) + ) + """ + + try: + self.session.execute(create_rows_cql) + logger.info(f"Ensured rows table exists: {safe_keyspace}.rows") + + self.session.execute(create_partitions_cql) + logger.info(f"Ensured row_partitions table exists: {safe_keyspace}.row_partitions") + + self.tables_initialized.add(keyspace) + + except Exception as e: + logger.error(f"Failed to create tables in {safe_keyspace}: {e}", exc_info=True) + raise + + def get_index_names(self, schema: RowSchema) -> List[str]: + """ + Get all index names for a schema. + Returns list of index_name strings (single field names or comma-joined composites). + """ + index_names = [] + + for field in schema.fields: + # Primary key fields are treated as indexes + if field.primary: + index_names.append(field.name) + # Indexed fields + elif field.indexed: + index_names.append(field.name) + + # TODO: Support composite indexes in the future + # For now, each indexed field is a single-field index + + return index_names + + def register_partitions(self, keyspace: str, collection: str, schema_name: str): + """ + Register partition entries for a (collection, schema_name) pair. + Called once on first row for each pair. + """ + cache_key = (collection, schema_name) + if cache_key in self.registered_partitions: + return + + schema = self.schemas.get(schema_name) + if not schema: + logger.warning(f"Cannot register partitions - schema {schema_name} not found") + return + + safe_keyspace = self.sanitize_name(keyspace) + index_names = self.get_index_names(schema) + + # Insert partition entries for each index + insert_cql = f""" + INSERT INTO {safe_keyspace}.row_partitions (collection, schema_name, index_name) + VALUES (%s, %s, %s) + """ + + for index_name in index_names: + try: + self.session.execute(insert_cql, (collection, schema_name, index_name)) + except Exception as e: + logger.warning(f"Failed to register partition {collection}/{schema_name}/{index_name}: {e}") + + self.registered_partitions.add(cache_key) + logger.info(f"Registered partitions for {collection}/{schema_name}: {index_names}") + + def build_index_value(self, value_map: Dict[str, str], index_name: str) -> List[str]: + """ + Build the index_value list for a given index. + For single-field indexes, returns a single-element list. + For composite indexes (comma-separated), returns multiple elements. + """ + field_names = [f.strip() for f in index_name.split(',')] + values = [] + + for field_name in field_names: + value = value_map.get(field_name) + # Convert to string for storage + values.append(str(value) if value is not None else "") + + return values + + async def on_object(self, msg, consumer, flow): + """Process incoming ExtractedObject and store in Cassandra""" + + obj = msg.value() + logger.info( + f"Storing {len(obj.values)} rows for schema {obj.schema_name} " + f"from {obj.metadata.id}" + ) + + # Validate collection exists before accepting writes + if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + error_msg = ( + f"Collection {obj.metadata.collection} does not exist. " + f"Create it first via collection management API." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Get schema definition + schema = self.schemas.get(obj.schema_name) + if not schema: + logger.warning(f"No schema found for {obj.schema_name} - skipping") + return + + keyspace = obj.metadata.user + collection = obj.metadata.collection + schema_name = obj.schema_name + source = getattr(obj.metadata, 'source', '') or '' + + # Ensure tables exist + self.ensure_tables(keyspace) + + # Register partitions if first time seeing this (collection, schema_name) + self.register_partitions(keyspace, collection, schema_name) + + safe_keyspace = self.sanitize_name(keyspace) + + # Get all index names for this schema + index_names = self.get_index_names(schema) + + if not index_names: + logger.warning(f"Schema {schema_name} has no indexed fields - rows won't be queryable") + return + + # Prepare insert statement + insert_cql = f""" + INSERT INTO {safe_keyspace}.rows + (collection, schema_name, index_name, index_value, data, source) + VALUES (%s, %s, %s, %s, %s, %s) + """ + + # Process each row in the batch + rows_written = 0 + for row_index, value_map in enumerate(obj.values): + # Convert all values to strings for the data map + data_map = {} + for field in schema.fields: + raw_value = value_map.get(field.name) + if raw_value is not None: + data_map[field.name] = str(raw_value) + + # Write one copy per index + for index_name in index_names: + index_value = self.build_index_value(value_map, index_name) + + # Skip if index value is empty/null + if not index_value or all(v == "" for v in index_value): + logger.debug( + f"Skipping index {index_name} for row {row_index} - " + f"empty index value" + ) + continue + + try: + self.session.execute( + insert_cql, + (collection, schema_name, index_name, index_value, data_map, source) + ) + rows_written += 1 + except Exception as e: + logger.error( + f"Failed to insert row {row_index} for index {index_name}: {e}", + exc_info=True + ) + raise + + logger.info( + f"Wrote {rows_written} index entries for {len(obj.values)} rows " + f"({len(index_names)} indexes per row)" + ) + + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create/verify collection exists in Cassandra row store""" + # Connect if not already connected + self.connect_cassandra() + + # Ensure tables exist + self.ensure_tables(user) + + logger.info(f"Collection {collection} ready for user {user}") + + async def delete_collection(self, user: str, collection: str): + """Delete all data for a specific collection using partition tracking""" + # Connect if not already connected + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Check if keyspace exists + if user not in self.known_keyspaces: + check_keyspace_cql = """ + SELECT keyspace_name FROM system_schema.keyspaces + WHERE keyspace_name = %s + """ + result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) + if not result.one(): + logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") + return + self.known_keyspaces.add(user) + + # Discover all partitions for this collection + select_partitions_cql = f""" + SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions + WHERE collection = %s + """ + + try: + partitions = self.session.execute(select_partitions_cql, (collection,)) + partition_list = list(partitions) + except Exception as e: + logger.error(f"Failed to query partitions for collection {collection}: {e}") + raise + + # Delete each partition from rows table + delete_rows_cql = f""" + DELETE FROM {safe_keyspace}.rows + WHERE collection = %s AND schema_name = %s AND index_name = %s + """ + + partitions_deleted = 0 + for partition in partition_list: + try: + self.session.execute( + delete_rows_cql, + (collection, partition.schema_name, partition.index_name) + ) + partitions_deleted += 1 + except Exception as e: + logger.error( + f"Failed to delete partition {collection}/{partition.schema_name}/" + f"{partition.index_name}: {e}" + ) + raise + + # Clean up row_partitions entries + delete_partitions_cql = f""" + DELETE FROM {safe_keyspace}.row_partitions + WHERE collection = %s + """ + + try: + self.session.execute(delete_partitions_cql, (collection,)) + except Exception as e: + logger.error(f"Failed to clean up row_partitions for {collection}: {e}") + raise + + # Clear from local cache + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } + + logger.info( + f"Deleted collection {collection}: {partitions_deleted} partitions " + f"from keyspace {safe_keyspace}" + ) + + async def delete_collection_schema(self, user: str, collection: str, schema_name: str): + """Delete all data for a specific collection + schema combination""" + # Connect if not already connected + self.connect_cassandra() + + safe_keyspace = self.sanitize_name(user) + + # Discover partitions for this collection + schema + select_partitions_cql = f""" + SELECT index_name FROM {safe_keyspace}.row_partitions + WHERE collection = %s AND schema_name = %s + """ + + try: + partitions = self.session.execute(select_partitions_cql, (collection, schema_name)) + partition_list = list(partitions) + except Exception as e: + logger.error( + f"Failed to query partitions for {collection}/{schema_name}: {e}" + ) + raise + + # Delete each partition from rows table + delete_rows_cql = f""" + DELETE FROM {safe_keyspace}.rows + WHERE collection = %s AND schema_name = %s AND index_name = %s + """ + + partitions_deleted = 0 + for partition in partition_list: + try: + self.session.execute( + delete_rows_cql, + (collection, schema_name, partition.index_name) + ) + partitions_deleted += 1 + except Exception as e: + logger.error( + f"Failed to delete partition {collection}/{schema_name}/" + f"{partition.index_name}: {e}" + ) + raise + + # Clean up row_partitions entries for this schema + delete_partitions_cql = f""" + DELETE FROM {safe_keyspace}.row_partitions + WHERE collection = %s AND schema_name = %s + """ + + try: + self.session.execute(delete_partitions_cql, (collection, schema_name)) + except Exception as e: + logger.error( + f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}" + ) + raise + + # Clear from local cache + self.registered_partitions.discard((collection, schema_name)) + + logger.info( + f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " + f"from keyspace {safe_keyspace}" + ) + + def close(self): + """Clean up Cassandra connections""" + if self.cluster: + self.cluster.shutdown() + logger.info("Closed Cassandra connection") @staticmethod def add_args(parser): + """Add command-line arguments""" - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) + FlowProcessor.add_args(parser) add_cassandra_args(parser) + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + def run(): - - Processor.launch(module, __doc__) - + """Entry point for rows-write-cassandra command""" + Processor.launch(default_ident, __doc__)