diff --git a/docs/tech-specs/collection-management.md b/docs/tech-specs/collection-management.md new file mode 100644 index 00000000..3e3ded01 --- /dev/null +++ b/docs/tech-specs/collection-management.md @@ -0,0 +1,349 @@ +# Collection Management Technical Specification + +## Overview + +This specification describes the collection management capabilities for TrustGraph, enabling users to have explicit control over collections that are currently implicitly created during data loading and querying operations. The feature supports four primary use cases: + +1. **Collection Listing**: View all existing collections in the system +2. **Collection Deletion**: Remove unwanted collections and their associated data +3. **Collection Labeling**: Associate descriptive labels with collections for better organization +4. **Collection Tagging**: Apply tags to collections for categorization and easier discovery + +## Goals + +- **Explicit Collection Control**: Provide users with direct management capabilities over collections beyond implicit creation +- **Collection Visibility**: Enable users to list and inspect all collections in their environment +- **Collection Cleanup**: Allow deletion of collections that are no longer needed +- **Collection Organization**: Support labels and tags for better collection tracking and discovery +- **Metadata Management**: Associate meaningful metadata with collections for operational clarity +- **Collection Discovery**: Make it easier to find specific collections through filtering and search +- **Operational Transparency**: Provide clear visibility into collection lifecycle and usage +- **Resource Management**: Enable cleanup of unused collections to optimize resource utilization + +## Background + +Currently, collections in TrustGraph are implicitly created during data loading operations and query execution. While this provides convenience for users, it lacks the explicit control needed for production environments and long-term data management. + +Current limitations include: +- No way to list existing collections +- No mechanism to delete unwanted collections +- No ability to associate metadata with collections for tracking purposes +- Difficulty in organizing and discovering collections over time + +This specification addresses these gaps by introducing explicit collection management operations. By providing collection management APIs and commands, TrustGraph can: +- Give users full control over their collection lifecycle +- Enable better organization through labels and tags +- Support collection cleanup for resource optimization +- Improve operational visibility and management + +## Technical Design + +### Architecture + +The collection management system will be implemented within existing TrustGraph infrastructure: + +1. **Librarian Service Integration** + - Collection management operations will be added to the existing librarian service + - No new service required - leverages existing authentication and access patterns + - Handles collection listing, deletion, and metadata management + + Module: trustgraph-librarian + +2. **Cassandra Collection Metadata Table** + - New table in the existing librarian keyspace + - Stores collection metadata with user-scoped access + - Primary key: (user_id, collection_id) for proper multi-tenancy + + Module: trustgraph-librarian + +3. **Collection Management CLI** + - Command-line interface for collection operations + - Provides list, delete, label, and tag management commands + - Integrates with existing CLI framework + + Module: trustgraph-cli + +### Data Models + +#### Cassandra Collection Metadata Table + +The collection metadata will be stored in a structured Cassandra table in the librarian keyspace: + +```sql +CREATE TABLE collections ( + user text, + collection text, + name text, + description text, + tags set, + created_at timestamp, + updated_at timestamp, + PRIMARY KEY (user, collection) +); +``` + +Table structure: +- **user** + **collection**: Composite primary key ensuring user isolation +- **name**: Human-readable collection name +- **description**: Detailed description of collection purpose +- **tags**: Set of tags for categorization and filtering +- **created_at**: Collection creation timestamp +- **updated_at**: Last modification timestamp + +This approach allows: +- Multi-tenant collection management with user isolation +- Efficient querying by user and collection +- Flexible tagging system for organization +- Lifecycle tracking for operational insights + +#### Collection Lifecycle + +Collections follow a lazy-creation pattern that aligns with existing TrustGraph behavior: + +1. **Lazy Creation**: Collections are automatically created when first referenced during data loading or query operations. No explicit create operation is needed. + +2. **Implicit Registration**: When a collection is used (data loading, querying), the system checks if a metadata record exists. If not, a new record is created with default values: + - `name`: defaults to collection_id + - `description`: empty + - `tags`: empty set + - `created_at`: current timestamp + +3. **Explicit Updates**: Users can update collection metadata (name, description, tags) through management operations after lazy creation. + +4. **Explicit Deletion**: Users can delete collections, which removes both the metadata record and the underlying collection data across all store types. + +5. **Multi-Store Deletion**: Collection deletion cascades across all storage backends (vector stores, object stores, triple stores) as each implements lazy creation and must support collection deletion. + +Operations required: +- **Collection Use Notification**: Internal operation triggered during data loading/querying to ensure metadata record exists +- **Update Collection Metadata**: User operation to modify name, description, and tags +- **Delete Collection**: User operation to remove collection and its data across all stores +- **List Collections**: User operation to view collections with filtering by tags + +#### Multi-Store Collection Management + +Collections exist across multiple storage backends in TrustGraph: +- **Vector Stores**: Store embeddings and vector data for collections +- **Object Stores**: Store documents and file data for collections +- **Triple Stores**: Store graph/RDF data for collections + +Each store type implements: +- **Lazy Creation**: Collections are created implicitly when data is first stored +- **Collection Deletion**: Store-specific deletion operations to remove collection data + +The librarian service coordinates collection operations across all store types, ensuring consistent collection lifecycle management. + +### APIs + +New APIs: +- **List Collections**: Retrieve collections for a user with optional tag filtering +- **Update Collection Metadata**: Modify collection name, description, and tags +- **Delete Collection**: Remove collection and associated data with confirmation, cascading to all store types +- **Collection Use Notification** (Internal): Ensure metadata record exists when collection is referenced + +Store Writer APIs (Enhanced): +- **Vector Store Collection Deletion**: Remove vector data for specified user and collection +- **Object Store Collection Deletion**: Remove object/document data for specified user and collection +- **Triple Store Collection Deletion**: Remove graph/RDF data for specified user and collection + +Modified APIs: +- **Data Loading APIs**: Enhanced to trigger collection use notification for lazy metadata creation +- **Query APIs**: Enhanced to trigger collection use notification and optionally include metadata in responses + +### Implementation Details + +The implementation will follow existing TrustGraph patterns for service integration and CLI command structure. + +#### Collection Deletion Cascade + +When a user initiates collection deletion through the librarian service: + +1. **Metadata Validation**: Verify collection exists and user has permission to delete +2. **Store Cascade**: Librarian coordinates deletion across all store writers: + - Vector store writer: Remove embeddings and vector indexes for the user and collection + - Object store writer: Remove documents and files for the user and collection + - Triple store writer: Remove graph data and triples for the user and collection +3. **Metadata Cleanup**: Remove collection metadata record from Cassandra +4. **Error Handling**: If any store deletion fails, maintain consistency through rollback or retry mechanisms + +#### Collection Management Interface + +All store writers will implement a standardized collection management interface with a common schema across store types: + +**Message Schema:** +```json +{ + "operation": "delete-collection", + "user": "user123", + "collection": "documents-2024", + "timestamp": "2024-01-15T10:30:00Z" +} +``` + +**Queue Architecture:** +- **Object Store Collection Management Queue**: Handles collection operations for object/document stores +- **Vector Store Collection Management Queue**: Handles collection operations for vector/embedding stores +- **Triple Store Collection Management Queue**: Handles collection operations for graph/RDF stores + +Each store writer implements: +- **Collection Management Handler**: Separate from standard data storage handlers +- **Delete Collection Operation**: Removes all data associated with the specified collection +- **Message Processing**: Consumes from dedicated collection management queue +- **Status Reporting**: Returns success/failure status for coordination +- **Idempotent Operations**: Handles cases where collection doesn't exist (no-op) + +**Initial Implementation:** +Only `delete-collection` operation will be implemented initially. The interface supports future operations like `archive-collection`, `migrate-collection`, etc. + +#### Cassandra Triple Store Refactor + +As part of this implementation, the Cassandra triple store will be refactored from a table-per-collection model to a unified table model: + +**Current Architecture:** +- Keyspace per user, separate table per collection +- Schema: `(s, p, o)` with `PRIMARY KEY (s, p, o)` +- Table names: user collections become separate Cassandra tables + +**New Architecture:** +- Keyspace per user, single "triples" table for all collections +- Schema: `(collection, s, p, o)` with `PRIMARY KEY (collection, s, p, o)` +- Collection isolation through collection partitioning + +**Changes Required:** + +1. **TrustGraph Class Refactor** (`trustgraph/direct/cassandra.py`): + - Remove `table` parameter from constructor, use fixed "triples" table + - Add `collection` parameter to all methods + - Update schema to include collection as first column + - **Index Updates**: New indexes will be created to support all 8 query patterns: + - Index on `(s)` for subject-based queries + - Index on `(p)` for predicate-based queries + - Index on `(o)` for object-based queries + - Note: Cassandra doesn't support multi-column secondary indexes, so these are single-column indexes + + - **Query Pattern Performance**: + - ✅ `get_all()` - partition scan on `collection` + - ✅ `get_s(s)` - uses primary key efficiently (`collection, s`) + - ✅ `get_p(p)` - uses `idx_p` with `collection` filtering + - ✅ `get_o(o)` - uses `idx_o` with `collection` filtering + - ✅ `get_sp(s, p)` - uses primary key efficiently (`collection, s, p`) + - ⚠️ `get_po(p, o)` - requires `ALLOW FILTERING` (uses either `idx_p` or `idx_o` plus filtering) + - ✅ `get_os(o, s)` - uses `idx_o` with additional filtering on `s` + - ✅ `get_spo(s, p, o)` - uses full primary key efficiently + + - **Note on ALLOW FILTERING**: The `get_po` query pattern requires `ALLOW FILTERING` as it needs both predicate and object constraints without a suitable compound index. This is acceptable as this query pattern is less common than subject-based queries in typical triple store usage + +2. **Storage Writer Updates** (`trustgraph/storage/triples/cassandra/write.py`): + - Maintain single TrustGraph connection per user instead of per (user, collection) + - Pass collection to insert operations + - Improved resource utilization with fewer connections + +3. **Query Service Updates** (`trustgraph/query/triples/cassandra/service.py`): + - Single TrustGraph connection per user + - Pass collection to all query operations + - Maintain same query logic with collection parameter + +**Benefits:** +- **Simplified Collection Deletion**: Simple `DELETE FROM triples WHERE collection = ?` instead of dropping tables +- **Resource Efficiency**: Fewer database connections and table objects +- **Cross-Collection Operations**: Easier to implement operations spanning multiple collections +- **Consistent Architecture**: Aligns with unified collection metadata approach + +**Migration Strategy:** +Existing table-per-collection data will need migration to the new unified schema during the upgrade process. + +Collection operations will be atomic where possible and provide appropriate error handling and validation. + +## Security Considerations + +Collection management operations require appropriate authorization to prevent unauthorized access or deletion of collections. Access control will align with existing TrustGraph security models. + +## Performance Considerations + +Collection listing operations may need pagination for environments with large numbers of collections. Metadata queries should be optimized for common filtering patterns. + +## Testing Strategy + +Comprehensive testing will cover collection lifecycle operations, metadata management, and CLI command functionality with both unit and integration tests. + +## Migration Plan + +This implementation requires both metadata and storage migrations: + +### Collection Metadata Migration +Existing collections will need to be registered in the new Cassandra collections metadata table. A migration process will: +- Scan existing keyspaces and tables to identify collections +- Create metadata records with default values (name=collection_id, empty description/tags) +- Preserve creation timestamps where possible + +### Cassandra Triple Store Migration +The Cassandra storage refactor requires data migration from table-per-collection to unified table: +- **Pre-migration**: Identify all user keyspaces and collection tables +- **Data Transfer**: Copy triples from individual collection tables to unified "triples" table with collection +- **Schema Validation**: Ensure new primary key structure maintains query performance +- **Cleanup**: Remove old collection tables after successful migration +- **Rollback Plan**: Maintain ability to restore table-per-collection structure if needed + +Migration will be performed during a maintenance window to ensure data consistency. + +## Implementation Status + +### ✅ Completed Components + +1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_service.py`) + - Complete collection CRUD operations (list, update, delete) + - Cassandra collection metadata table integration via `LibraryTableStore` + - Async request/response handling with proper error management + - Collection deletion cascade coordination across all storage types + +2. **Collection Metadata Schema** (`trustgraph-base/trustgraph/schema/services/collection.py`) + - `CollectionManagementRequest` and `CollectionManagementResponse` schemas + - `CollectionMetadata` schema for collection records + - Collection request/response queue topic definitions + +3. **Storage Management Schema** (`trustgraph-base/trustgraph/schema/services/storage.py`) + - `StorageManagementRequest` and `StorageManagementResponse` schemas + - Message format for storage-level collection operations + +### ❌ Missing Components + +1. **Storage Management Queue Topics** + - Missing topic definitions in schema for: + - `vector_storage_management_topic` + - `object_storage_management_topic` + - `triples_storage_management_topic` + - `storage_management_response_topic` + - These are referenced by the librarian service but not yet defined + +2. **Store Collection Management Handlers** + - **Vector Store Writers** (Qdrant, Milvus, Pinecone): No collection deletion handlers + - **Object Store Writers** (Cassandra): No collection deletion handlers + - **Triple Store Writers** (Cassandra, Neo4j, Memgraph, FalkorDB): No collection deletion handlers + - Need to implement `StorageManagementRequest` processing in each store writer + +3. **Collection Management Interface Implementation** + - Store writers need collection management message consumers + - Collection deletion operations need to be implemented per store type + - Response handling back to librarian service + +### Next Implementation Steps + +1. **Define Storage Management Topics** in `trustgraph-base/trustgraph/schema/services/storage.py` +2. **Implement Collection Management Handlers** in each storage writer: + - Add `StorageManagementRequest` consumers + - Implement collection deletion operations + - Add response producers for status reporting +3. **Test End-to-End Collection Deletion** across all storage types + +## Timeline + +Phase 1 (Storage Topics): 1-2 days +Phase 2 (Store Handlers): 1-2 weeks depending on number of storage backends +Phase 3 (Testing & Integration): 3-5 days + +## Open Questions + +- Should collection deletion be soft or hard delete by default? +- What metadata fields should be required vs optional? +- Should we implement storage management handlers incrementally by store type? + diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 8dc60de7..a14c521c 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -21,7 +21,7 @@ class TestEndToEndConfigurationFlow: """Test complete configuration flow from environment to processors.""" @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_triples_writer_env_to_connection(self, mock_cluster): """Test complete flow from environment variables to TrustGraph connection.""" env_vars = { @@ -117,7 +117,7 @@ class TestConfigurationPriorityEndToEnd: """Test configuration priority chains end-to-end.""" @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_cli_override_env_end_to_end(self, mock_cluster): """Test that CLI parameters override environment variables end-to-end.""" env_vars = { @@ -184,7 +184,7 @@ class TestConfigurationPriorityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_no_config_defaults_end_to_end(self, mock_cluster): """Test that defaults are used when no configuration provided end-to-end.""" mock_cluster_instance = MagicMock() @@ -222,7 +222,7 @@ class TestNoBackwardCompatibilityEndToEnd: """Test that backward compatibility with old parameter names is removed.""" @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_old_graph_params_no_longer_work_end_to_end(self, mock_cluster): """Test that old graph_* parameters no longer work end-to-end.""" mock_cluster_instance = MagicMock() @@ -275,7 +275,7 @@ class TestNoBackwardCompatibilityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_new_params_override_old_params_end_to_end(self, mock_cluster): """Test that new parameters override old ones when both are present end-to-end.""" mock_cluster_instance = MagicMock() @@ -334,7 +334,7 @@ class TestMultipleHostsHandling: assert call_args.kwargs['contact_points'] == ['host1', 'host2', 'host3', 'host4', 'host5'] @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra.Cluster') + @patch('trustgraph.direct.cassandra_kg.Cluster') async def test_single_host_converted_to_list(self, mock_cluster): """Test that single host is converted to list for TrustGraph.""" mock_cluster_instance = MagicMock() diff --git a/tests/integration/test_cassandra_integration.py b/tests/integration/test_cassandra_integration.py index ce9d7fd3..560f3132 100644 --- a/tests/integration/test_cassandra_integration.py +++ b/tests/integration/test_cassandra_integration.py @@ -13,7 +13,7 @@ import time from unittest.mock import MagicMock from .cassandra_test_helper import cassandra_container -from trustgraph.direct.cassandra import TrustGraph +from trustgraph.direct.cassandra_kg import KnowledgeGraph from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest @@ -62,29 +62,29 @@ class TestCassandraIntegration: print("=" * 60) # ===================================================== - # Test 1: Basic TrustGraph Operations + # Test 1: Basic KnowledgeGraph Operations # ===================================================== - print("\n1. Testing basic TrustGraph operations...") - - client = TrustGraph( + print("\n1. Testing basic KnowledgeGraph operations...") + + client = KnowledgeGraph( hosts=[host], - keyspace="test_basic", - table="test_table" + keyspace="test_basic" ) self.clients_to_close.append(client) # Insert test data - client.insert("http://example.org/alice", "knows", "http://example.org/bob") - client.insert("http://example.org/alice", "age", "25") - client.insert("http://example.org/bob", "age", "30") - + collection = "test_collection" + client.insert(collection, "http://example.org/alice", "knows", "http://example.org/bob") + client.insert(collection, "http://example.org/alice", "age", "25") + client.insert(collection, "http://example.org/bob", "age", "30") + # Test get_all - all_results = list(client.get_all(limit=10)) + all_results = list(client.get_all(collection, limit=10)) assert len(all_results) == 3 print(f"✓ Stored and retrieved {len(all_results)} triples") # Test get_s (subject query) - alice_results = list(client.get_s("http://example.org/alice", limit=10)) + alice_results = list(client.get_s(collection, "http://example.org/alice", limit=10)) assert len(alice_results) == 2 alice_predicates = [r.p for r in alice_results] assert "knows" in alice_predicates @@ -110,7 +110,7 @@ class TestCassandraIntegration: keyspace="test_storage", table="test_triples" ) - # Track the TrustGraph instance that will be created + # Track the KnowledgeGraph instance that will be created self.storage_processor = storage_processor # Create test message @@ -202,7 +202,7 @@ class TestCassandraIntegration: # Debug: Check what was actually stored print("Debug: Checking what was stored for Alice...") direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10)) - print(f"Direct TrustGraph results: {len(direct_results)}") + print(f"Direct KnowledgeGraph results: {len(direct_results)}") for result in direct_results: print(f" S=http://example.org/alice, P={result.p}, O={result.o}") diff --git a/tests/unit/test_direct/test_milvus_collection_naming.py b/tests/unit/test_direct/test_milvus_collection_naming.py index 9c6b0a90..d948caff 100644 --- a/tests/unit/test_direct/test_milvus_collection_naming.py +++ b/tests/unit/test_direct/test_milvus_collection_naming.py @@ -13,163 +13,146 @@ class TestMilvusCollectionNaming: """Test basic collection name creation""" result = make_safe_collection_name( user="test_user", - collection="test_collection", - dimension=384, + collection="test_collection", prefix="doc" ) - assert result == "doc_test_user_test_collection_384" + assert result == "doc_test_user_test_collection" def test_make_safe_collection_name_with_special_characters(self): """Test collection name creation with special characters that need sanitization""" result = make_safe_collection_name( user="user@domain.com", collection="test-collection.v2", - dimension=768, prefix="entity" ) - assert result == "entity_user_domain_com_test_collection_v2_768" + assert result == "entity_user_domain_com_test_collection_v2" def test_make_safe_collection_name_with_unicode(self): """Test collection name creation with Unicode characters""" result = make_safe_collection_name( user="测试用户", - collection="colección_española", - dimension=512, + collection="colección_española", prefix="doc" ) - assert result == "doc_default_colecci_n_espa_ola_512" + assert result == "doc_default_colecci_n_espa_ola" def test_make_safe_collection_name_with_spaces(self): """Test collection name creation with spaces""" result = make_safe_collection_name( user="test user", collection="my test collection", - dimension=256, prefix="entity" ) - assert result == "entity_test_user_my_test_collection_256" + assert result == "entity_test_user_my_test_collection" def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self): """Test collection name creation with multiple consecutive special characters""" result = make_safe_collection_name( user="user@@@domain!!!", collection="test---collection...v2", - dimension=384, - prefix="doc" + prefix="doc" ) - assert result == "doc_user_domain_test_collection_v2_384" + assert result == "doc_user_domain_test_collection_v2" def test_make_safe_collection_name_with_leading_trailing_underscores(self): """Test collection name creation with leading/trailing special characters""" result = make_safe_collection_name( user="__test_user__", collection="@@test_collection##", - dimension=128, prefix="entity" ) - assert result == "entity_test_user_test_collection_128" + assert result == "entity_test_user_test_collection" def test_make_safe_collection_name_empty_user(self): """Test collection name creation with empty user (should fallback to 'default')""" result = make_safe_collection_name( user="", collection="test_collection", - dimension=384, prefix="doc" ) - assert result == "doc_default_test_collection_384" + assert result == "doc_default_test_collection" def test_make_safe_collection_name_empty_collection(self): """Test collection name creation with empty collection (should fallback to 'default')""" result = make_safe_collection_name( user="test_user", collection="", - dimension=384, prefix="doc" ) - assert result == "doc_test_user_default_384" + assert result == "doc_test_user_default" def test_make_safe_collection_name_both_empty(self): """Test collection name creation with both user and collection empty""" result = make_safe_collection_name( user="", collection="", - dimension=384, prefix="doc" ) - assert result == "doc_default_default_384" + assert result == "doc_default_default" def test_make_safe_collection_name_only_special_characters(self): """Test collection name creation with only special characters (should fallback to 'default')""" result = make_safe_collection_name( user="@@@!!!", collection="---###", - dimension=512, prefix="entity" ) - assert result == "entity_default_default_512" + assert result == "entity_default_default" def test_make_safe_collection_name_whitespace_only(self): """Test collection name creation with whitespace-only strings""" result = make_safe_collection_name( user=" \n\t ", collection=" \r\n ", - dimension=256, prefix="doc" ) - assert result == "doc_default_default_256" + assert result == "doc_default_default" def test_make_safe_collection_name_mixed_valid_invalid_chars(self): """Test collection name creation with mixed valid and invalid characters""" result = make_safe_collection_name( user="user123@test", collection="coll_2023.v1", - dimension=384, prefix="entity" ) - assert result == "entity_user123_test_coll_2023_v1_384" + assert result == "entity_user123_test_coll_2023_v1" def test_make_safe_collection_name_different_prefixes(self): """Test collection name creation with different prefixes""" user = "test_user" collection = "test_collection" - dimension = 384 - - doc_result = make_safe_collection_name(user, collection, dimension, "doc") - entity_result = make_safe_collection_name(user, collection, dimension, "entity") - custom_result = make_safe_collection_name(user, collection, dimension, "custom") - - assert doc_result == "doc_test_user_test_collection_384" - assert entity_result == "entity_test_user_test_collection_384" - assert custom_result == "custom_test_user_test_collection_384" + + doc_result = make_safe_collection_name(user, collection, "doc") + entity_result = make_safe_collection_name(user, collection, "entity") + custom_result = make_safe_collection_name(user, collection, "custom") + + assert doc_result == "doc_test_user_test_collection" + assert entity_result == "entity_test_user_test_collection" + assert custom_result == "custom_test_user_test_collection" def test_make_safe_collection_name_different_dimensions(self): - """Test collection name creation with different dimensions""" + """Test collection name creation - dimension handling no longer part of function""" user = "test_user" collection = "test_collection" prefix = "doc" - - result_128 = make_safe_collection_name(user, collection, 128, prefix) - result_384 = make_safe_collection_name(user, collection, 384, prefix) - result_768 = make_safe_collection_name(user, collection, 768, prefix) - - assert result_128 == "doc_test_user_test_collection_128" - assert result_384 == "doc_test_user_test_collection_384" - assert result_768 == "doc_test_user_test_collection_768" + + # With new API, dimensions are handled separately, function always returns same result + result = make_safe_collection_name(user, collection, prefix) + + assert result == "doc_test_user_test_collection" def test_make_safe_collection_name_long_names(self): """Test collection name creation with very long user/collection names""" long_user = "a" * 100 long_collection = "b" * 100 - + result = make_safe_collection_name( user=long_user, collection=long_collection, - dimension=384, prefix="doc" ) - - expected = f"doc_{long_user}_{long_collection}_384" + + expected = f"doc_{long_user}_{long_collection}" assert result == expected assert len(result) > 200 # Verify it handles long names @@ -178,20 +161,18 @@ class TestMilvusCollectionNaming: result = make_safe_collection_name( user="user123", collection="collection456", - dimension=384, prefix="doc" ) - assert result == "doc_user123_collection456_384" + assert result == "doc_user123_collection456" def test_make_safe_collection_name_case_sensitivity(self): """Test that collection name creation preserves case""" result = make_safe_collection_name( user="TestUser", collection="TestCollection", - dimension=384, prefix="Doc" ) - assert result == "Doc_TestUser_TestCollection_384" + assert result == "Doc_TestUser_TestCollection" def test_make_safe_collection_name_realistic_examples(self): """Test collection name creation with realistic user/collection combinations""" @@ -202,30 +183,27 @@ class TestMilvusCollectionNaming: ("user_123", "test_collection", "user_123", "test_collection"), ("αβγ-user", "测试集合", "user", "default"), ] - + for user, collection, expected_user, expected_collection in test_cases: - result = make_safe_collection_name(user, collection, 384, "doc") - assert result == f"doc_{expected_user}_{expected_collection}_384" + result = make_safe_collection_name(user, collection, "doc") + assert result == f"doc_{expected_user}_{expected_collection}" def test_make_safe_collection_name_matches_qdrant_pattern(self): - """Test that Milvus collection names follow similar pattern to Qdrant""" + """Test that Milvus collection names follow similar pattern to Qdrant (but without dimension in name)""" # Qdrant uses: "d_{user}_{collection}_{dimension}" and "t_{user}_{collection}_{dimension}" - # Milvus should use: "{prefix}_{safe_user}_{safe_collection}_{dimension}" - + # New Milvus API uses: "{prefix}_{safe_user}_{safe_collection}" (dimension handled separately) + user = "test.user@domain.com" collection = "test-collection.v2" - dimension = 384 - - doc_result = make_safe_collection_name(user, collection, dimension, "doc") - entity_result = make_safe_collection_name(user, collection, dimension, "entity") - - # Should follow the pattern but with sanitized names - assert doc_result == "doc_test_user_domain_com_test_collection_v2_384" - assert entity_result == "entity_test_user_domain_com_test_collection_v2_384" - - # Verify structure matches expected pattern (may have more underscores due to sanitization) - # The important thing is that it follows prefix_user_collection_dimension structure + + doc_result = make_safe_collection_name(user, collection, "doc") + entity_result = make_safe_collection_name(user, collection, "entity") + + # Should follow the pattern but with sanitized names and no dimension + assert doc_result == "doc_test_user_domain_com_test_collection_v2" + assert entity_result == "entity_test_user_domain_com_test_collection_v2" + + # Verify structure matches expected pattern assert doc_result.startswith("doc_") - assert doc_result.endswith("_384") assert entity_result.startswith("entity_") - assert entity_result.endswith("_384") \ No newline at end of file + # Dimension is no longer part of the collection name \ No newline at end of file diff --git a/tests/unit/test_direct/test_milvus_user_collection_integration.py b/tests/unit/test_direct/test_milvus_user_collection_integration.py index 931332e4..cc45524c 100644 --- a/tests/unit/test_direct/test_milvus_user_collection_integration.py +++ b/tests/unit/test_direct/test_milvus_user_collection_integration.py @@ -32,7 +32,7 @@ class TestMilvusUserCollectionIntegration: doc_vectors.insert(vector, "test document", user, collection) expected_collection_name = make_safe_collection_name( - user, collection, len(vector), "doc" + user, collection, "doc" ) # Verify collection was created with correct name @@ -58,7 +58,7 @@ class TestMilvusUserCollectionIntegration: entity_vectors.insert(vector, "test entity", user, collection) expected_collection_name = make_safe_collection_name( - user, collection, len(vector), "entity" + user, collection, "entity" ) # Verify collection was created with correct name @@ -89,7 +89,7 @@ class TestMilvusUserCollectionIntegration: result = doc_vectors.search(vector, user, collection, limit=5) # Verify search was called with correct collection name - expected_collection_name = make_safe_collection_name(user, collection, 3, "doc") + expected_collection_name = make_safe_collection_name(user, collection, "doc") mock_client.search.assert_called_once() search_call = mock_client.search.call_args assert search_call[1]["collection_name"] == expected_collection_name @@ -118,7 +118,7 @@ class TestMilvusUserCollectionIntegration: result = entity_vectors.search(vector, user, collection, limit=5) # Verify search was called with correct collection name - expected_collection_name = make_safe_collection_name(user, collection, 3, "entity") + expected_collection_name = make_safe_collection_name(user, collection, "entity") mock_client.search.assert_called_once() search_call = mock_client.search.call_args assert search_call[1]["collection_name"] == expected_collection_name @@ -142,9 +142,9 @@ class TestMilvusUserCollectionIntegration: collection_names = set(doc_vectors.collections.values()) expected_names = { - "doc_user1_collection1_3", - "doc_user2_collection2_3", - "doc_user1_collection2_3" + "doc_user1_collection1", + "doc_user2_collection2", + "doc_user1_collection2" } assert collection_names == expected_names @@ -167,9 +167,9 @@ class TestMilvusUserCollectionIntegration: collection_names = set(entity_vectors.collections.values()) expected_names = { - "entity_user1_collection1_3", - "entity_user2_collection2_3", - "entity_user1_collection2_3" + "entity_user1_collection1", + "entity_user2_collection2", + "entity_user1_collection2" } assert collection_names == expected_names @@ -194,10 +194,13 @@ class TestMilvusUserCollectionIntegration: collection_names = set(doc_vectors.collections.values()) expected_names = { - "doc_test_user_test_collection_3", # 3D - "doc_test_user_test_collection_4", # 4D - "doc_test_user_test_collection_2" # 2D + "doc_test_user_test_collection", # Same name for all dimensions + "doc_test_user_test_collection", # now stored per dimension in key + "doc_test_user_test_collection" # but collection name is the same } + # Note: Now all dimensions use the same collection name, they are differentiated by the key + assert len(collection_names) == 1 # Only one unique collection name + assert "doc_test_user_test_collection" in collection_names assert collection_names == expected_names @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') @@ -220,7 +223,7 @@ class TestMilvusUserCollectionIntegration: # Verify only one collection was created assert len(doc_vectors.collections) == 1 - expected_collection_name = "doc_test_user_test_collection_3" + expected_collection_name = "doc_test_user_test_collection" assert doc_vectors.collections[(3, user, collection)] == expected_collection_name @patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient') @@ -233,10 +236,10 @@ class TestMilvusUserCollectionIntegration: # Test various special character combinations test_cases = [ - ("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"), - ("user_123", "collection_456", "doc_user_123_collection_456_3"), - ("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"), - ("user@@@test", "collection---test", "doc_user_test_collection_test_3"), + ("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1"), + ("user_123", "collection_456", "doc_user_123_collection_456"), + ("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces"), + ("user@@@test", "collection---test", "doc_user_test_collection_test"), ] vector = [0.1, 0.2, 0.3] @@ -250,24 +253,24 @@ class TestMilvusUserCollectionIntegration: def test_collection_name_backward_compatibility(self): """Test that new collection names don't conflict with old pattern""" # Old pattern was: {prefix}_{dimension} - # New pattern is: {prefix}_{safe_user}_{safe_collection}_{dimension} - + # New pattern is: {prefix}_{safe_user}_{safe_collection} + # The new pattern should never generate names that match the old pattern old_pattern_examples = ["doc_384", "entity_768", "doc_512"] - + test_cases = [ - ("user", "collection", 384, "doc"), - ("test", "test", 768, "entity"), - ("a", "b", 512, "doc"), + ("user", "collection", "doc"), + ("test", "test", "entity"), + ("a", "b", "doc"), ] - - for user, collection, dimension, prefix in test_cases: - new_name = make_safe_collection_name(user, collection, dimension, prefix) - - # New names should have at least 4 underscores (prefix_user_collection_dimension) + + for user, collection, prefix in test_cases: + new_name = make_safe_collection_name(user, collection, prefix) + + # New names should have at least 2 underscores (prefix_user_collection) # Old names had only 1 underscore (prefix_dimension) - assert new_name.count('_') >= 3, f"New name {new_name} doesn't have enough underscores" - + assert new_name.count('_') >= 2, f"New name {new_name} doesn't have enough underscores" + # New names should not match old pattern assert new_name not in old_pattern_examples, f"New name {new_name} conflicts with old pattern" @@ -286,23 +289,23 @@ class TestMilvusUserCollectionIntegration: dimension = 384 # Generate collection names - doc_name1 = make_safe_collection_name(user1, collection1, dimension, "doc") - doc_name2 = make_safe_collection_name(user2, collection2, dimension, "doc") - - entity_name1 = make_safe_collection_name(user1, collection1, dimension, "entity") - entity_name2 = make_safe_collection_name(user2, collection2, dimension, "entity") + doc_name1 = make_safe_collection_name(user1, collection1, "doc") + doc_name2 = make_safe_collection_name(user2, collection2, "doc") + + entity_name1 = make_safe_collection_name(user1, collection1, "entity") + entity_name2 = make_safe_collection_name(user2, collection2, "entity") # Verify complete isolation assert doc_name1 != doc_name2, "Document collections should be isolated" assert entity_name1 != entity_name2, "Entity collections should be isolated" - # Verify names match expected pattern from Qdrant + # Verify names match expected pattern from new API # Qdrant uses: d_{user}_{collection}_{dimension}, t_{user}_{collection}_{dimension} - # Milvus uses: doc_{safe_user}_{safe_collection}_{dimension}, entity_{safe_user}_{safe_collection}_{dimension} - assert doc_name1 == "doc_my_user_test_coll_1_384" - assert doc_name2 == "doc_other_user_production_data_384" - assert entity_name1 == "entity_my_user_test_coll_1_384" - assert entity_name2 == "entity_other_user_production_data_384" + # New Milvus API uses: doc_{safe_user}_{safe_collection}, entity_{safe_user}_{safe_collection} + assert doc_name1 == "doc_my_user_test_coll_1" + assert doc_name2 == "doc_other_user_production_data" + assert entity_name1 == "entity_my_user_test_coll_1" + assert entity_name2 == "entity_other_user_production_data" # This test would have FAILED with the old implementation that used: # - doc_384 for all document embeddings (no user/collection differentiation) diff --git a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py index 92551587..ce2a7431 100644 --- a/tests/unit/test_query/test_doc_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_doc_embeddings_pinecone_query.py @@ -120,7 +120,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: chunks = await processor.query_document_embeddings(message) # Verify index was accessed correctly - expected_index_name = "d-test_user-test_collection-3" + expected_index_name = "d-test_user-test_collection" processor.pinecone.Index.assert_called_once_with(expected_index_name) # Verify query parameters @@ -239,7 +239,7 @@ class TestPineconeDocEmbeddingsQueryProcessor: @pytest.mark.asyncio async def test_query_document_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions""" + """Test querying with vectors of different dimensions using same index""" message = MagicMock() message.vectors = [ [0.1, 0.2], # 2D vector @@ -248,37 +248,33 @@ class TestPineconeDocEmbeddingsQueryProcessor: message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' - - mock_index_2d = MagicMock() - mock_index_4d = MagicMock() - - def mock_index_side_effect(name): - if name.endswith("-2"): - return mock_index_2d - elif name.endswith("-4"): - return mock_index_4d - - processor.pinecone.Index.side_effect = mock_index_side_effect - - # Mock results for different dimensions + + # Mock single index that handles all dimensions + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Mock results for different vector queries mock_results_2d = MagicMock() - mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})] - mock_index_2d.query.return_value = mock_results_2d - + mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D query'})] + mock_results_4d = MagicMock() - mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})] - mock_index_4d.query.return_value = mock_results_4d - + mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D query'})] + + mock_index.query.side_effect = [mock_results_2d, mock_results_4d] + chunks = await processor.query_document_embeddings(message) - - # Verify different indexes were used + + # Verify same index used for both vectors + expected_index_name = "d-test_user-test_collection" assert processor.pinecone.Index.call_count == 2 - mock_index_2d.query.assert_called_once() - mock_index_4d.query.assert_called_once() - + processor.pinecone.Index.assert_called_with(expected_index_name) + + # Verify both queries were made + assert mock_index.query.call_count == 2 + # Verify results from both dimensions - assert 'Document from 2D index' in chunks - assert 'Document from 4D index' in chunks + assert 'Document from 2D query' in chunks + assert 'Document from 4D query' in chunks @pytest.mark.asyncio async def test_query_document_embeddings_empty_vectors_list(self, processor): diff --git a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py index 5352e002..dbe9b9fc 100644 --- a/tests/unit/test_query/test_graph_embeddings_pinecone_query.py +++ b/tests/unit/test_query/test_graph_embeddings_pinecone_query.py @@ -148,7 +148,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: entities = await processor.query_graph_embeddings(message) # Verify index was accessed correctly - expected_index_name = "t-test_user-test_collection-3" + expected_index_name = "t-test_user-test_collection" processor.pinecone.Index.assert_called_once_with(expected_index_name) # Verify query parameters @@ -265,7 +265,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor: @pytest.mark.asyncio async def test_query_graph_embeddings_different_vector_dimensions(self, processor): - """Test querying with vectors of different dimensions""" + """Test querying with vectors of different dimensions using same index""" message = MagicMock() message.vectors = [ [0.1, 0.2], # 2D vector @@ -274,34 +274,30 @@ class TestPineconeGraphEmbeddingsQueryProcessor: message.limit = 5 message.user = 'test_user' message.collection = 'test_collection' - - mock_index_2d = MagicMock() - mock_index_4d = MagicMock() - - def mock_index_side_effect(name): - if name.endswith("-2"): - return mock_index_2d - elif name.endswith("-4"): - return mock_index_4d - - processor.pinecone.Index.side_effect = mock_index_side_effect - - # Mock results for different dimensions + + # Mock single index that handles all dimensions + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index + + # Mock results for different vector queries mock_results_2d = MagicMock() mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})] - mock_index_2d.query.return_value = mock_results_2d - + mock_results_4d = MagicMock() mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})] - mock_index_4d.query.return_value = mock_results_4d - + + mock_index.query.side_effect = [mock_results_2d, mock_results_4d] + entities = await processor.query_graph_embeddings(message) - - # Verify different indexes were used + + # Verify same index used for both vectors + expected_index_name = "t-test_user-test_collection" assert processor.pinecone.Index.call_count == 2 - mock_index_2d.query.assert_called_once() - mock_index_4d.query.assert_called_once() - + processor.pinecone.Index.assert_called_with(expected_index_name) + + # Verify both queries were made + assert mock_index.query.call_count == 2 + # Verify results from both dimensions entity_values = [e.value for e in entities] assert 'entity_2d' in entity_values diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index f162f5e8..72871456 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -70,7 +70,7 @@ class TestCassandraQueryProcessor: assert result.is_uri is False @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_spo_query(self, mock_trustgraph): """Test querying triples with subject, predicate, and object specified""" from trustgraph.schema import TriplesQueryRequest, Value @@ -98,16 +98,15 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - # Verify TrustGraph was created with correct parameters + # Verify KnowledgeGraph was created with correct parameters mock_trustgraph.assert_called_once_with( hosts=['localhost'], - keyspace='test_user', - table='test_collection' + keyspace='test_user' ) # Verify get_spo was called with correct parameters mock_tg_instance.get_spo.assert_called_once_with( - 'test_subject', 'test_predicate', 'test_object', limit=100 + 'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100 ) # Verify result contains the queried triple @@ -144,7 +143,7 @@ class TestCassandraQueryProcessor: assert processor.table is None @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_sp_pattern(self, mock_trustgraph): """Test SP query pattern (subject and predicate, no object)""" from trustgraph.schema import TriplesQueryRequest, Value @@ -170,14 +169,14 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50) + mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50) assert len(result) == 1 assert result[0].s.value == 'test_subject' assert result[0].p.value == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_s_pattern(self, mock_trustgraph): """Test S query pattern (subject only)""" from trustgraph.schema import TriplesQueryRequest, Value @@ -203,14 +202,14 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25) + mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25) assert len(result) == 1 assert result[0].s.value == 'test_subject' assert result[0].p.value == 'result_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_p_pattern(self, mock_trustgraph): """Test P query pattern (predicate only)""" from trustgraph.schema import TriplesQueryRequest, Value @@ -236,14 +235,14 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10) + mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10) assert len(result) == 1 assert result[0].s.value == 'result_subject' assert result[0].p.value == 'test_predicate' assert result[0].o.value == 'result_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_o_pattern(self, mock_trustgraph): """Test O query pattern (object only)""" from trustgraph.schema import TriplesQueryRequest, Value @@ -269,14 +268,14 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75) + mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75) assert len(result) == 1 assert result[0].s.value == 'result_subject' assert result[0].p.value == 'result_predicate' assert result[0].o.value == 'test_object' @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_get_all_pattern(self, mock_trustgraph): """Test query pattern with no constraints (get all)""" from trustgraph.schema import TriplesQueryRequest @@ -303,7 +302,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples(query) - mock_tg_instance.get_all.assert_called_once_with(limit=1000) + mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.value == 'all_subject' assert result[0].p.value == 'all_predicate' @@ -376,7 +375,7 @@ class TestCassandraQueryProcessor: mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n') @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_with_authentication(self, mock_trustgraph): """Test querying with username and password authentication""" from trustgraph.schema import TriplesQueryRequest, Value @@ -402,17 +401,16 @@ class TestCassandraQueryProcessor: await processor.query_triples(query) - # Verify TrustGraph was created with authentication + # Verify KnowledgeGraph was created with authentication mock_trustgraph.assert_called_once_with( hosts=['cassandra'], # Updated default keyspace='test_user', - table='test_collection', username='authuser', password='authpass' ) @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_table_reuse(self, mock_trustgraph): """Test that TrustGraph is reused for same table""" from trustgraph.schema import TriplesQueryRequest, Value @@ -441,7 +439,7 @@ class TestCassandraQueryProcessor: assert mock_trustgraph.call_count == 1 # Should not increase @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_table_switching(self, mock_trustgraph): """Test table switching creates new TrustGraph""" from trustgraph.schema import TriplesQueryRequest, Value @@ -463,7 +461,7 @@ class TestCassandraQueryProcessor: ) await processor.query_triples(query1) - assert processor.table == ('user1', 'collection1') + assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( @@ -476,13 +474,13 @@ class TestCassandraQueryProcessor: ) await processor.query_triples(query2) - assert processor.table == ('user2', 'collection2') + assert processor.table == 'user2' # Verify TrustGraph was created twice assert mock_trustgraph.call_count == 2 @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_exception_handling(self, mock_trustgraph): """Test exception handling during query execution""" from trustgraph.schema import TriplesQueryRequest, Value @@ -506,7 +504,7 @@ class TestCassandraQueryProcessor: await processor.query_triples(query) @pytest.mark.asyncio - @patch('trustgraph.query.triples.cassandra.service.TrustGraph') + @patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph') async def test_query_triples_multiple_results(self, mock_trustgraph): """Test query returning multiple results""" from trustgraph.schema import TriplesQueryRequest, Value diff --git a/tests/unit/test_storage/test_cassandra_config_integration.py b/tests/unit/test_storage/test_cassandra_config_integration.py index 42e02d3d..754a4bb0 100644 --- a/tests/unit/test_storage/test_cassandra_config_integration.py +++ b/tests/unit/test_storage/test_cassandra_config_integration.py @@ -18,7 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KgStore class TestTriplesWriterConfiguration: """Test Cassandra configuration in triples writer processor.""" - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_environment_variable_configuration(self, mock_trust_graph): """Test processor picks up configuration from environment variables.""" env_vars = { @@ -34,7 +34,7 @@ class TestTriplesWriterConfiguration: assert processor.cassandra_username == 'env-user' assert processor.cassandra_password == 'env-pass' - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_parameter_override_environment(self, mock_trust_graph): """Test explicit parameters override environment variables.""" env_vars = { @@ -55,7 +55,7 @@ class TestTriplesWriterConfiguration: assert processor.cassandra_username == 'param-user' assert processor.cassandra_password == 'param-pass' - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_no_backward_compatibility_graph_params(self, mock_trust_graph): """Test that old graph_* parameter names are no longer supported.""" processor = TriplesWriter( @@ -70,7 +70,7 @@ class TestTriplesWriterConfiguration: assert processor.cassandra_username is None assert processor.cassandra_password is None - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_default_configuration(self, mock_trust_graph): """Test default configuration when no params or env vars provided.""" with patch.dict(os.environ, {}, clear=True): @@ -163,7 +163,7 @@ class TestObjectsWriterConfiguration: class TestTriplesQueryConfiguration: """Test Cassandra configuration in triples query processor.""" - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_environment_variable_configuration(self, mock_trust_graph): """Test processor picks up configuration from environment variables.""" env_vars = { @@ -179,7 +179,7 @@ class TestTriplesQueryConfiguration: assert processor.cassandra_username == 'query-env-user' assert processor.cassandra_password == 'query-env-pass' - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_only_new_parameters_work(self, mock_trust_graph): """Test that only new parameters work.""" processor = TriplesQuery( @@ -379,7 +379,7 @@ class TestCommandLineArgumentHandling: class TestConfigurationPriorityIntegration: """Test complete configuration priority chain in processors.""" - @patch('trustgraph.direct.cassandra.TrustGraph') + @patch('trustgraph.direct.cassandra_kg.KnowledgeGraph') def test_complete_priority_chain(self, mock_trust_graph): """Test CLI params > env vars > defaults priority in actual processor.""" env_vars = { diff --git a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py index 6c4ddb6b..113a75cb 100644 --- a/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_pinecone_storage.py @@ -135,7 +135,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) # Verify index name and operations - expected_index_name = "d-test_user-test_collection-3" + expected_index_name = "d-test_user-test_collection" processor.pinecone.Index.assert_called_with(expected_index_name) # Verify upsert was called for each vector @@ -203,7 +203,7 @@ class TestPineconeDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) # Verify index creation was called - expected_index_name = "d-test_user-test_collection-3" + expected_index_name = "d-test_user-test_collection" processor.pinecone.create_index.assert_called_once() create_call = processor.pinecone.create_index.call_args assert create_call[1]['name'] == expected_index_name @@ -299,12 +299,11 @@ class TestPineconeDocEmbeddingsStorageProcessor: mock_index_3d = MagicMock() def mock_index_side_effect(name): - if name.endswith("-2"): - return mock_index_2d - elif name.endswith("-4"): - return mock_index_4d - elif name.endswith("-3"): - return mock_index_3d + # All dimensions now use the same index name pattern + # Different dimensions will be handled within the same index + if "test_user" in name and "test_collection" in name: + return mock_index_2d # Just return one mock for all + return MagicMock() processor.pinecone.Index.side_effect = mock_index_side_effect processor.pinecone.has_index.return_value = True @@ -312,11 +311,10 @@ class TestPineconeDocEmbeddingsStorageProcessor: with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): await processor.store_document_embeddings(message) - # Verify different indexes were used for different dimensions - assert processor.pinecone.Index.call_count == 3 - mock_index_2d.upsert.assert_called_once() - mock_index_4d.upsert.assert_called_once() - mock_index_3d.upsert.assert_called_once() + # Verify all vectors are now stored in the same index + # (Pinecone can handle mixed dimensions in the same index) + assert processor.pinecone.Index.call_count == 3 # Called once per vector + mock_index_2d.upsert.call_count == 3 # All upserts go to same index @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunks_list(self, processor): diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index 4fadc641..021b5d96 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -106,7 +106,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert # Verify collection existence was checked - expected_collection = 'd_test_user_test_collection_3' + expected_collection = 'd_test_user_test_collection' mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) # Verify upsert was called @@ -309,7 +309,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) # Assert - expected_collection = 'd_new_user_new_collection_5' + expected_collection = 'd_new_user_new_collection' # Verify collection existence check and creation mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) @@ -408,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message2) # Assert - expected_collection = 'd_cache_user_cache_collection_3' + expected_collection = 'd_cache_user_cache_collection' assert processor.last_collection == expected_collection # Verify second call skipped existence check (cached) @@ -455,17 +455,16 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) # Assert - # Should check existence of both collections - expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3'] - actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list] - assert actual_calls == expected_collections - - # Should upsert to both collections + # Should check existence of the same collection (dimensions no longer create separate collections) + expected_collection = 'd_dim_user_dim_collection' + mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + + # Should upsert to the same collection for both vectors assert mock_qdrant_instance.upsert.call_count == 2 - + upsert_calls = mock_qdrant_instance.upsert.call_args_list - assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' - assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' + assert upsert_calls[0][1]['collection_name'] == expected_collection + assert upsert_calls[1][1]['collection_name'] == expected_collection @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') diff --git a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py index 91e60057..cf83e2ed 100644 --- a/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_pinecone_storage.py @@ -135,7 +135,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: await processor.store_graph_embeddings(message) # Verify index name and operations - expected_index_name = "t-test_user-test_collection-3" + expected_index_name = "t-test_user-test_collection" processor.pinecone.Index.assert_called_with(expected_index_name) # Verify upsert was called for each vector @@ -203,7 +203,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor: await processor.store_graph_embeddings(message) # Verify index creation was called - expected_index_name = "t-test_user-test_collection-3" + expected_index_name = "t-test_user-test_collection" processor.pinecone.create_index.assert_called_once() create_call = processor.pinecone.create_index.call_args assert create_call[1]['name'] == expected_index_name @@ -256,12 +256,12 @@ class TestPineconeGraphEmbeddingsStorageProcessor: @pytest.mark.asyncio async def test_store_graph_embeddings_different_vector_dimensions(self, processor): - """Test storing graph embeddings with different vector dimensions""" + """Test storing graph embeddings with different vector dimensions to same index""" message = MagicMock() message.metadata = MagicMock() message.metadata.user = 'test_user' message.metadata.collection = 'test_collection' - + entity = EntityEmbeddings( entity=Value(value="test_entity", is_uri=False), vectors=[ @@ -271,30 +271,21 @@ class TestPineconeGraphEmbeddingsStorageProcessor: ] ) message.entities = [entity] - - mock_index_2d = MagicMock() - mock_index_4d = MagicMock() - mock_index_3d = MagicMock() - - def mock_index_side_effect(name): - if name.endswith("-2"): - return mock_index_2d - elif name.endswith("-4"): - return mock_index_4d - elif name.endswith("-3"): - return mock_index_3d - - processor.pinecone.Index.side_effect = mock_index_side_effect + + # All vectors now use the same index (no dimension in name) + mock_index = MagicMock() + processor.pinecone.Index.return_value = mock_index processor.pinecone.has_index.return_value = True - + with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): await processor.store_graph_embeddings(message) - - # Verify different indexes were used for different dimensions - assert processor.pinecone.Index.call_count == 3 - mock_index_2d.upsert.assert_called_once() - mock_index_4d.upsert.assert_called_once() - mock_index_3d.upsert.assert_called_once() + + # Verify same index was used for all dimensions + expected_index_name = 't-test_user-test_collection' + processor.pinecone.Index.assert_called_with(expected_index_name) + + # Verify all vectors were upserted to the same index + assert mock_index.upsert.call_count == 3 @pytest.mark.asyncio async def test_store_graph_embeddings_empty_entities_list(self, processor): diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index 081d79cd..ee9fc0fc 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -69,7 +69,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection') # Assert - expected_name = 't_test_user_test_collection_512' + expected_name = 't_test_user_test_collection' assert collection_name == expected_name assert processor.last_collection == expected_name @@ -118,7 +118,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert # Verify collection existence was checked - expected_collection = 't_test_user_test_collection_3' + expected_collection = 't_test_user_test_collection' mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) # Verify upsert was called @@ -156,7 +156,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection') # Assert - expected_name = 't_existing_user_existing_collection_256' + expected_name = 't_existing_user_existing_collection' assert collection_name == expected_name assert processor.last_collection == expected_name @@ -194,7 +194,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection') # Assert - expected_name = 't_cache_user_cache_collection_128' + expected_name = 't_cache_user_cache_collection' assert collection_name1 == expected_name assert collection_name2 == expected_name diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 45be3b99..a6a6a539 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -86,7 +86,7 @@ class TestCassandraStorageProcessor: assert processor.cassandra_username == 'new-user' # Only cassandra_* params work @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_table_switching_with_auth(self, mock_trustgraph): """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() @@ -107,18 +107,17 @@ class TestCassandraStorageProcessor: await processor.store_triples(mock_message) - # Verify TrustGraph was called with auth parameters + # Verify KnowledgeGraph was called with auth parameters mock_trustgraph.assert_called_once_with( hosts=['cassandra'], # Updated default keyspace='user1', - table='collection1', username='testuser', password='testpass' ) - assert processor.table == ('user1', 'collection1') + assert processor.table == 'user1' @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_table_switching_without_auth(self, mock_trustgraph): """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() @@ -135,16 +134,15 @@ class TestCassandraStorageProcessor: await processor.store_triples(mock_message) - # Verify TrustGraph was called without auth parameters + # Verify KnowledgeGraph was called without auth parameters mock_trustgraph.assert_called_once_with( hosts=['cassandra'], # Updated default - keyspace='user2', - table='collection2' + keyspace='user2' ) - assert processor.table == ('user2', 'collection2') + assert processor.table == 'user2' @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_table_reuse_when_same(self, mock_trustgraph): """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() @@ -168,7 +166,7 @@ class TestCassandraStorageProcessor: assert mock_trustgraph.call_count == 1 # Should not increase @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_triple_insertion(self, mock_trustgraph): """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() @@ -198,11 +196,11 @@ class TestCassandraStorageProcessor: # Verify both triples were inserted assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1') - mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2') + mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1') + mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2') @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_triple_insertion_with_empty_list(self, mock_trustgraph): """Test behavior when message has no triples""" taskgroup_mock = MagicMock() @@ -223,7 +221,7 @@ class TestCassandraStorageProcessor: mock_tg_instance.insert.assert_not_called() @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') @patch('trustgraph.storage.triples.cassandra.write.time.sleep') async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph): """Test exception handling during TrustGraph creation""" @@ -328,7 +326,7 @@ class TestCassandraStorageProcessor: mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n') @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph): """Test table switching when different tables are used in sequence""" taskgroup_mock = MagicMock() @@ -345,7 +343,7 @@ class TestCassandraStorageProcessor: mock_message1.triples = [] await processor.store_triples(mock_message1) - assert processor.table == ('user1', 'collection1') + assert processor.table == 'user1' assert processor.tg == mock_tg_instance1 # Second message with different table @@ -355,14 +353,14 @@ class TestCassandraStorageProcessor: mock_message2.triples = [] await processor.store_triples(mock_message2) - assert processor.table == ('user2', 'collection2') + assert processor.table == 'user2' assert processor.tg == mock_tg_instance2 # Verify TrustGraph was created twice for different tables assert mock_trustgraph.call_count == 2 @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph): """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() @@ -386,13 +384,14 @@ class TestCassandraStorageProcessor: # Verify the triple was inserted with special characters preserved mock_tg_instance.insert.assert_called_once_with( + 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', 'object with "quotes" and unicode: ñáéíóú' ) @pytest.mark.asyncio - @patch('trustgraph.storage.triples.cassandra.write.TrustGraph') + @patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph') async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph): """Test that table remains unchanged when TrustGraph creation fails""" taskgroup_mock = MagicMock() diff --git a/tests/unit/test_storage/test_triples_falkordb_storage.py b/tests/unit/test_storage/test_triples_falkordb_storage.py index 7d602b6f..f9dfbc5d 100644 --- a/tests/unit/test_storage/test_triples_falkordb_storage.py +++ b/tests/unit/test_storage/test_triples_falkordb_storage.py @@ -86,15 +86,17 @@ class TestFalkorDBStorageProcessor: mock_result = MagicMock() mock_result.nodes_created = 1 mock_result.run_time_ms = 10 - + processor.io.query.return_value = mock_result - - processor.create_node(test_uri) - + + processor.create_node(test_uri, 'test_user', 'test_collection') + processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", params={ "uri": test_uri, + "user": 'test_user', + "collection": 'test_collection', }, ) @@ -104,15 +106,17 @@ class TestFalkorDBStorageProcessor: mock_result = MagicMock() mock_result.nodes_created = 1 mock_result.run_time_ms = 10 - + processor.io.query.return_value = mock_result - - processor.create_literal(test_value) - + + processor.create_literal(test_value, 'test_user', 'test_collection') + processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", params={ "value": test_value, + "user": 'test_user', + "collection": 'test_collection', }, ) @@ -121,23 +125,25 @@ class TestFalkorDBStorageProcessor: src_uri = 'http://example.com/src' pred_uri = 'http://example.com/pred' dest_uri = 'http://example.com/dest' - + mock_result = MagicMock() mock_result.nodes_created = 0 mock_result.run_time_ms = 5 - + processor.io.query.return_value = mock_result - - processor.relate_node(src_uri, pred_uri, dest_uri) - + + processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection') + processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": dest_uri, "uri": pred_uri, + "user": 'test_user', + "collection": 'test_collection', }, ) @@ -146,23 +152,25 @@ class TestFalkorDBStorageProcessor: src_uri = 'http://example.com/src' pred_uri = 'http://example.com/pred' literal_value = 'literal destination' - + mock_result = MagicMock() mock_result.nodes_created = 0 mock_result.run_time_ms = 5 - + processor.io.query.return_value = mock_result - - processor.relate_literal(src_uri, pred_uri, literal_value) - + + processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection') + processor.io.query.assert_called_once_with( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", params={ "src": src_uri, "dest": literal_value, "uri": pred_uri, + "user": 'test_user', + "collection": 'test_collection', }, ) @@ -191,14 +199,16 @@ class TestFalkorDBStorageProcessor: # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}), + (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), # Create object node - (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}), + (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), + {"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}), + (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -220,14 +230,16 @@ class TestFalkorDBStorageProcessor: # Verify queries were called in the correct order expected_calls = [ # Create subject node - (("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}), + (("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), + {"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), # Create literal object - (("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}), + (("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",), + {"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}), # Create relationship - (("MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)",), - {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}), + (("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), + {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), ] assert processor.io.query.call_count == 3 @@ -408,12 +420,14 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_node(test_uri) + processor.create_node(test_uri, 'test_user', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", params={ "uri": test_uri, + "user": 'test_user', + "collection": 'test_collection', }, ) @@ -426,11 +440,13 @@ class TestFalkorDBStorageProcessor: processor.io.query.return_value = mock_result - processor.create_literal(test_value) + processor.create_literal(test_value, 'test_user', 'test_collection') processor.io.query.assert_called_once_with( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", params={ "value": test_value, + "user": 'test_user', + "collection": 'test_collection', }, ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index b65f62ac..b0bae8ce 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -8,6 +8,7 @@ from . library import Library from . flow import Flow from . config import Config from . knowledge import Knowledge +from . collection import Collection from . exceptions import * from . types import * @@ -68,3 +69,6 @@ class Api: def library(self): return Library(self) + + def collection(self): + return Collection(self) diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py new file mode 100644 index 00000000..9a826899 --- /dev/null +++ b/trustgraph-base/trustgraph/api/collection.py @@ -0,0 +1,90 @@ +import datetime +import logging + +from . types import CollectionMetadata +from . exceptions import * + +logger = logging.getLogger(__name__) + +class Collection: + + def __init__(self, api): + self.api = api + + def request(self, request): + return self.api.request(f"collection-management", request) + + def list_collections(self, user, tag_filter=None): + + input = { + "operation": "list-collections", + "user": user, + } + + if tag_filter: + input["tag_filter"] = tag_filter + + object = self.request(input) + + try: + return [ + CollectionMetadata( + user = v["user"], + collection = v["collection"], + name = v["name"], + description = v["description"], + tags = v["tags"], + created_at = v["created_at"], + updated_at = v["updated_at"] + ) + for v in object["collections"] + ] + except Exception as e: + logger.error("Failed to parse collection list response", exc_info=True) + raise ProtocolException(f"Response not formatted correctly") + + def update_collection(self, user, collection, name=None, description=None, tags=None): + + input = { + "operation": "update-collection", + "user": user, + "collection": collection, + } + + if name is not None: + input["name"] = name + if description is not None: + input["description"] = description + if tags is not None: + input["tags"] = tags + + object = self.request(input) + + try: + if "collections" in object and object["collections"]: + v = object["collections"][0] + return CollectionMetadata( + user = v["user"], + collection = v["collection"], + name = v["name"], + description = v["description"], + tags = v["tags"], + created_at = v["created_at"], + updated_at = v["updated_at"] + ) + return None + except Exception as e: + logger.error("Failed to parse collection update response", exc_info=True) + raise ProtocolException(f"Response not formatted correctly") + + def delete_collection(self, user, collection): + + input = { + "operation": "delete-collection", + "user": user, + "collection": collection, + } + + object = self.request(input) + + return {} \ No newline at end of file diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index fe3472b1..71b438f6 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -41,3 +41,13 @@ class ProcessingMetadata: user : str collection : str tags : List[str] + +@dataclasses.dataclass +class CollectionMetadata: + user : str + collection : str + name : str + description : str + tags : List[str] + created_at : str + updated_at : str diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 0c805967..80c5438b 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -25,6 +25,7 @@ from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQue from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator # Register all service translators TranslatorRegistry.register_service( @@ -135,6 +136,12 @@ TranslatorRegistry.register_service( StructuredDataDiagnosisResponseTranslator() ) +TranslatorRegistry.register_service( + "collection-management", + CollectionManagementRequestTranslator(), + CollectionManagementResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py new file mode 100644 index 00000000..5c2a0fd4 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -0,0 +1,112 @@ +from typing import Dict, Any, List +from ...schema import CollectionManagementRequest, CollectionManagementResponse, CollectionMetadata, Error +from .base import MessageTranslator + + +class CollectionManagementRequestTranslator(MessageTranslator): + """Translator for CollectionManagementRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementRequest: + return CollectionManagementRequest( + operation=data.get("operation", ""), + user=data.get("user", ""), + collection=data.get("collection", ""), + timestamp=data.get("timestamp", ""), + name=data.get("name", ""), + description=data.get("description", ""), + tags=data.get("tags", []), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + tag_filter=data.get("tag_filter", []), + limit=data.get("limit", 50) + ) + + def from_pulsar(self, obj: CollectionManagementRequest) -> Dict[str, Any]: + result = {} + + if obj.operation: + result["operation"] = obj.operation + if obj.user: + result["user"] = obj.user + if obj.collection: + result["collection"] = obj.collection + if obj.timestamp: + result["timestamp"] = obj.timestamp + if obj.name: + result["name"] = obj.name + if obj.description: + result["description"] = obj.description + if obj.tags: + result["tags"] = list(obj.tags) + if obj.created_at: + result["created_at"] = obj.created_at + if obj.updated_at: + result["updated_at"] = obj.updated_at + if obj.tag_filter: + result["tag_filter"] = list(obj.tag_filter) + if obj.limit: + result["limit"] = obj.limit + + return result + + +class CollectionManagementResponseTranslator(MessageTranslator): + """Translator for CollectionManagementResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementResponse: + # Handle error + error = None + if "error" in data and data["error"]: + error_data = data["error"] + error = Error( + type=error_data.get("type", ""), + message=error_data.get("message", "") + ) + + # Handle collections array + collections = [] + if "collections" in data: + for coll_data in data["collections"]: + collections.append(CollectionMetadata( + user=coll_data.get("user", ""), + collection=coll_data.get("collection", ""), + name=coll_data.get("name", ""), + description=coll_data.get("description", ""), + tags=coll_data.get("tags", []), + created_at=coll_data.get("created_at", ""), + updated_at=coll_data.get("updated_at", "") + )) + + return CollectionManagementResponse( + success=data.get("success", ""), + error=error, + timestamp=data.get("timestamp", ""), + collections=collections + ) + + def from_pulsar(self, obj: CollectionManagementResponse) -> Dict[str, Any]: + result = {} + + if obj.success: + result["success"] = obj.success + if obj.error: + result["error"] = { + "type": obj.error.type, + "message": obj.error.message + } + if obj.timestamp: + result["timestamp"] = obj.timestamp + if obj.collections: + result["collections"] = [] + for coll in obj.collections: + result["collections"].append({ + "user": coll.user, + "collection": coll.collection, + "name": coll.name, + "description": coll.description, + "tags": list(coll.tags) if coll.tags else [], + "created_at": coll.created_at, + "updated_at": coll.updated_at + }) + + return result \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index d1c5448a..aaeb739f 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -10,4 +10,6 @@ from .lookup import * from .nlp_query import * from .structured_query import * from .objects_query import * -from .diagnosis import * \ No newline at end of file +from .diagnosis import * +from .collection import * +from .storage import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py new file mode 100644 index 00000000..bb837c63 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -0,0 +1,60 @@ +from pulsar.schema import Record, String, Integer, Array +from datetime import datetime + +from ..core.primitives import Error +from ..core.topic import topic + +############################################################################ + +# Collection management operations + +# Collection metadata operations (for librarian service) + +class CollectionMetadata(Record): + """Collection metadata record""" + user = String() + collection = String() + name = String() + description = String() + tags = Array(String()) + created_at = String() # ISO timestamp + updated_at = String() # ISO timestamp + +############################################################################ + +class CollectionManagementRequest(Record): + """Request for collection management operations""" + operation = String() # e.g., "delete-collection" + + # For 'list-collections' + user = String() + collection = String() + timestamp = String() # ISO timestamp + name = String() + description = String() + tags = Array(String()) + created_at = String() # ISO timestamp + updated_at = String() # ISO timestamp + + # For list + tag_filter = Array(String()) # Optional filter by tags + limit = Integer() + +class CollectionManagementResponse(Record): + """Response for collection management operations""" + success = String() # "true" or "false" + error = Error() # Only populated if success is "false" + timestamp = String() # ISO timestamp + collections = Array(CollectionMetadata()) + + +############################################################################ + +# Topics + +collection_request_queue = topic( + 'collection', kind='non-persistent', namespace='request' +) +collection_response_queue = topic( + 'collection', kind='non-persistent', namespace='response' +) diff --git a/trustgraph-base/trustgraph/schema/services/storage.py b/trustgraph-base/trustgraph/schema/services/storage.py new file mode 100644 index 00000000..16791615 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/storage.py @@ -0,0 +1,42 @@ +from pulsar.schema import Record, String + +from ..core.primitives import Error +from ..core.topic import topic + +############################################################################ + +# Storage management operations + +class StorageManagementRequest(Record): + """Request for storage management operations sent to store processors""" + operation = String() # e.g., "delete-collection" + user = String() + collection = String() + +class StorageManagementResponse(Record): + """Response from storage processors for management operations""" + error = Error() # Only populated if there's an error, if null success + +############################################################################ + +# Storage management topics + +# Topics for sending collection management requests to different storage types +vector_storage_management_topic = topic( + 'vector-storage-management', kind='non-persistent', namespace='request' +) + +object_storage_management_topic = topic( + 'object-storage-management', kind='non-persistent', namespace='request' +) + +triples_storage_management_topic = topic( + 'triples-storage-management', kind='non-persistent', namespace='request' +) + +# Topic for receiving responses from storage processors +storage_management_response_topic = topic( + 'storage-management', kind='non-persistent', namespace='response' +) + +############################################################################ diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 86fb0831..06b1e303 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -86,6 +86,9 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main" tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" tg-delete-config-item = "trustgraph.cli.delete_config_item:main" +tg-list-collections = "trustgraph.cli.list_collections:main" +tg-update-collection = "trustgraph.cli.update_collection:main" +tg-delete-collection = "trustgraph.cli.delete_collection:main" [tool.setuptools.packages.find] include = ["trustgraph*"] diff --git a/trustgraph-cli/trustgraph/cli/delete_collection.py b/trustgraph-cli/trustgraph/cli/delete_collection.py new file mode 100644 index 00000000..3e19ac09 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/delete_collection.py @@ -0,0 +1,72 @@ +""" +Delete a collection and all its data +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = "trustgraph" + +def delete_collection(url, user, collection, confirm): + + if not confirm: + response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ") + if response.lower() not in ['y', 'yes']: + print("Operation cancelled.") + return + + api = Api(url).collection() + + api.delete_collection(user=user, collection=collection) + + print(f"Collection '{collection}' deleted successfully.") + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-delete-collection', + description=__doc__, + ) + + parser.add_argument( + 'collection', + help='Collection ID to delete' + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-y', '--yes', + action='store_true', + help='Skip confirmation prompt' + ) + + args = parser.parse_args() + + try: + + delete_collection( + url = args.api_url, + user = args.user, + collection = args.collection, + confirm = args.yes + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py new file mode 100644 index 00000000..8429b0cb --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -0,0 +1,85 @@ +""" +List collections for a user +""" + +import argparse +import os +import tabulate +from trustgraph.api import Api +import json + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = "trustgraph" + +def list_collections(url, user, tag_filter): + + api = Api(url).collection() + + collections = api.list_collections(user=user, tag_filter=tag_filter) + + if len(collections) == 0: + print("No collections.") + return + + table = [] + for collection in collections: + table.append([ + collection.collection, + collection.name, + collection.description, + ", ".join(collection.tags), + collection.created_at, + collection.updated_at + ]) + + headers = ["Collection", "Name", "Description", "Tags", "Created", "Updated"] + + print(tabulate.tabulate( + table, + headers=headers, + tablefmt="pretty", + stralign="left", + maxcolwidths=[20, 30, 50, 30, 19, 19], + )) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-list-collections', + description=__doc__, + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-t', '--tag-filter', + action='append', + help='Filter by tags (can be specified multiple times)' + ) + + args = parser.parse_args() + + try: + + list_collections( + url = args.api_url, + user = args.user, + tag_filter = args.tag_filter + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/trustgraph-cli/trustgraph/cli/update_collection.py b/trustgraph-cli/trustgraph/cli/update_collection.py new file mode 100644 index 00000000..094c033c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/update_collection.py @@ -0,0 +1,103 @@ +""" +Update collection metadata +""" + +import argparse +import os +import tabulate +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = "trustgraph" + +def update_collection(url, user, collection, name, description, tags): + + api = Api(url).collection() + + result = api.update_collection( + user=user, + collection=collection, + name=name, + description=description, + tags=tags + ) + + if result: + print(f"Collection '{collection}' updated successfully.") + + table = [] + table.append(("Collection", result.collection)) + table.append(("Name", result.name)) + table.append(("Description", result.description)) + table.append(("Tags", ", ".join(result.tags))) + table.append(("Updated", result.updated_at)) + + print(tabulate.tabulate( + table, + tablefmt="pretty", + stralign="left", + maxcolwidths=[None, 67], + )) + else: + print(f"Failed to update collection '{collection}'.") + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-update-collection', + description=__doc__, + ) + + parser.add_argument( + 'collection', + help='Collection ID to update' + ) + + parser.add_argument( + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-n', '--name', + help='Collection name' + ) + + parser.add_argument( + '-d', '--description', + help='Collection description' + ) + + parser.add_argument( + '-t', '--tag', + action='append', + dest='tags', + help='Collection tags (can be specified multiple times)' + ) + + args = parser.parse_args() + + try: + + update_collection( + url = args.api_url, + user = args.user, + collection = args.collection, + name = args.name, + description = args.description, + tags = args.tags + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/direct/cassandra.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py similarity index 54% rename from trustgraph-flow/trustgraph/direct/cassandra.py rename to trustgraph-flow/trustgraph/direct/cassandra_kg.py index f7ca7e5e..93e41230 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -6,18 +6,18 @@ from ssl import SSLContext, PROTOCOL_TLSv1_2 # Global list to track clusters for cleanup _active_clusters = [] -class TrustGraph: +class KnowledgeGraph: def __init__( self, hosts=None, - keyspace="trustgraph", table="default", username=None, password=None + keyspace="trustgraph", username=None, password=None ): if hosts is None: hosts = ["localhost"] self.keyspace = keyspace - self.table = table + self.table = "triples" # Fixed table name for unified schema self.username = username if username and password: @@ -55,13 +55,19 @@ class TrustGraph: self.session.execute(f""" create table if not exists {self.table} ( + collection text, s text, p text, o text, - PRIMARY KEY (s, p, o) + PRIMARY KEY (collection, s, p, o) ); """); + self.session.execute(f""" + create index if not exists {self.table}_s + ON {self.table} (s); + """); + self.session.execute(f""" create index if not exists {self.table}_p ON {self.table} (p); @@ -72,58 +78,66 @@ class TrustGraph: ON {self.table} (o); """); - def insert(self, s, p, o): - + def insert(self, collection, s, p, o): + self.session.execute( - f"insert into {self.table} (s, p, o) values (%s, %s, %s)", - (s, p, o) + f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)", + (collection, s, p, o) ) - def get_all(self, limit=50): + def get_all(self, collection, limit=50): return self.session.execute( - f"select s, p, o from {self.table} limit {limit}" + f"select s, p, o from {self.table} where collection = %s limit {limit}", + (collection,) ) - def get_s(self, s, limit=10): + def get_s(self, collection, s, limit=10): return self.session.execute( - f"select p, o from {self.table} where s = %s limit {limit}", - (s,) + f"select p, o from {self.table} where collection = %s and s = %s limit {limit}", + (collection, s) ) - def get_p(self, p, limit=10): + def get_p(self, collection, p, limit=10): return self.session.execute( - f"select s, o from {self.table} where p = %s limit {limit}", - (p,) + f"select s, o from {self.table} where collection = %s and p = %s limit {limit}", + (collection, p) ) - def get_o(self, o, limit=10): + def get_o(self, collection, o, limit=10): return self.session.execute( - f"select s, p from {self.table} where o = %s limit {limit}", - (o,) + f"select s, p from {self.table} where collection = %s and o = %s limit {limit}", + (collection, o) ) - def get_sp(self, s, p, limit=10): + def get_sp(self, collection, s, p, limit=10): return self.session.execute( - f"select o from {self.table} where s = %s and p = %s limit {limit}", - (s, p) + f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}", + (collection, s, p) ) - def get_po(self, p, o, limit=10): + def get_po(self, collection, p, o, limit=10): return self.session.execute( - f"select s from {self.table} where p = %s and o = %s limit {limit} allow filtering", - (p, o) + f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering", + (collection, p, o) ) - def get_os(self, o, s, limit=10): + def get_os(self, collection, o, s, limit=10): return self.session.execute( - f"select p from {self.table} where o = %s and s = %s limit {limit}", - (o, s) + f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering", + (collection, o, s) ) - def get_spo(self, s, p, o, limit=10): + def get_spo(self, collection, s, p, o, limit=10): return self.session.execute( - f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""", - (s, p, o) + f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""", + (collection, s, p, o) + ) + + def delete_collection(self, collection): + """Delete all triples for a specific collection""" + self.session.execute( + f"delete from {self.table} where collection = %s", + (collection,) ) def close(self): diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 220c8d7b..24ac6b23 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -6,7 +6,7 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, dimension, prefix): +def make_safe_collection_name(user, collection, prefix): """ Create a safe Milvus collection name from user/collection parameters. Milvus only allows letters, numbers, and underscores. @@ -26,7 +26,7 @@ def make_safe_collection_name(user, collection, dimension, prefix): safe_user = sanitize(user) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + return f"{prefix}_{safe_user}_{safe_collection}" class DocVectors: @@ -51,7 +51,7 @@ class DocVectors: def init_collection(self, dimension, user, collection): - collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) + collection_name = make_safe_collection_name(user, collection, self.prefix) pkey_field = FieldSchema( name="id", @@ -162,3 +162,20 @@ class DocVectors: return res + def delete_collection(self, user, collection): + """Delete a collection for the given user and collection""" + collection_name = make_safe_collection_name(user, collection, self.prefix) + + # Check if collection exists + if self.client.has_collection(collection_name): + # Drop the collection + self.client.drop_collection(collection_name) + logger.info(f"Deleted Milvus collection: {collection_name}") + + # Remove from our local cache + keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + for key in keys_to_remove: + del self.collections[key] + else: + logger.info(f"Collection {collection_name} does not exist, nothing to delete") + diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index b179c7de..85292a85 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -6,7 +6,7 @@ import re logger = logging.getLogger(__name__) -def make_safe_collection_name(user, collection, dimension, prefix): +def make_safe_collection_name(user, collection, prefix): """ Create a safe Milvus collection name from user/collection parameters. Milvus only allows letters, numbers, and underscores. @@ -26,7 +26,7 @@ def make_safe_collection_name(user, collection, dimension, prefix): safe_user = sanitize(user) safe_collection = sanitize(collection) - return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + return f"{prefix}_{safe_user}_{safe_collection}" class EntityVectors: @@ -51,7 +51,7 @@ class EntityVectors: def init_collection(self, dimension, user, collection): - collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) + collection_name = make_safe_collection_name(user, collection, self.prefix) pkey_field = FieldSchema( name="id", @@ -162,3 +162,20 @@ class EntityVectors: return res + def delete_collection(self, user, collection): + """Delete a collection for the given user and collection""" + collection_name = make_safe_collection_name(user, collection, self.prefix) + + # Check if collection exists + if self.client.has_collection(collection_name): + # Drop the collection + self.client.drop_collection(collection_name) + logger.info(f"Deleted Milvus collection: {collection_name}") + + # Remove from our local cache + keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] + for key in keys_to_remove: + del self.collections[key] + else: + logger.info(f"Collection {collection_name} does not exist, nothing to delete") + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py new file mode 100644 index 00000000..6e78db48 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py @@ -0,0 +1,28 @@ +from ... schema import CollectionManagementRequest, CollectionManagementResponse +from ... schema import collection_request_queue, collection_response_queue +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class CollectionManagementRequestor(ServiceRequestor): + def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + + super(CollectionManagementRequestor, self).__init__( + pulsar_client=pulsar_client, + consumer_name = consumer, + subscription = subscriber, + request_queue=collection_request_queue, + response_queue=collection_response_queue, + request_schema=CollectionManagementRequest, + response_schema=CollectionManagementResponse, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("collection-management") + self.response_translator = TranslatorRegistry.get_response_translator("collection-management") + + def to_request(self, body): + return self.request_translator.to_pulsar(body) + + def from_response(self, message): + return self.response_translator.from_response_with_completion(message) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 6f8649f0..a1821e84 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -11,6 +11,7 @@ from . config import ConfigRequestor from . flow import FlowRequestor from . librarian import LibrarianRequestor from . knowledge import KnowledgeRequestor +from . collection_management import CollectionManagementRequestor from . embeddings import EmbeddingsRequestor from . agent import AgentRequestor @@ -66,6 +67,7 @@ global_dispatchers = { "flow": FlowRequestor, "librarian": LibrarianRequestor, "knowledge": KnowledgeRequestor, + "collection-management": CollectionManagementRequestor, } sender_dispatchers = { diff --git a/trustgraph-flow/trustgraph/librarian/collection_service.py b/trustgraph-flow/trustgraph/librarian/collection_service.py new file mode 100644 index 00000000..7a4b9e6e --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/collection_service.py @@ -0,0 +1,362 @@ +""" +Collection management service for the librarian +""" + +import asyncio +import logging +from datetime import datetime + +from .. base import AsyncProcessor, Consumer, Producer +from .. base import ConsumerMetrics, ProducerMetrics +from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config + +from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error +from .. schema import collection_request_queue, collection_response_queue +from .. schema import CollectionMetadata +from .. schema import StorageManagementRequest, StorageManagementResponse +from .. schema import vector_storage_management_topic, object_storage_management_topic, triples_storage_management_topic, storage_management_response_topic + +from .. exceptions import RequestError +from .. tables.library import LibraryTableStore + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "collection-management" +default_cassandra_host = "cassandra" +keyspace = "librarian" + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + # Get Cassandra configuration + cassandra_host = params.get("cassandra_host", default_cassandra_host) + cassandra_username = params.get("cassandra_username") + cassandra_password = params.get("cassandra_password") + + # Resolve configuration with environment variable fallback + hosts, username, password = resolve_cassandra_config( + host=cassandra_host, + username=cassandra_username, + password=cassandra_password + ) + + super(Processor, self).__init__( + **params | { + "cassandra_host": ','.join(hosts), + "cassandra_username": username + } + ) + + self.cassandra_host = hosts + self.cassandra_username = username + self.cassandra_password = password + + # Set up metrics + collection_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="collection-request" + ) + collection_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="collection-response" + ) + + # Set up consumer for collection management requests + self.collection_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=collection_request_queue, + subscriber=id, + schema=CollectionManagementRequest, + handler=self.on_collection_request, + metrics=collection_request_metrics, + ) + + # Set up producer for collection management responses + self.collection_response_producer = Producer( + client=self.pulsar_client, + topic=collection_response_queue, + schema=CollectionManagementResponse, + metrics=collection_response_metrics, + ) + + # Set up producers for storage management requests + self.vector_storage_producer = Producer( + client=self.pulsar_client, + topic=vector_storage_management_topic, + schema=StorageManagementRequest, + ) + + self.object_storage_producer = Producer( + client=self.pulsar_client, + topic=object_storage_management_topic, + schema=StorageManagementRequest, + ) + + self.triples_storage_producer = Producer( + client=self.pulsar_client, + topic=triples_storage_management_topic, + schema=StorageManagementRequest, + ) + + # Set up consumer for storage management responses + storage_response_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + self.storage_response_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=storage_management_response_topic, + subscriber=f"{id}-storage", + schema=StorageManagementResponse, + handler=self.on_storage_response, + metrics=storage_response_metrics, + ) + + # Initialize table store + self.table_store = LibraryTableStore( + cassandra_host=self.cassandra_host, + cassandra_username=self.cassandra_username, + cassandra_password=self.cassandra_password, + keyspace=keyspace + ) + + # Track pending deletion requests by user+collection + self.pending_deletions = {} # (user, collection) -> {responses_pending, responses_received, all_successful, error_messages, deletion_complete} + + async def on_collection_request(self, message): + """Handle collection management requests""" + + logger.debug(f"Collection request: {message.operation}") + + try: + if message.operation == "list-collections": + response = await self.handle_list_collections(message) + elif message.operation == "update-collection": + response = await self.handle_update_collection(message) + elif message.operation == "delete-collection": + response = await self.handle_delete_collection(message) + else: + response = CollectionManagementResponse( + success="false", + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ), + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error processing collection request: {e}", exc_info=True) + response = CollectionManagementResponse( + success="false", + error=Error( + type="processing_error", + message=str(e) + ), + timestamp=datetime.now().isoformat() + ) + + await self.collection_response_producer.send(response) + + async def on_storage_response(self, response): + """Handle storage management responses""" + logger.debug(f"Received storage response: error={response.error}") + + # Find matching deletion by checking all pending deletions + # Note: This is simplified correlation - assumes responses come back quickly + # In production, we'd want better correlation mechanism + for deletion_key, info in list(self.pending_deletions.items()): + if info["responses_pending"] > 0: + # Record this response + info["responses_received"].append(response) + info["responses_pending"] -= 1 + + # Check if this response indicates failure + if response.error and response.error.message: + info["all_successful"] = False + info["error_messages"].append(response.error.message) + logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}") + else: + logger.debug(f"Storage deletion succeeded for {deletion_key}") + + # If all responses received, signal completion + if info["responses_pending"] == 0: + logger.info(f"All storage responses received for {deletion_key}") + info["deletion_complete"].set() + + break # Only process for first matching deletion + + # For now, we'll correlate by user+collection since we don't have deletion_id in the response + # This is a simplified approach - in production we'd want better correlation + for deletion_id, info in list(self.pending_deletions.items()): + if info["responses_pending"] > 0: + # Record this response + info["responses_received"].append(response) + info["responses_pending"] -= 1 + + # Check if this response indicates failure + if response.error and response.error.message: + info["all_successful"] = False + info["error_messages"].append(response.error.message) + logger.warning(f"Storage deletion failed for {deletion_id}: {response.error.message}") + + # If all responses received, signal completion + if info["responses_pending"] == 0: + logger.info(f"All storage responses received for {deletion_id}") + info["deletion_complete"].set() + + break # Only process for first matching deletion + + async def handle_list_collections(self, message): + """Handle list collections request""" + try: + tag_filter = list(message.tag_filter) if message.tag_filter else None + collections = await self.table_store.list_collections(message.user, tag_filter) + + collection_metadata = [ + CollectionMetadata( + user=coll["user"], + collection=coll["collection"], + name=coll["name"], + description=coll["description"], + tags=coll["tags"], + created_at=coll["created_at"], + updated_at=coll["updated_at"] + ) + for coll in collections + ] + + return CollectionManagementResponse( + success="true", + collections=collection_metadata, + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error listing collections: {e}") + raise + + async def handle_update_collection(self, message): + """Handle update collection request""" + try: + # Extract fields for update + name = message.name if message.name else None + description = message.description if message.description else None + tags = list(message.tags) if message.tags else None + + updated_collection = await self.table_store.update_collection( + message.user, message.collection, name, description, tags + ) + + collection_metadata = CollectionMetadata( + user=updated_collection["user"], + collection=updated_collection["collection"], + name=updated_collection["name"], + description=updated_collection["description"], + tags=updated_collection["tags"], + created_at="", # Not returned by update + updated_at=updated_collection["updated_at"] + ) + + return CollectionManagementResponse( + success="true", + collections=[collection_metadata], + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error updating collection: {e}") + raise + + async def handle_delete_collection(self, message): + """Handle delete collection request with cascade to all storage types""" + try: + deletion_key = (message.user, message.collection) + + logger.info(f"Starting cascade deletion for {message.user}/{message.collection}") + + # Track this deletion request + self.pending_deletions[deletion_key] = { + "responses_pending": 3, # vector, object, triples + "responses_received": [], + "all_successful": True, + "error_messages": [], + "deletion_complete": asyncio.Event() + } + + # Create storage management request + storage_request = StorageManagementRequest( + operation="delete-collection", + user=message.user, + collection=message.collection + ) + + # Send delete requests to all three storage types + await self.vector_storage_producer.send(storage_request) + await self.object_storage_producer.send(storage_request) + await self.triples_storage_producer.send(storage_request) + + logger.info(f"Storage deletion requests sent for {message.user}/{message.collection}") + + # Wait for all storage responses (with timeout) + try: + await asyncio.wait_for( + self.pending_deletions[deletion_key]["deletion_complete"].wait(), + timeout=30.0 # 30 second timeout + ) + except asyncio.TimeoutError: + logger.error(f"Timeout waiting for storage responses for {deletion_key}") + self.pending_deletions[deletion_key]["all_successful"] = False + self.pending_deletions[deletion_key]["error_messages"].append("Timeout waiting for storage responses") + + # Check if all storage deletions were successful + deletion_info = self.pending_deletions.pop(deletion_key, {}) + + if deletion_info.get("all_successful", False): + # All storage deletions succeeded, now delete metadata + await self.table_store.delete_collection_metadata(message.user, message.collection) + logger.info(f"Successfully completed cascade deletion for {message.user}/{message.collection}") + + return CollectionManagementResponse( + success="true", + timestamp=datetime.now().isoformat() + ) + else: + # Some storage deletions failed + error_messages = deletion_info.get("error_messages", ["Unknown storage deletion error"]) + error_msg = "; ".join(error_messages) + logger.error(f"Cascade deletion failed for {deletion_key}: {error_msg}") + + return CollectionManagementResponse( + success="false", + error=Error( + type="storage_deletion_error", + message=f"Storage deletion failed: {error_msg}" + ), + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error in cascade deletion: {e}") + return CollectionManagementResponse( + success="false", + error=Error( + type="deletion_error", + message=f"Failed to delete collection: {str(e)}" + ), + timestamp=datetime.now().isoformat() + ) + + @staticmethod + def add_args(parser): + AsyncProcessor.add_args(parser) + add_cassandra_args(parser) + +def run(): + Processor.launch(default_ident, __doc__) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index 3ef3f40b..4ec91dfe 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -95,7 +95,7 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) index_name = ( - "d-" + msg.user + "-" + msg.collection + "-" + str(dim) + "d-" + msg.user + "-" + msg.collection ) self.ensure_index_exists(index_name, dim) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 6de08e4c..30e24bd8 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -104,7 +104,7 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) index_name = ( - "t-" + msg.user + "-" + msg.collection + "-" + str(dim) + "t-" + msg.user + "-" + msg.collection ) self.ensure_index_exists(index_name, dim) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index a0dde295..cf2757af 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,7 +6,7 @@ null. Output is a list of triples. import logging -from .... direct.cassandra import TrustGraph +from .... direct.cassandra_kg import KnowledgeGraph from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error from .... schema import Value, Triple from .... base import TriplesQueryService @@ -56,21 +56,21 @@ class Processor(TriplesQueryService): try: - table = (query.user, query.collection) + user = query.user - if table != self.table: + if user != self.table: if self.cassandra_username and self.cassandra_password: - self.tg = TrustGraph( + self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=query.user, table=query.collection, + keyspace=query.user, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = TrustGraph( + self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=query.user, table=query.collection, + keyspace=query.user, ) - self.table = table + self.table = user triples = [] @@ -78,13 +78,13 @@ class Processor(TriplesQueryService): if query.p is not None: if query.o is not None: resp = self.tg.get_spo( - query.s.value, query.p.value, query.o.value, + query.collection, query.s.value, query.p.value, query.o.value, limit=query.limit ) triples.append((query.s.value, query.p.value, query.o.value)) else: resp = self.tg.get_sp( - query.s.value, query.p.value, + query.collection, query.s.value, query.p.value, limit=query.limit ) for t in resp: @@ -92,14 +92,14 @@ class Processor(TriplesQueryService): else: if query.o is not None: resp = self.tg.get_os( - query.o.value, query.s.value, + query.collection, query.o.value, query.s.value, limit=query.limit ) for t in resp: triples.append((query.s.value, t.p, query.o.value)) else: resp = self.tg.get_s( - query.s.value, + query.collection, query.s.value, limit=query.limit ) for t in resp: @@ -108,14 +108,14 @@ class Processor(TriplesQueryService): if query.p is not None: if query.o is not None: resp = self.tg.get_po( - query.p.value, query.o.value, + query.collection, query.p.value, query.o.value, limit=query.limit ) for t in resp: triples.append((t.s, query.p.value, query.o.value)) else: resp = self.tg.get_p( - query.p.value, + query.collection, query.p.value, limit=query.limit ) for t in resp: @@ -123,13 +123,14 @@ class Processor(TriplesQueryService): else: if query.o is not None: resp = self.tg.get_o( - query.o.value, + query.collection, query.o.value, limit=query.limit ) for t in resp: triples.append((t.s, t.p, query.o.value)) else: resp = self.tg.get_all( + query.collection, limit=query.limit ) for t in resp: diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index b1d401aa..598183f2 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -3,8 +3,17 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ +import logging + from .... direct.milvus_doc_embeddings import DocVectors from .... base import DocumentEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic + +# Module logger +logger = logging.getLogger(__name__) default_ident = "de-write" default_store_uri = 'http://localhost:19530' @@ -23,6 +32,34 @@ class Processor(DocumentEmbeddingsStoreService): self.vecstore = DocVectors(store_uri) + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + async def store_document_embeddings(self, message): for emb in message.chunks: @@ -50,6 +87,48 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for document embeddings""" + try: + self.vecstore.delete_collection(message.user, message.collection) + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 1851a243..a613320a 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -12,6 +12,10 @@ import os import logging from .... base import DocumentEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -55,6 +59,34 @@ class Processor(DocumentEmbeddingsStoreService): self.last_index_name = None + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + def create_index(self, index_name, dim): self.pinecone.create_index( @@ -96,7 +128,7 @@ class Processor(DocumentEmbeddingsStoreService): dim = len(vec) index_name = ( - "d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim) + "d-" + message.metadata.user + "-" + message.metadata.collection ) if index_name != self.last_index_name: @@ -160,6 +192,54 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for document embeddings""" + try: + index_name = f"d-{message.user}-{message.collection}" + + if self.pinecone.has_index(index_name): + self.pinecone.delete_index(index_name) + logger.info(f"Deleted Pinecone index: {index_name}") + else: + logger.info(f"Index {index_name} does not exist, nothing to delete") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 6005df1f..8f393b1a 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -10,6 +10,10 @@ import uuid import logging from .... base import DocumentEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -36,6 +40,37 @@ class Processor(DocumentEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + # Set up storage management if base class attributes are available + # (they may not be in unit tests) + if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'): + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + async def store_document_embeddings(self, message): for emb in message.chunks: @@ -48,8 +83,7 @@ class Processor(DocumentEmbeddingsStoreService): dim = len(vec) collection = ( "d_" + message.metadata.user + "_" + - message.metadata.collection + "_" + - str(dim) + message.metadata.collection ) if collection != self.last_collection: @@ -99,6 +133,54 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Qdrant API key (default: None)' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for document embeddings""" + try: + collection_name = f"d_{message.user}_{message.collection}" + + if self.qdrant.collection_exists(collection_name): + self.qdrant.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + else: + logger.info(f"Collection {collection_name} does not exist, nothing to delete") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 68e56c0f..f94f2752 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -3,8 +3,17 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ +import logging + from .... direct.milvus_graph_embeddings import EntityVectors from .... base import GraphEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic + +# Module logger +logger = logging.getLogger(__name__) default_ident = "ge-write" default_store_uri = 'http://localhost:19530' @@ -23,6 +32,34 @@ class Processor(GraphEmbeddingsStoreService): self.vecstore = EntityVectors(store_uri) + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + async def store_graph_embeddings(self, message): for entity in message.entities: @@ -46,6 +83,48 @@ class Processor(GraphEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for graph embeddings""" + try: + self.vecstore.delete_collection(message.user, message.collection) + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index f73cfd22..b4d9ac5e 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -12,6 +12,10 @@ import os import logging from .... base import GraphEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -55,6 +59,34 @@ class Processor(GraphEmbeddingsStoreService): self.last_index_name = None + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + def create_index(self, index_name, dim): self.pinecone.create_index( @@ -95,7 +127,7 @@ class Processor(GraphEmbeddingsStoreService): dim = len(vec) index_name = ( - "t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim) + "t-" + message.metadata.user + "-" + message.metadata.collection ) if index_name != self.last_index_name: @@ -159,6 +191,54 @@ class Processor(GraphEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for graph embeddings""" + try: + index_name = f"t-{message.user}-{message.collection}" + + if self.pinecone.has_index(index_name): + self.pinecone.delete_index(index_name) + logger.info(f"Deleted Pinecone index: {index_name}") + else: + logger.info(f"Index {index_name} does not exist, nothing to delete") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 903702c7..2b67adf7 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -10,6 +10,10 @@ import uuid import logging from .... base import GraphEmbeddingsStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -36,10 +40,41 @@ class Processor(GraphEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + # Set up storage management if base class attributes are available + # (they may not be in unit tests) + if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'): + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=vector_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + def get_collection(self, dim, user, collection): cname = ( - "t_" + user + "_" + collection + "_" + str(dim) + "t_" + user + "_" + collection ) if cname != self.last_collection: @@ -105,6 +140,54 @@ class Processor(GraphEmbeddingsStoreService): help=f'Qdrant API key' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for graph embeddings""" + try: + collection_name = f"t_{message.user}_{message.collection}" + + if self.qdrant.collection_exists(collection_name): + self.qdrant.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + else: + logger.info(f"Collection {collection_name} does not exist, nothing to delete") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py index 8a6db66c..2ec98711 100644 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py @@ -13,7 +13,9 @@ from cassandra import ConsistencyLevel from .... schema import ExtractedObject from .... schema import RowSchema, Field -from .... base import FlowProcessor, ConsumerSpec +from .... schema import StorageManagementRequest, StorageManagementResponse +from .... schema import object_storage_management_topic, storage_management_response_topic +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config # Module logger @@ -61,7 +63,38 @@ class Processor(FlowProcessor): handler = self.on_object ) ) - + + # Set up storage management consumer and producer directly + # (FlowProcessor doesn't support topic-based specs outside of flows) + from .... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics + + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Create storage management consumer + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=object_storage_management_topic, + subscriber=f"{id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Create storage management response producer + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + # Register config handler for schema updates self.register_config_handler(self.on_schema_config) @@ -390,6 +423,100 @@ class Processor(FlowProcessor): logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True) raise + async def on_storage_management(self, msg, consumer, flow): + """Handle storage management requests for collection operations""" + logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}") + + try: + if msg.operation == "delete-collection": + await self.delete_collection(msg.user, msg.collection) + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}") + else: + logger.warning(f"Unknown storage management operation: {msg.operation}") + # Send error response + from .... schema import Error + response = StorageManagementResponse( + error=Error( + type="unknown_operation", + message=f"Unknown operation: {msg.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error handling storage management request: {e}", exc_info=True) + # Send error response + from .... schema import Error + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.send("storage-response", response) + + async def delete_collection(self, user: str, collection: str): + """Delete all data for a specific collection""" + # Connect if not already connected + self.connect_cassandra() + + # Sanitize names for safety + safe_keyspace = self.sanitize_name(user) + + # Check if keyspace exists + if safe_keyspace not in self.known_keyspaces: + # Query to verify keyspace exists + check_keyspace_cql = """ + SELECT keyspace_name FROM system_schema.keyspaces + WHERE keyspace_name = %s + """ + result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) + if not result.one(): + logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") + return + self.known_keyspaces.add(safe_keyspace) + + # Get all tables in the keyspace that might contain collection data + get_tables_cql = """ + SELECT table_name FROM system_schema.tables + WHERE keyspace_name = %s + """ + + tables = self.session.execute(get_tables_cql, (safe_keyspace,)) + tables_deleted = 0 + + for row in tables: + table_name = row.table_name + + # Check if the table has a collection column + check_column_cql = """ + SELECT column_name FROM system_schema.columns + WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection' + """ + + result = self.session.execute(check_column_cql, (safe_keyspace, table_name)) + if result.one(): + # Table has collection column, delete data for this collection + try: + delete_cql = f""" + DELETE FROM {safe_keyspace}.{table_name} + WHERE collection = %s + """ + self.session.execute(delete_cql, (collection,)) + tables_deleted += 1 + logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}") + except Exception as e: + logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}") + raise + + logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}") + def close(self): """Clean up Cassandra connections""" if self.cluster: diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 06e8f4e0..e925ece0 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -10,9 +10,13 @@ import argparse import time import logging -from .... direct.cassandra import TrustGraph +from .... direct.cassandra_kg import KnowledgeGraph from .... base import TriplesStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -50,42 +54,146 @@ class Processor(TriplesStoreService): self.cassandra_password = password self.table = None + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=triples_storage_management_topic, + subscriber=f"{id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + async def store_triples(self, message): - table = (message.metadata.user, message.metadata.collection) + user = message.metadata.user - if self.table is None or self.table != table: + if self.table is None or self.table != user: self.tg = None try: if self.cassandra_username and self.cassandra_password: - self.tg = TrustGraph( + self.tg = KnowledgeGraph( hosts=self.cassandra_host, keyspace=message.metadata.user, - table=message.metadata.collection, username=self.cassandra_username, password=self.cassandra_password ) else: - self.tg = TrustGraph( + self.tg = KnowledgeGraph( hosts=self.cassandra_host, keyspace=message.metadata.user, - table=message.metadata.collection, ) except Exception as e: logger.error(f"Exception: {e}", exc_info=True) time.sleep(1) raise e - self.table = table + self.table = user for t in message.triples: self.tg.insert( + message.metadata.collection, t.s.value, t.p.value, t.o.value ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete all data for a specific collection from the unified triples table""" + try: + # Create or reuse connection for this user's keyspace + if self.table is None or self.table != message.user: + self.tg = None + + try: + if self.cassandra_username and self.cassandra_password: + self.tg = KnowledgeGraph( + hosts=self.cassandra_host, + keyspace=message.user, + username=self.cassandra_username, + password=self.cassandra_password + ) + else: + self.tg = KnowledgeGraph( + hosts=self.cassandra_host, + keyspace=message.user, + ) + except Exception as e: + logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}") + raise + + self.table = message.user + + # Delete all triples for this collection from the unified table + # In the unified table schema, collection is the partition key + delete_cql = """ + DELETE FROM triples + WHERE collection = ? + """ + + try: + self.tg.session.execute(delete_cql, (message.collection,)) + logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}") + except Exception as e: + logger.error(f"Failed to delete collection data: {e}") + raise + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index b71c247b..6591bafc 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -13,6 +13,10 @@ import logging from falkordb import FalkorDB from .... base import TriplesStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -40,14 +44,44 @@ class Processor(TriplesStoreService): self.io = FalkorDB.from_url(graph_url).select_graph(database) - def create_node(self, uri): + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) - logger.debug(f"Create node {uri}") + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=triples_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + + def create_node(self, uri, user, collection): + + logger.debug(f"Create node {uri} for user={user}, collection={collection}") res = self.io.query( - "MERGE (n:Node {uri: $uri})", + "MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", params={ "uri": uri, + "user": user, + "collection": collection, }, ) @@ -56,14 +90,16 @@ class Processor(TriplesStoreService): time=res.run_time_ms )) - def create_literal(self, value): + def create_literal(self, value, user, collection): - logger.debug(f"Create literal {value}") + logger.debug(f"Create literal {value} for user={user}, collection={collection}") res = self.io.query( - "MERGE (n:Literal {value: $value})", + "MERGE (n:Literal {value: $value, user: $user, collection: $collection})", params={ "value": value, + "user": user, + "collection": collection, }, ) @@ -72,18 +108,20 @@ class Processor(TriplesStoreService): time=res.run_time_ms )) - def relate_node(self, src, uri, dest): + def relate_node(self, src, uri, dest, user, collection): - logger.debug(f"Create node rel {src} {uri} {dest}") + logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Node {uri: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, + "user": user, + "collection": collection, }, ) @@ -92,18 +130,20 @@ class Processor(TriplesStoreService): time=res.run_time_ms )) - def relate_literal(self, src, uri, dest): + def relate_literal(self, src, uri, dest, user, collection): - logger.debug(f"Create literal rel {src} {uri} {dest}") + logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") res = self.io.query( - "MATCH (src:Node {uri: $src}) " - "MATCH (dest:Literal {value: $dest}) " - "MERGE (src)-[:Rel {uri: $uri}]->(dest)", + "MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " + "MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " + "MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", params={ "src": src, "dest": dest, "uri": uri, + "user": user, + "collection": collection, }, ) @@ -113,17 +153,20 @@ class Processor(TriplesStoreService): )) async def store_triples(self, message): + # Extract user and collection from metadata + user = message.metadata.user if message.metadata.user else "default" + collection = message.metadata.collection if message.metadata.collection else "default" for t in message.triples: - self.create_node(t.s.value) + self.create_node(t.s.value, user, collection) if t.o.is_uri: - self.create_node(t.o.value) - self.relate_node(t.s.value, t.p.value, t.o.value) + self.create_node(t.o.value, user, collection) + self.relate_node(t.s.value, t.p.value, t.o.value, user, collection) else: - self.create_literal(t.o.value) - self.relate_literal(t.s.value, t.p.value, t.o.value) + self.create_literal(t.o.value, user, collection) + self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection) @staticmethod def add_args(parser): @@ -142,6 +185,59 @@ class Processor(TriplesStoreService): help=f'FalkorDB database (default: {default_database})' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete the collection for FalkorDB triples""" + try: + # Delete all nodes and literals for this user/collection + node_result = self.io.query( + "MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n", + params={"user": message.user, "collection": message.collection} + ) + + literal_result = self.io.query( + "MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n", + params={"user": message.user, "collection": message.collection} + ) + + logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 0996111d..04f01f3d 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -13,6 +13,10 @@ import logging from neo4j import GraphDatabase from .... base import TriplesStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -49,6 +53,34 @@ class Processor(TriplesStoreService): with self.io.session(database=self.db) as session: self.create_indexes(session) + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=triples_storage_management_topic, + subscriber=f"{self.id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + def create_indexes(self, session): # Race condition, index creation failure is ignored. Right thing @@ -285,6 +317,67 @@ class Processor(TriplesStoreService): help=f'Memgraph database (default: {default_database})' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete all data for a specific collection""" + try: + with self.io.session(database=self.db) as session: + # Delete all nodes for this user and collection + node_result = session.run( + "MATCH (n:Node {user: $user, collection: $collection}) " + "DETACH DELETE n", + user=message.user, collection=message.collection + ) + nodes_deleted = node_result.consume().counters.nodes_deleted + + # Delete all literals for this user and collection + literal_result = session.run( + "MATCH (n:Literal {user: $user, collection: $collection}) " + "DETACH DELETE n", + user=message.user, collection=message.collection + ) + literals_deleted = literal_result.consume().counters.nodes_deleted + + # Note: Relationships are automatically deleted with DETACH DELETE + + logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index c33478eb..a59f9a7e 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -12,6 +12,10 @@ import logging from neo4j import GraphDatabase from .... base import TriplesStoreService +from .... base import AsyncProcessor, Consumer, Producer +from .... base import ConsumerMetrics, ProducerMetrics +from .... schema import StorageManagementRequest, StorageManagementResponse, Error +from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -49,6 +53,34 @@ class Processor(TriplesStoreService): with self.io.session(database=self.db) as session: self.create_indexes(session) + # Set up metrics for storage management + storage_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="storage-request" + ) + storage_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="storage-response" + ) + + # Set up consumer for storage management requests + self.storage_request_consumer = Consumer( + taskgroup=self.taskgroup, + client=self.pulsar_client, + flow=None, + topic=triples_storage_management_topic, + subscriber=f"{id}-storage", + schema=StorageManagementRequest, + handler=self.on_storage_management, + metrics=storage_request_metrics, + ) + + # Set up producer for storage management responses + self.storage_response_producer = Producer( + client=self.pulsar_client, + topic=storage_management_response_topic, + schema=StorageManagementResponse, + metrics=storage_response_metrics, + ) + def create_indexes(self, session): # Race condition, index creation failure is ignored. Right thing @@ -236,6 +268,67 @@ class Processor(TriplesStoreService): help=f'Neo4j database (default: {default_database})' ) + async def on_storage_management(self, message): + """Handle storage management requests""" + logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}") + + try: + if message.operation == "delete-collection": + await self.handle_delete_collection(message) + else: + response = StorageManagementResponse( + error=Error( + type="invalid_operation", + message=f"Unknown operation: {message.operation}" + ) + ) + await self.storage_response_producer.send(response) + + except Exception as e: + logger.error(f"Error processing storage management request: {e}", exc_info=True) + response = StorageManagementResponse( + error=Error( + type="processing_error", + message=str(e) + ) + ) + await self.storage_response_producer.send(response) + + async def handle_delete_collection(self, message): + """Delete all data for a specific collection""" + try: + with self.io.session(database=self.db) as session: + # Delete all nodes for this user and collection + node_result = session.run( + "MATCH (n:Node {user: $user, collection: $collection}) " + "DETACH DELETE n", + user=message.user, collection=message.collection + ) + nodes_deleted = node_result.consume().counters.nodes_deleted + + # Delete all literals for this user and collection + literal_result = session.run( + "MATCH (n:Literal {user: $user, collection: $collection}) " + "DETACH DELETE n", + user=message.user, collection=message.collection + ) + literals_deleted = literal_result.consume().counters.nodes_deleted + + # Note: Relationships are automatically deleted with DETACH DELETE + + logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}") + + # Send success response + response = StorageManagementResponse( + error=None # No error means success + ) + await self.storage_response_producer.send(response) + logger.info(f"Successfully deleted collection {message.user}/{message.collection}") + + except Exception as e: + logger.error(f"Failed to delete collection: {e}") + raise + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index cb152c30..fb3d5a0e 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -111,6 +111,21 @@ class LibraryTableStore: ); """); + logger.debug("collections table...") + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS collections ( + user text, + collection text, + name text, + description text, + tags set, + created_at timestamp, + updated_at timestamp, + PRIMARY KEY (user, collection) + ); + """); + logger.info("Cassandra schema OK.") def prepare_statements(self): @@ -187,6 +202,43 @@ class LibraryTableStore: LIMIT 1 """) + # Collection management statements + self.insert_collection_stmt = self.cassandra.prepare(""" + INSERT INTO collections + (user, collection, name, description, tags, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """) + + self.update_collection_stmt = self.cassandra.prepare(""" + UPDATE collections + SET name = ?, description = ?, tags = ?, updated_at = ? + WHERE user = ? AND collection = ? + """) + + self.get_collection_stmt = self.cassandra.prepare(""" + SELECT collection, name, description, tags, created_at, updated_at + FROM collections + WHERE user = ? AND collection = ? + """) + + self.list_collections_stmt = self.cassandra.prepare(""" + SELECT collection, name, description, tags, created_at, updated_at + FROM collections + WHERE user = ? + """) + + self.delete_collection_stmt = self.cassandra.prepare(""" + DELETE FROM collections + WHERE user = ? AND collection = ? + """) + + self.collection_exists_stmt = self.cassandra.prepare(""" + SELECT collection + FROM collections + WHERE user = ? AND collection = ? + LIMIT 1 + """) + self.list_processing_stmt = self.cassandra.prepare(""" SELECT id, document_id, time, flow, collection, tags @@ -521,3 +573,113 @@ class LibraryTableStore: return lst + + + # Collection management methods + + async def ensure_collection_exists(self, user, collection): + """Ensure collection metadata record exists, create if not""" + try: + resp = await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.collection_exists_stmt, [user, collection] + ) + if resp: + return + import datetime + now = datetime.datetime.now() + await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.insert_collection_stmt, + [user, collection, collection, "", set(), now, now] + ) + logger.debug(f"Created collection metadata for {user}/{collection}") + except Exception as e: + logger.error(f"Error ensuring collection exists: {e}") + raise + + async def list_collections(self, user, tag_filter=None): + """List collections for a user, optionally filtered by tags""" + try: + resp = await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.list_collections_stmt, [user] + ) + collections = [] + for row in resp: + collection_data = { + "user": user, + "collection": row[0], + "name": row[1] or row[0], + "description": row[2] or "", + "tags": list(row[3]) if row[3] else [], + "created_at": row[4].isoformat() if row[4] else "", + "updated_at": row[5].isoformat() if row[5] else "" + } + if tag_filter: + collection_tags = set(collection_data["tags"]) + filter_tags = set(tag_filter) + if not filter_tags.intersection(collection_tags): + continue + collections.append(collection_data) + return collections + except Exception as e: + logger.error(f"Error listing collections: {e}") + raise + + async def update_collection(self, user, collection, name=None, description=None, tags=None): + """Update collection metadata""" + try: + resp = await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.get_collection_stmt, [user, collection] + ) + if not resp: + raise RequestError(f"Collection {collection} not found") + row = resp.one() + current_name = row[1] or collection + current_description = row[2] or "" + current_tags = set(row[3]) if row[3] else set() + new_name = name if name is not None else current_name + new_description = description if description is not None else current_description + new_tags = set(tags) if tags is not None else current_tags + import datetime + now = datetime.datetime.now() + await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.update_collection_stmt, + [new_name, new_description, new_tags, now, user, collection] + ) + return { + "user": user, "collection": collection, "name": new_name, + "description": new_description, "tags": list(new_tags), + "updated_at": now.isoformat() + } + except Exception as e: + logger.error(f"Error updating collection: {e}") + raise + + async def delete_collection_metadata(self, user, collection): + """Delete collection metadata record""" + try: + await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.delete_collection_stmt, [user, collection] + ) + logger.debug(f"Deleted collection metadata for {user}/{collection}") + except Exception as e: + logger.error(f"Error deleting collection metadata: {e}") + raise + + async def get_collection(self, user, collection): + """Get collection metadata""" + try: + resp = await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.get_collection_stmt, [user, collection] + ) + if not resp: + return None + row = resp.one() + return { + "user": user, "collection": row[0], "name": row[1] or row[0], + "description": row[2] or "", "tags": list(row[3]) if row[3] else [], + "created_at": row[4].isoformat() if row[4] else "", + "updated_at": row[5].isoformat() if row[5] else "" + } + except Exception as e: + logger.error(f"Error getting collection: {e}") + raise