Collection management (#520)

* Tech spec

* Refactored Cassanda knowledge graph for single table

* Collection management, librarian services to manage metadata and collection deletion
This commit is contained in:
cybermaggedon 2025-09-18 15:57:52 +01:00 committed by GitHub
parent 48016d8fb2
commit 13ff7d765d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 2941 additions and 425 deletions

View file

@ -0,0 +1,349 @@
# Collection Management Technical Specification
## Overview
This specification describes the collection management capabilities for TrustGraph, enabling users to have explicit control over collections that are currently implicitly created during data loading and querying operations. The feature supports four primary use cases:
1. **Collection Listing**: View all existing collections in the system
2. **Collection Deletion**: Remove unwanted collections and their associated data
3. **Collection Labeling**: Associate descriptive labels with collections for better organization
4. **Collection Tagging**: Apply tags to collections for categorization and easier discovery
## Goals
- **Explicit Collection Control**: Provide users with direct management capabilities over collections beyond implicit creation
- **Collection Visibility**: Enable users to list and inspect all collections in their environment
- **Collection Cleanup**: Allow deletion of collections that are no longer needed
- **Collection Organization**: Support labels and tags for better collection tracking and discovery
- **Metadata Management**: Associate meaningful metadata with collections for operational clarity
- **Collection Discovery**: Make it easier to find specific collections through filtering and search
- **Operational Transparency**: Provide clear visibility into collection lifecycle and usage
- **Resource Management**: Enable cleanup of unused collections to optimize resource utilization
## Background
Currently, collections in TrustGraph are implicitly created during data loading operations and query execution. While this provides convenience for users, it lacks the explicit control needed for production environments and long-term data management.
Current limitations include:
- No way to list existing collections
- No mechanism to delete unwanted collections
- No ability to associate metadata with collections for tracking purposes
- Difficulty in organizing and discovering collections over time
This specification addresses these gaps by introducing explicit collection management operations. By providing collection management APIs and commands, TrustGraph can:
- Give users full control over their collection lifecycle
- Enable better organization through labels and tags
- Support collection cleanup for resource optimization
- Improve operational visibility and management
## Technical Design
### Architecture
The collection management system will be implemented within existing TrustGraph infrastructure:
1. **Librarian Service Integration**
- Collection management operations will be added to the existing librarian service
- No new service required - leverages existing authentication and access patterns
- Handles collection listing, deletion, and metadata management
Module: trustgraph-librarian
2. **Cassandra Collection Metadata Table**
- New table in the existing librarian keyspace
- Stores collection metadata with user-scoped access
- Primary key: (user_id, collection_id) for proper multi-tenancy
Module: trustgraph-librarian
3. **Collection Management CLI**
- Command-line interface for collection operations
- Provides list, delete, label, and tag management commands
- Integrates with existing CLI framework
Module: trustgraph-cli
### Data Models
#### Cassandra Collection Metadata Table
The collection metadata will be stored in a structured Cassandra table in the librarian keyspace:
```sql
CREATE TABLE collections (
user text,
collection text,
name text,
description text,
tags set<text>,
created_at timestamp,
updated_at timestamp,
PRIMARY KEY (user, collection)
);
```
Table structure:
- **user** + **collection**: Composite primary key ensuring user isolation
- **name**: Human-readable collection name
- **description**: Detailed description of collection purpose
- **tags**: Set of tags for categorization and filtering
- **created_at**: Collection creation timestamp
- **updated_at**: Last modification timestamp
This approach allows:
- Multi-tenant collection management with user isolation
- Efficient querying by user and collection
- Flexible tagging system for organization
- Lifecycle tracking for operational insights
#### Collection Lifecycle
Collections follow a lazy-creation pattern that aligns with existing TrustGraph behavior:
1. **Lazy Creation**: Collections are automatically created when first referenced during data loading or query operations. No explicit create operation is needed.
2. **Implicit Registration**: When a collection is used (data loading, querying), the system checks if a metadata record exists. If not, a new record is created with default values:
- `name`: defaults to collection_id
- `description`: empty
- `tags`: empty set
- `created_at`: current timestamp
3. **Explicit Updates**: Users can update collection metadata (name, description, tags) through management operations after lazy creation.
4. **Explicit Deletion**: Users can delete collections, which removes both the metadata record and the underlying collection data across all store types.
5. **Multi-Store Deletion**: Collection deletion cascades across all storage backends (vector stores, object stores, triple stores) as each implements lazy creation and must support collection deletion.
Operations required:
- **Collection Use Notification**: Internal operation triggered during data loading/querying to ensure metadata record exists
- **Update Collection Metadata**: User operation to modify name, description, and tags
- **Delete Collection**: User operation to remove collection and its data across all stores
- **List Collections**: User operation to view collections with filtering by tags
#### Multi-Store Collection Management
Collections exist across multiple storage backends in TrustGraph:
- **Vector Stores**: Store embeddings and vector data for collections
- **Object Stores**: Store documents and file data for collections
- **Triple Stores**: Store graph/RDF data for collections
Each store type implements:
- **Lazy Creation**: Collections are created implicitly when data is first stored
- **Collection Deletion**: Store-specific deletion operations to remove collection data
The librarian service coordinates collection operations across all store types, ensuring consistent collection lifecycle management.
### APIs
New APIs:
- **List Collections**: Retrieve collections for a user with optional tag filtering
- **Update Collection Metadata**: Modify collection name, description, and tags
- **Delete Collection**: Remove collection and associated data with confirmation, cascading to all store types
- **Collection Use Notification** (Internal): Ensure metadata record exists when collection is referenced
Store Writer APIs (Enhanced):
- **Vector Store Collection Deletion**: Remove vector data for specified user and collection
- **Object Store Collection Deletion**: Remove object/document data for specified user and collection
- **Triple Store Collection Deletion**: Remove graph/RDF data for specified user and collection
Modified APIs:
- **Data Loading APIs**: Enhanced to trigger collection use notification for lazy metadata creation
- **Query APIs**: Enhanced to trigger collection use notification and optionally include metadata in responses
### Implementation Details
The implementation will follow existing TrustGraph patterns for service integration and CLI command structure.
#### Collection Deletion Cascade
When a user initiates collection deletion through the librarian service:
1. **Metadata Validation**: Verify collection exists and user has permission to delete
2. **Store Cascade**: Librarian coordinates deletion across all store writers:
- Vector store writer: Remove embeddings and vector indexes for the user and collection
- Object store writer: Remove documents and files for the user and collection
- Triple store writer: Remove graph data and triples for the user and collection
3. **Metadata Cleanup**: Remove collection metadata record from Cassandra
4. **Error Handling**: If any store deletion fails, maintain consistency through rollback or retry mechanisms
#### Collection Management Interface
All store writers will implement a standardized collection management interface with a common schema across store types:
**Message Schema:**
```json
{
"operation": "delete-collection",
"user": "user123",
"collection": "documents-2024",
"timestamp": "2024-01-15T10:30:00Z"
}
```
**Queue Architecture:**
- **Object Store Collection Management Queue**: Handles collection operations for object/document stores
- **Vector Store Collection Management Queue**: Handles collection operations for vector/embedding stores
- **Triple Store Collection Management Queue**: Handles collection operations for graph/RDF stores
Each store writer implements:
- **Collection Management Handler**: Separate from standard data storage handlers
- **Delete Collection Operation**: Removes all data associated with the specified collection
- **Message Processing**: Consumes from dedicated collection management queue
- **Status Reporting**: Returns success/failure status for coordination
- **Idempotent Operations**: Handles cases where collection doesn't exist (no-op)
**Initial Implementation:**
Only `delete-collection` operation will be implemented initially. The interface supports future operations like `archive-collection`, `migrate-collection`, etc.
#### Cassandra Triple Store Refactor
As part of this implementation, the Cassandra triple store will be refactored from a table-per-collection model to a unified table model:
**Current Architecture:**
- Keyspace per user, separate table per collection
- Schema: `(s, p, o)` with `PRIMARY KEY (s, p, o)`
- Table names: user collections become separate Cassandra tables
**New Architecture:**
- Keyspace per user, single "triples" table for all collections
- Schema: `(collection, s, p, o)` with `PRIMARY KEY (collection, s, p, o)`
- Collection isolation through collection partitioning
**Changes Required:**
1. **TrustGraph Class Refactor** (`trustgraph/direct/cassandra.py`):
- Remove `table` parameter from constructor, use fixed "triples" table
- Add `collection` parameter to all methods
- Update schema to include collection as first column
- **Index Updates**: New indexes will be created to support all 8 query patterns:
- Index on `(s)` for subject-based queries
- Index on `(p)` for predicate-based queries
- Index on `(o)` for object-based queries
- Note: Cassandra doesn't support multi-column secondary indexes, so these are single-column indexes
- **Query Pattern Performance**:
- ✅ `get_all()` - partition scan on `collection`
- ✅ `get_s(s)` - uses primary key efficiently (`collection, s`)
- ✅ `get_p(p)` - uses `idx_p` with `collection` filtering
- ✅ `get_o(o)` - uses `idx_o` with `collection` filtering
- ✅ `get_sp(s, p)` - uses primary key efficiently (`collection, s, p`)
- ⚠️ `get_po(p, o)` - requires `ALLOW FILTERING` (uses either `idx_p` or `idx_o` plus filtering)
- ✅ `get_os(o, s)` - uses `idx_o` with additional filtering on `s`
- ✅ `get_spo(s, p, o)` - uses full primary key efficiently
- **Note on ALLOW FILTERING**: The `get_po` query pattern requires `ALLOW FILTERING` as it needs both predicate and object constraints without a suitable compound index. This is acceptable as this query pattern is less common than subject-based queries in typical triple store usage
2. **Storage Writer Updates** (`trustgraph/storage/triples/cassandra/write.py`):
- Maintain single TrustGraph connection per user instead of per (user, collection)
- Pass collection to insert operations
- Improved resource utilization with fewer connections
3. **Query Service Updates** (`trustgraph/query/triples/cassandra/service.py`):
- Single TrustGraph connection per user
- Pass collection to all query operations
- Maintain same query logic with collection parameter
**Benefits:**
- **Simplified Collection Deletion**: Simple `DELETE FROM triples WHERE collection = ?` instead of dropping tables
- **Resource Efficiency**: Fewer database connections and table objects
- **Cross-Collection Operations**: Easier to implement operations spanning multiple collections
- **Consistent Architecture**: Aligns with unified collection metadata approach
**Migration Strategy:**
Existing table-per-collection data will need migration to the new unified schema during the upgrade process.
Collection operations will be atomic where possible and provide appropriate error handling and validation.
## Security Considerations
Collection management operations require appropriate authorization to prevent unauthorized access or deletion of collections. Access control will align with existing TrustGraph security models.
## Performance Considerations
Collection listing operations may need pagination for environments with large numbers of collections. Metadata queries should be optimized for common filtering patterns.
## Testing Strategy
Comprehensive testing will cover collection lifecycle operations, metadata management, and CLI command functionality with both unit and integration tests.
## Migration Plan
This implementation requires both metadata and storage migrations:
### Collection Metadata Migration
Existing collections will need to be registered in the new Cassandra collections metadata table. A migration process will:
- Scan existing keyspaces and tables to identify collections
- Create metadata records with default values (name=collection_id, empty description/tags)
- Preserve creation timestamps where possible
### Cassandra Triple Store Migration
The Cassandra storage refactor requires data migration from table-per-collection to unified table:
- **Pre-migration**: Identify all user keyspaces and collection tables
- **Data Transfer**: Copy triples from individual collection tables to unified "triples" table with collection
- **Schema Validation**: Ensure new primary key structure maintains query performance
- **Cleanup**: Remove old collection tables after successful migration
- **Rollback Plan**: Maintain ability to restore table-per-collection structure if needed
Migration will be performed during a maintenance window to ensure data consistency.
## Implementation Status
### ✅ Completed Components
1. **Librarian Collection Management Service** (`trustgraph-flow/trustgraph/librarian/collection_service.py`)
- Complete collection CRUD operations (list, update, delete)
- Cassandra collection metadata table integration via `LibraryTableStore`
- Async request/response handling with proper error management
- Collection deletion cascade coordination across all storage types
2. **Collection Metadata Schema** (`trustgraph-base/trustgraph/schema/services/collection.py`)
- `CollectionManagementRequest` and `CollectionManagementResponse` schemas
- `CollectionMetadata` schema for collection records
- Collection request/response queue topic definitions
3. **Storage Management Schema** (`trustgraph-base/trustgraph/schema/services/storage.py`)
- `StorageManagementRequest` and `StorageManagementResponse` schemas
- Message format for storage-level collection operations
### ❌ Missing Components
1. **Storage Management Queue Topics**
- Missing topic definitions in schema for:
- `vector_storage_management_topic`
- `object_storage_management_topic`
- `triples_storage_management_topic`
- `storage_management_response_topic`
- These are referenced by the librarian service but not yet defined
2. **Store Collection Management Handlers**
- **Vector Store Writers** (Qdrant, Milvus, Pinecone): No collection deletion handlers
- **Object Store Writers** (Cassandra): No collection deletion handlers
- **Triple Store Writers** (Cassandra, Neo4j, Memgraph, FalkorDB): No collection deletion handlers
- Need to implement `StorageManagementRequest` processing in each store writer
3. **Collection Management Interface Implementation**
- Store writers need collection management message consumers
- Collection deletion operations need to be implemented per store type
- Response handling back to librarian service
### Next Implementation Steps
1. **Define Storage Management Topics** in `trustgraph-base/trustgraph/schema/services/storage.py`
2. **Implement Collection Management Handlers** in each storage writer:
- Add `StorageManagementRequest` consumers
- Implement collection deletion operations
- Add response producers for status reporting
3. **Test End-to-End Collection Deletion** across all storage types
## Timeline
Phase 1 (Storage Topics): 1-2 days
Phase 2 (Store Handlers): 1-2 weeks depending on number of storage backends
Phase 3 (Testing & Integration): 3-5 days
## Open Questions
- Should collection deletion be soft or hard delete by default?
- What metadata fields should be required vs optional?
- Should we implement storage management handlers incrementally by store type?

View file

@ -21,7 +21,7 @@ class TestEndToEndConfigurationFlow:
"""Test complete configuration flow from environment to processors."""
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_triples_writer_env_to_connection(self, mock_cluster):
"""Test complete flow from environment variables to TrustGraph connection."""
env_vars = {
@ -117,7 +117,7 @@ class TestConfigurationPriorityEndToEnd:
"""Test configuration priority chains end-to-end."""
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_cli_override_env_end_to_end(self, mock_cluster):
"""Test that CLI parameters override environment variables end-to-end."""
env_vars = {
@ -184,7 +184,7 @@ class TestConfigurationPriorityEndToEnd:
)
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_no_config_defaults_end_to_end(self, mock_cluster):
"""Test that defaults are used when no configuration provided end-to-end."""
mock_cluster_instance = MagicMock()
@ -222,7 +222,7 @@ class TestNoBackwardCompatibilityEndToEnd:
"""Test that backward compatibility with old parameter names is removed."""
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_old_graph_params_no_longer_work_end_to_end(self, mock_cluster):
"""Test that old graph_* parameters no longer work end-to-end."""
mock_cluster_instance = MagicMock()
@ -275,7 +275,7 @@ class TestNoBackwardCompatibilityEndToEnd:
)
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_new_params_override_old_params_end_to_end(self, mock_cluster):
"""Test that new parameters override old ones when both are present end-to-end."""
mock_cluster_instance = MagicMock()
@ -334,7 +334,7 @@ class TestMultipleHostsHandling:
assert call_args.kwargs['contact_points'] == ['host1', 'host2', 'host3', 'host4', 'host5']
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra.Cluster')
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_single_host_converted_to_list(self, mock_cluster):
"""Test that single host is converted to list for TrustGraph."""
mock_cluster_instance = MagicMock()

View file

@ -13,7 +13,7 @@ import time
from unittest.mock import MagicMock
from .cassandra_test_helper import cassandra_container
from trustgraph.direct.cassandra import TrustGraph
from trustgraph.direct.cassandra_kg import KnowledgeGraph
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
@ -62,29 +62,29 @@ class TestCassandraIntegration:
print("=" * 60)
# =====================================================
# Test 1: Basic TrustGraph Operations
# Test 1: Basic KnowledgeGraph Operations
# =====================================================
print("\n1. Testing basic TrustGraph operations...")
client = TrustGraph(
print("\n1. Testing basic KnowledgeGraph operations...")
client = KnowledgeGraph(
hosts=[host],
keyspace="test_basic",
table="test_table"
keyspace="test_basic"
)
self.clients_to_close.append(client)
# Insert test data
client.insert("http://example.org/alice", "knows", "http://example.org/bob")
client.insert("http://example.org/alice", "age", "25")
client.insert("http://example.org/bob", "age", "30")
collection = "test_collection"
client.insert(collection, "http://example.org/alice", "knows", "http://example.org/bob")
client.insert(collection, "http://example.org/alice", "age", "25")
client.insert(collection, "http://example.org/bob", "age", "30")
# Test get_all
all_results = list(client.get_all(limit=10))
all_results = list(client.get_all(collection, limit=10))
assert len(all_results) == 3
print(f"✓ Stored and retrieved {len(all_results)} triples")
# Test get_s (subject query)
alice_results = list(client.get_s("http://example.org/alice", limit=10))
alice_results = list(client.get_s(collection, "http://example.org/alice", limit=10))
assert len(alice_results) == 2
alice_predicates = [r.p for r in alice_results]
assert "knows" in alice_predicates
@ -110,7 +110,7 @@ class TestCassandraIntegration:
keyspace="test_storage",
table="test_triples"
)
# Track the TrustGraph instance that will be created
# Track the KnowledgeGraph instance that will be created
self.storage_processor = storage_processor
# Create test message
@ -202,7 +202,7 @@ class TestCassandraIntegration:
# Debug: Check what was actually stored
print("Debug: Checking what was stored for Alice...")
direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10))
print(f"Direct TrustGraph results: {len(direct_results)}")
print(f"Direct KnowledgeGraph results: {len(direct_results)}")
for result in direct_results:
print(f" S=http://example.org/alice, P={result.p}, O={result.o}")

View file

@ -13,163 +13,146 @@ class TestMilvusCollectionNaming:
"""Test basic collection name creation"""
result = make_safe_collection_name(
user="test_user",
collection="test_collection",
dimension=384,
collection="test_collection",
prefix="doc"
)
assert result == "doc_test_user_test_collection_384"
assert result == "doc_test_user_test_collection"
def test_make_safe_collection_name_with_special_characters(self):
"""Test collection name creation with special characters that need sanitization"""
result = make_safe_collection_name(
user="user@domain.com",
collection="test-collection.v2",
dimension=768,
prefix="entity"
)
assert result == "entity_user_domain_com_test_collection_v2_768"
assert result == "entity_user_domain_com_test_collection_v2"
def test_make_safe_collection_name_with_unicode(self):
"""Test collection name creation with Unicode characters"""
result = make_safe_collection_name(
user="测试用户",
collection="colección_española",
dimension=512,
collection="colección_española",
prefix="doc"
)
assert result == "doc_default_colecci_n_espa_ola_512"
assert result == "doc_default_colecci_n_espa_ola"
def test_make_safe_collection_name_with_spaces(self):
"""Test collection name creation with spaces"""
result = make_safe_collection_name(
user="test user",
collection="my test collection",
dimension=256,
prefix="entity"
)
assert result == "entity_test_user_my_test_collection_256"
assert result == "entity_test_user_my_test_collection"
def test_make_safe_collection_name_with_multiple_consecutive_special_chars(self):
"""Test collection name creation with multiple consecutive special characters"""
result = make_safe_collection_name(
user="user@@@domain!!!",
collection="test---collection...v2",
dimension=384,
prefix="doc"
prefix="doc"
)
assert result == "doc_user_domain_test_collection_v2_384"
assert result == "doc_user_domain_test_collection_v2"
def test_make_safe_collection_name_with_leading_trailing_underscores(self):
"""Test collection name creation with leading/trailing special characters"""
result = make_safe_collection_name(
user="__test_user__",
collection="@@test_collection##",
dimension=128,
prefix="entity"
)
assert result == "entity_test_user_test_collection_128"
assert result == "entity_test_user_test_collection"
def test_make_safe_collection_name_empty_user(self):
"""Test collection name creation with empty user (should fallback to 'default')"""
result = make_safe_collection_name(
user="",
collection="test_collection",
dimension=384,
prefix="doc"
)
assert result == "doc_default_test_collection_384"
assert result == "doc_default_test_collection"
def test_make_safe_collection_name_empty_collection(self):
"""Test collection name creation with empty collection (should fallback to 'default')"""
result = make_safe_collection_name(
user="test_user",
collection="",
dimension=384,
prefix="doc"
)
assert result == "doc_test_user_default_384"
assert result == "doc_test_user_default"
def test_make_safe_collection_name_both_empty(self):
"""Test collection name creation with both user and collection empty"""
result = make_safe_collection_name(
user="",
collection="",
dimension=384,
prefix="doc"
)
assert result == "doc_default_default_384"
assert result == "doc_default_default"
def test_make_safe_collection_name_only_special_characters(self):
"""Test collection name creation with only special characters (should fallback to 'default')"""
result = make_safe_collection_name(
user="@@@!!!",
collection="---###",
dimension=512,
prefix="entity"
)
assert result == "entity_default_default_512"
assert result == "entity_default_default"
def test_make_safe_collection_name_whitespace_only(self):
"""Test collection name creation with whitespace-only strings"""
result = make_safe_collection_name(
user=" \n\t ",
collection=" \r\n ",
dimension=256,
prefix="doc"
)
assert result == "doc_default_default_256"
assert result == "doc_default_default"
def test_make_safe_collection_name_mixed_valid_invalid_chars(self):
"""Test collection name creation with mixed valid and invalid characters"""
result = make_safe_collection_name(
user="user123@test",
collection="coll_2023.v1",
dimension=384,
prefix="entity"
)
assert result == "entity_user123_test_coll_2023_v1_384"
assert result == "entity_user123_test_coll_2023_v1"
def test_make_safe_collection_name_different_prefixes(self):
"""Test collection name creation with different prefixes"""
user = "test_user"
collection = "test_collection"
dimension = 384
doc_result = make_safe_collection_name(user, collection, dimension, "doc")
entity_result = make_safe_collection_name(user, collection, dimension, "entity")
custom_result = make_safe_collection_name(user, collection, dimension, "custom")
assert doc_result == "doc_test_user_test_collection_384"
assert entity_result == "entity_test_user_test_collection_384"
assert custom_result == "custom_test_user_test_collection_384"
doc_result = make_safe_collection_name(user, collection, "doc")
entity_result = make_safe_collection_name(user, collection, "entity")
custom_result = make_safe_collection_name(user, collection, "custom")
assert doc_result == "doc_test_user_test_collection"
assert entity_result == "entity_test_user_test_collection"
assert custom_result == "custom_test_user_test_collection"
def test_make_safe_collection_name_different_dimensions(self):
"""Test collection name creation with different dimensions"""
"""Test collection name creation - dimension handling no longer part of function"""
user = "test_user"
collection = "test_collection"
prefix = "doc"
result_128 = make_safe_collection_name(user, collection, 128, prefix)
result_384 = make_safe_collection_name(user, collection, 384, prefix)
result_768 = make_safe_collection_name(user, collection, 768, prefix)
assert result_128 == "doc_test_user_test_collection_128"
assert result_384 == "doc_test_user_test_collection_384"
assert result_768 == "doc_test_user_test_collection_768"
# With new API, dimensions are handled separately, function always returns same result
result = make_safe_collection_name(user, collection, prefix)
assert result == "doc_test_user_test_collection"
def test_make_safe_collection_name_long_names(self):
"""Test collection name creation with very long user/collection names"""
long_user = "a" * 100
long_collection = "b" * 100
result = make_safe_collection_name(
user=long_user,
collection=long_collection,
dimension=384,
prefix="doc"
)
expected = f"doc_{long_user}_{long_collection}_384"
expected = f"doc_{long_user}_{long_collection}"
assert result == expected
assert len(result) > 200 # Verify it handles long names
@ -178,20 +161,18 @@ class TestMilvusCollectionNaming:
result = make_safe_collection_name(
user="user123",
collection="collection456",
dimension=384,
prefix="doc"
)
assert result == "doc_user123_collection456_384"
assert result == "doc_user123_collection456"
def test_make_safe_collection_name_case_sensitivity(self):
"""Test that collection name creation preserves case"""
result = make_safe_collection_name(
user="TestUser",
collection="TestCollection",
dimension=384,
prefix="Doc"
)
assert result == "Doc_TestUser_TestCollection_384"
assert result == "Doc_TestUser_TestCollection"
def test_make_safe_collection_name_realistic_examples(self):
"""Test collection name creation with realistic user/collection combinations"""
@ -202,30 +183,27 @@ class TestMilvusCollectionNaming:
("user_123", "test_collection", "user_123", "test_collection"),
("αβγ-user", "测试集合", "user", "default"),
]
for user, collection, expected_user, expected_collection in test_cases:
result = make_safe_collection_name(user, collection, 384, "doc")
assert result == f"doc_{expected_user}_{expected_collection}_384"
result = make_safe_collection_name(user, collection, "doc")
assert result == f"doc_{expected_user}_{expected_collection}"
def test_make_safe_collection_name_matches_qdrant_pattern(self):
"""Test that Milvus collection names follow similar pattern to Qdrant"""
"""Test that Milvus collection names follow similar pattern to Qdrant (but without dimension in name)"""
# Qdrant uses: "d_{user}_{collection}_{dimension}" and "t_{user}_{collection}_{dimension}"
# Milvus should use: "{prefix}_{safe_user}_{safe_collection}_{dimension}"
# New Milvus API uses: "{prefix}_{safe_user}_{safe_collection}" (dimension handled separately)
user = "test.user@domain.com"
collection = "test-collection.v2"
dimension = 384
doc_result = make_safe_collection_name(user, collection, dimension, "doc")
entity_result = make_safe_collection_name(user, collection, dimension, "entity")
# Should follow the pattern but with sanitized names
assert doc_result == "doc_test_user_domain_com_test_collection_v2_384"
assert entity_result == "entity_test_user_domain_com_test_collection_v2_384"
# Verify structure matches expected pattern (may have more underscores due to sanitization)
# The important thing is that it follows prefix_user_collection_dimension structure
doc_result = make_safe_collection_name(user, collection, "doc")
entity_result = make_safe_collection_name(user, collection, "entity")
# Should follow the pattern but with sanitized names and no dimension
assert doc_result == "doc_test_user_domain_com_test_collection_v2"
assert entity_result == "entity_test_user_domain_com_test_collection_v2"
# Verify structure matches expected pattern
assert doc_result.startswith("doc_")
assert doc_result.endswith("_384")
assert entity_result.startswith("entity_")
assert entity_result.endswith("_384")
# Dimension is no longer part of the collection name

View file

@ -32,7 +32,7 @@ class TestMilvusUserCollectionIntegration:
doc_vectors.insert(vector, "test document", user, collection)
expected_collection_name = make_safe_collection_name(
user, collection, len(vector), "doc"
user, collection, "doc"
)
# Verify collection was created with correct name
@ -58,7 +58,7 @@ class TestMilvusUserCollectionIntegration:
entity_vectors.insert(vector, "test entity", user, collection)
expected_collection_name = make_safe_collection_name(
user, collection, len(vector), "entity"
user, collection, "entity"
)
# Verify collection was created with correct name
@ -89,7 +89,7 @@ class TestMilvusUserCollectionIntegration:
result = doc_vectors.search(vector, user, collection, limit=5)
# Verify search was called with correct collection name
expected_collection_name = make_safe_collection_name(user, collection, 3, "doc")
expected_collection_name = make_safe_collection_name(user, collection, "doc")
mock_client.search.assert_called_once()
search_call = mock_client.search.call_args
assert search_call[1]["collection_name"] == expected_collection_name
@ -118,7 +118,7 @@ class TestMilvusUserCollectionIntegration:
result = entity_vectors.search(vector, user, collection, limit=5)
# Verify search was called with correct collection name
expected_collection_name = make_safe_collection_name(user, collection, 3, "entity")
expected_collection_name = make_safe_collection_name(user, collection, "entity")
mock_client.search.assert_called_once()
search_call = mock_client.search.call_args
assert search_call[1]["collection_name"] == expected_collection_name
@ -142,9 +142,9 @@ class TestMilvusUserCollectionIntegration:
collection_names = set(doc_vectors.collections.values())
expected_names = {
"doc_user1_collection1_3",
"doc_user2_collection2_3",
"doc_user1_collection2_3"
"doc_user1_collection1",
"doc_user2_collection2",
"doc_user1_collection2"
}
assert collection_names == expected_names
@ -167,9 +167,9 @@ class TestMilvusUserCollectionIntegration:
collection_names = set(entity_vectors.collections.values())
expected_names = {
"entity_user1_collection1_3",
"entity_user2_collection2_3",
"entity_user1_collection2_3"
"entity_user1_collection1",
"entity_user2_collection2",
"entity_user1_collection2"
}
assert collection_names == expected_names
@ -194,10 +194,13 @@ class TestMilvusUserCollectionIntegration:
collection_names = set(doc_vectors.collections.values())
expected_names = {
"doc_test_user_test_collection_3", # 3D
"doc_test_user_test_collection_4", # 4D
"doc_test_user_test_collection_2" # 2D
"doc_test_user_test_collection", # Same name for all dimensions
"doc_test_user_test_collection", # now stored per dimension in key
"doc_test_user_test_collection" # but collection name is the same
}
# Note: Now all dimensions use the same collection name, they are differentiated by the key
assert len(collection_names) == 1 # Only one unique collection name
assert "doc_test_user_test_collection" in collection_names
assert collection_names == expected_names
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
@ -220,7 +223,7 @@ class TestMilvusUserCollectionIntegration:
# Verify only one collection was created
assert len(doc_vectors.collections) == 1
expected_collection_name = "doc_test_user_test_collection_3"
expected_collection_name = "doc_test_user_test_collection"
assert doc_vectors.collections[(3, user, collection)] == expected_collection_name
@patch('trustgraph.direct.milvus_doc_embeddings.MilvusClient')
@ -233,10 +236,10 @@ class TestMilvusUserCollectionIntegration:
# Test various special character combinations
test_cases = [
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1_3"),
("user_123", "collection_456", "doc_user_123_collection_456_3"),
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces_3"),
("user@@@test", "collection---test", "doc_user_test_collection_test_3"),
("user@domain.com", "test-collection.v1", "doc_user_domain_com_test_collection_v1"),
("user_123", "collection_456", "doc_user_123_collection_456"),
("user with spaces", "collection with spaces", "doc_user_with_spaces_collection_with_spaces"),
("user@@@test", "collection---test", "doc_user_test_collection_test"),
]
vector = [0.1, 0.2, 0.3]
@ -250,24 +253,24 @@ class TestMilvusUserCollectionIntegration:
def test_collection_name_backward_compatibility(self):
"""Test that new collection names don't conflict with old pattern"""
# Old pattern was: {prefix}_{dimension}
# New pattern is: {prefix}_{safe_user}_{safe_collection}_{dimension}
# New pattern is: {prefix}_{safe_user}_{safe_collection}
# The new pattern should never generate names that match the old pattern
old_pattern_examples = ["doc_384", "entity_768", "doc_512"]
test_cases = [
("user", "collection", 384, "doc"),
("test", "test", 768, "entity"),
("a", "b", 512, "doc"),
("user", "collection", "doc"),
("test", "test", "entity"),
("a", "b", "doc"),
]
for user, collection, dimension, prefix in test_cases:
new_name = make_safe_collection_name(user, collection, dimension, prefix)
# New names should have at least 4 underscores (prefix_user_collection_dimension)
for user, collection, prefix in test_cases:
new_name = make_safe_collection_name(user, collection, prefix)
# New names should have at least 2 underscores (prefix_user_collection)
# Old names had only 1 underscore (prefix_dimension)
assert new_name.count('_') >= 3, f"New name {new_name} doesn't have enough underscores"
assert new_name.count('_') >= 2, f"New name {new_name} doesn't have enough underscores"
# New names should not match old pattern
assert new_name not in old_pattern_examples, f"New name {new_name} conflicts with old pattern"
@ -286,23 +289,23 @@ class TestMilvusUserCollectionIntegration:
dimension = 384
# Generate collection names
doc_name1 = make_safe_collection_name(user1, collection1, dimension, "doc")
doc_name2 = make_safe_collection_name(user2, collection2, dimension, "doc")
entity_name1 = make_safe_collection_name(user1, collection1, dimension, "entity")
entity_name2 = make_safe_collection_name(user2, collection2, dimension, "entity")
doc_name1 = make_safe_collection_name(user1, collection1, "doc")
doc_name2 = make_safe_collection_name(user2, collection2, "doc")
entity_name1 = make_safe_collection_name(user1, collection1, "entity")
entity_name2 = make_safe_collection_name(user2, collection2, "entity")
# Verify complete isolation
assert doc_name1 != doc_name2, "Document collections should be isolated"
assert entity_name1 != entity_name2, "Entity collections should be isolated"
# Verify names match expected pattern from Qdrant
# Verify names match expected pattern from new API
# Qdrant uses: d_{user}_{collection}_{dimension}, t_{user}_{collection}_{dimension}
# Milvus uses: doc_{safe_user}_{safe_collection}_{dimension}, entity_{safe_user}_{safe_collection}_{dimension}
assert doc_name1 == "doc_my_user_test_coll_1_384"
assert doc_name2 == "doc_other_user_production_data_384"
assert entity_name1 == "entity_my_user_test_coll_1_384"
assert entity_name2 == "entity_other_user_production_data_384"
# New Milvus API uses: doc_{safe_user}_{safe_collection}, entity_{safe_user}_{safe_collection}
assert doc_name1 == "doc_my_user_test_coll_1"
assert doc_name2 == "doc_other_user_production_data"
assert entity_name1 == "entity_my_user_test_coll_1"
assert entity_name2 == "entity_other_user_production_data"
# This test would have FAILED with the old implementation that used:
# - doc_384 for all document embeddings (no user/collection differentiation)

View file

@ -120,7 +120,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
chunks = await processor.query_document_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "d-test_user-test_collection-3"
expected_index_name = "d-test_user-test_collection"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -239,7 +239,7 @@ class TestPineconeDocEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_document_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
"""Test querying with vectors of different dimensions using same index"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
@ -248,37 +248,33 @@ class TestPineconeDocEmbeddingsQueryProcessor:
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
# Mock single index that handles all dimensions
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D index'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_2d.matches = [MagicMock(metadata={'doc': 'Document from 2D query'})]
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D index'})]
mock_index_4d.query.return_value = mock_results_4d
mock_results_4d.matches = [MagicMock(metadata={'doc': 'Document from 4D query'})]
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
chunks = await processor.query_document_embeddings(message)
# Verify different indexes were used
# Verify same index used for both vectors
expected_index_name = "d-test_user-test_collection"
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both dimensions
assert 'Document from 2D index' in chunks
assert 'Document from 4D index' in chunks
assert 'Document from 2D query' in chunks
assert 'Document from 4D query' in chunks
@pytest.mark.asyncio
async def test_query_document_embeddings_empty_vectors_list(self, processor):

View file

@ -148,7 +148,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
entities = await processor.query_graph_embeddings(message)
# Verify index was accessed correctly
expected_index_name = "t-test_user-test_collection-3"
expected_index_name = "t-test_user-test_collection"
processor.pinecone.Index.assert_called_once_with(expected_index_name)
# Verify query parameters
@ -265,7 +265,7 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
@pytest.mark.asyncio
async def test_query_graph_embeddings_different_vector_dimensions(self, processor):
"""Test querying with vectors of different dimensions"""
"""Test querying with vectors of different dimensions using same index"""
message = MagicMock()
message.vectors = [
[0.1, 0.2], # 2D vector
@ -274,34 +274,30 @@ class TestPineconeGraphEmbeddingsQueryProcessor:
message.limit = 5
message.user = 'test_user'
message.collection = 'test_collection'
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
processor.pinecone.Index.side_effect = mock_index_side_effect
# Mock results for different dimensions
# Mock single index that handles all dimensions
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock results for different vector queries
mock_results_2d = MagicMock()
mock_results_2d.matches = [MagicMock(metadata={'entity': 'entity_2d'})]
mock_index_2d.query.return_value = mock_results_2d
mock_results_4d = MagicMock()
mock_results_4d.matches = [MagicMock(metadata={'entity': 'entity_4d'})]
mock_index_4d.query.return_value = mock_results_4d
mock_index.query.side_effect = [mock_results_2d, mock_results_4d]
entities = await processor.query_graph_embeddings(message)
# Verify different indexes were used
# Verify same index used for both vectors
expected_index_name = "t-test_user-test_collection"
assert processor.pinecone.Index.call_count == 2
mock_index_2d.query.assert_called_once()
mock_index_4d.query.assert_called_once()
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify both queries were made
assert mock_index.query.call_count == 2
# Verify results from both dimensions
entity_values = [e.value for e in entities]
assert 'entity_2d' in entity_values

View file

@ -70,7 +70,7 @@ class TestCassandraQueryProcessor:
assert result.is_uri is False
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_spo_query(self, mock_trustgraph):
"""Test querying triples with subject, predicate, and object specified"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -98,16 +98,15 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
# Verify TrustGraph was created with correct parameters
# Verify KnowledgeGraph was created with correct parameters
mock_trustgraph.assert_called_once_with(
hosts=['localhost'],
keyspace='test_user',
table='test_collection'
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
'test_subject', 'test_predicate', 'test_object', limit=100
'test_collection', 'test_subject', 'test_predicate', 'test_object', limit=100
)
# Verify result contains the queried triple
@ -144,7 +143,7 @@ class TestCassandraQueryProcessor:
assert processor.table is None
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_sp_pattern(self, mock_trustgraph):
"""Test SP query pattern (subject and predicate, no object)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -170,14 +169,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', limit=50)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_s_pattern(self, mock_trustgraph):
"""Test S query pattern (subject only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -203,14 +202,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', limit=25)
assert len(result) == 1
assert result[0].s.value == 'test_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_p_pattern(self, mock_trustgraph):
"""Test P query pattern (predicate only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -236,14 +235,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', limit=10)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'test_predicate'
assert result[0].o.value == 'result_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_o_pattern(self, mock_trustgraph):
"""Test O query pattern (object only)"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -269,14 +268,14 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', limit=75)
assert len(result) == 1
assert result[0].s.value == 'result_subject'
assert result[0].p.value == 'result_predicate'
assert result[0].o.value == 'test_object'
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
"""Test query pattern with no constraints (get all)"""
from trustgraph.schema import TriplesQueryRequest
@ -303,7 +302,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples(query)
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.value == 'all_subject'
assert result[0].p.value == 'all_predicate'
@ -376,7 +375,7 @@ class TestCassandraQueryProcessor:
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_with_authentication(self, mock_trustgraph):
"""Test querying with username and password authentication"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -402,17 +401,16 @@ class TestCassandraQueryProcessor:
await processor.query_triples(query)
# Verify TrustGraph was created with authentication
# Verify KnowledgeGraph was created with authentication
mock_trustgraph.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='test_user',
table='test_collection',
username='authuser',
password='authpass'
)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_reuse(self, mock_trustgraph):
"""Test that TrustGraph is reused for same table"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -441,7 +439,7 @@ class TestCassandraQueryProcessor:
assert mock_trustgraph.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_table_switching(self, mock_trustgraph):
"""Test table switching creates new TrustGraph"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -463,7 +461,7 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples(query1)
assert processor.table == ('user1', 'collection1')
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
@ -476,13 +474,13 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples(query2)
assert processor.table == ('user2', 'collection2')
assert processor.table == 'user2'
# Verify TrustGraph was created twice
assert mock_trustgraph.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_exception_handling(self, mock_trustgraph):
"""Test exception handling during query execution"""
from trustgraph.schema import TriplesQueryRequest, Value
@ -506,7 +504,7 @@ class TestCassandraQueryProcessor:
await processor.query_triples(query)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
@patch('trustgraph.query.triples.cassandra.service.KnowledgeGraph')
async def test_query_triples_multiple_results(self, mock_trustgraph):
"""Test query returning multiple results"""
from trustgraph.schema import TriplesQueryRequest, Value

View file

@ -18,7 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KgStore
class TestTriplesWriterConfiguration:
"""Test Cassandra configuration in triples writer processor."""
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_environment_variable_configuration(self, mock_trust_graph):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -34,7 +34,7 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_username == 'env-user'
assert processor.cassandra_password == 'env-pass'
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_parameter_override_environment(self, mock_trust_graph):
"""Test explicit parameters override environment variables."""
env_vars = {
@ -55,7 +55,7 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_username == 'param-user'
assert processor.cassandra_password == 'param-pass'
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_no_backward_compatibility_graph_params(self, mock_trust_graph):
"""Test that old graph_* parameter names are no longer supported."""
processor = TriplesWriter(
@ -70,7 +70,7 @@ class TestTriplesWriterConfiguration:
assert processor.cassandra_username is None
assert processor.cassandra_password is None
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_default_configuration(self, mock_trust_graph):
"""Test default configuration when no params or env vars provided."""
with patch.dict(os.environ, {}, clear=True):
@ -163,7 +163,7 @@ class TestObjectsWriterConfiguration:
class TestTriplesQueryConfiguration:
"""Test Cassandra configuration in triples query processor."""
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_environment_variable_configuration(self, mock_trust_graph):
"""Test processor picks up configuration from environment variables."""
env_vars = {
@ -179,7 +179,7 @@ class TestTriplesQueryConfiguration:
assert processor.cassandra_username == 'query-env-user'
assert processor.cassandra_password == 'query-env-pass'
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_only_new_parameters_work(self, mock_trust_graph):
"""Test that only new parameters work."""
processor = TriplesQuery(
@ -379,7 +379,7 @@ class TestCommandLineArgumentHandling:
class TestConfigurationPriorityIntegration:
"""Test complete configuration priority chain in processors."""
@patch('trustgraph.direct.cassandra.TrustGraph')
@patch('trustgraph.direct.cassandra_kg.KnowledgeGraph')
def test_complete_priority_chain(self, mock_trust_graph):
"""Test CLI params > env vars > defaults priority in actual processor."""
env_vars = {

View file

@ -135,7 +135,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
await processor.store_document_embeddings(message)
# Verify index name and operations
expected_index_name = "d-test_user-test_collection-3"
expected_index_name = "d-test_user-test_collection"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
@ -203,7 +203,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection-3"
expected_index_name = "d-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
@ -299,12 +299,11 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
# All dimensions now use the same index name pattern
# Different dimensions will be handled within the same index
if "test_user" in name and "test_collection" in name:
return mock_index_2d # Just return one mock for all
return MagicMock()
processor.pinecone.Index.side_effect = mock_index_side_effect
processor.pinecone.has_index.return_value = True
@ -312,11 +311,10 @@ class TestPineconeDocEmbeddingsStorageProcessor:
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_document_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
# Verify all vectors are now stored in the same index
# (Pinecone can handle mixed dimensions in the same index)
assert processor.pinecone.Index.call_count == 3 # Called once per vector
mock_index_2d.upsert.call_count == 3 # All upserts go to same index
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunks_list(self, processor):

View file

@ -106,7 +106,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
# Verify collection existence was checked
expected_collection = 'd_test_user_test_collection_3'
expected_collection = 'd_test_user_test_collection'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
@ -309,7 +309,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message)
# Assert
expected_collection = 'd_new_user_new_collection_5'
expected_collection = 'd_new_user_new_collection'
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
@ -408,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection_3'
expected_collection = 'd_cache_user_cache_collection'
assert processor.last_collection == expected_collection
# Verify second call skipped existence check (cached)
@ -455,17 +455,16 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message)
# Assert
# Should check existence of both collections
expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3']
actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list]
assert actual_calls == expected_collections
# Should upsert to both collections
# Should check existence of the same collection (dimensions no longer create separate collections)
expected_collection = 'd_dim_user_dim_collection'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Should upsert to the same collection for both vectors
assert mock_qdrant_instance.upsert.call_count == 2
upsert_calls = mock_qdrant_instance.upsert.call_args_list
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
assert upsert_calls[0][1]['collection_name'] == expected_collection
assert upsert_calls[1][1]['collection_name'] == expected_collection
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')

View file

@ -135,7 +135,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
await processor.store_graph_embeddings(message)
# Verify index name and operations
expected_index_name = "t-test_user-test_collection-3"
expected_index_name = "t-test_user-test_collection"
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify upsert was called for each vector
@ -203,7 +203,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection-3"
expected_index_name = "t-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
@ -256,12 +256,12 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
@pytest.mark.asyncio
async def test_store_graph_embeddings_different_vector_dimensions(self, processor):
"""Test storing graph embeddings with different vector dimensions"""
"""Test storing graph embeddings with different vector dimensions to same index"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[
@ -271,30 +271,21 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
]
)
message.entities = [entity]
mock_index_2d = MagicMock()
mock_index_4d = MagicMock()
mock_index_3d = MagicMock()
def mock_index_side_effect(name):
if name.endswith("-2"):
return mock_index_2d
elif name.endswith("-4"):
return mock_index_4d
elif name.endswith("-3"):
return mock_index_3d
processor.pinecone.Index.side_effect = mock_index_side_effect
# All vectors now use the same index (no dimension in name)
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
# Verify different indexes were used for different dimensions
assert processor.pinecone.Index.call_count == 3
mock_index_2d.upsert.assert_called_once()
mock_index_4d.upsert.assert_called_once()
mock_index_3d.upsert.assert_called_once()
# Verify same index was used for all dimensions
expected_index_name = 't-test_user-test_collection'
processor.pinecone.Index.assert_called_with(expected_index_name)
# Verify all vectors were upserted to the same index
assert mock_index.upsert.call_count == 3
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entities_list(self, processor):

View file

@ -69,7 +69,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
# Assert
expected_name = 't_test_user_test_collection_512'
expected_name = 't_test_user_test_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
@ -118,7 +118,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
# Verify collection existence was checked
expected_collection = 't_test_user_test_collection_3'
expected_collection = 't_test_user_test_collection'
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Verify upsert was called
@ -156,7 +156,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
# Assert
expected_name = 't_existing_user_existing_collection_256'
expected_name = 't_existing_user_existing_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
@ -194,7 +194,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
# Assert
expected_name = 't_cache_user_cache_collection_128'
expected_name = 't_cache_user_cache_collection'
assert collection_name1 == expected_name
assert collection_name2 == expected_name

View file

@ -86,7 +86,7 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_username == 'new-user' # Only cassandra_* params work
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_with_auth(self, mock_trustgraph):
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
@ -107,18 +107,17 @@ class TestCassandraStorageProcessor:
await processor.store_triples(mock_message)
# Verify TrustGraph was called with auth parameters
# Verify KnowledgeGraph was called with auth parameters
mock_trustgraph.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user1',
table='collection1',
username='testuser',
password='testpass'
)
assert processor.table == ('user1', 'collection1')
assert processor.table == 'user1'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_switching_without_auth(self, mock_trustgraph):
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
@ -135,16 +134,15 @@ class TestCassandraStorageProcessor:
await processor.store_triples(mock_message)
# Verify TrustGraph was called without auth parameters
# Verify KnowledgeGraph was called without auth parameters
mock_trustgraph.assert_called_once_with(
hosts=['cassandra'], # Updated default
keyspace='user2',
table='collection2'
keyspace='user2'
)
assert processor.table == ('user2', 'collection2')
assert processor.table == 'user2'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_table_reuse_when_same(self, mock_trustgraph):
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
@ -168,7 +166,7 @@ class TestCassandraStorageProcessor:
assert mock_trustgraph.call_count == 1 # Should not increase
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion(self, mock_trustgraph):
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
@ -198,11 +196,11 @@ class TestCassandraStorageProcessor:
# Verify both triples were inserted
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1')
mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2')
mock_tg_instance.insert.assert_any_call('collection1', 'subject1', 'predicate1', 'object1')
mock_tg_instance.insert.assert_any_call('collection1', 'subject2', 'predicate2', 'object2')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
@ -223,7 +221,7 @@ class TestCassandraStorageProcessor:
mock_tg_instance.insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
"""Test exception handling during TrustGraph creation"""
@ -328,7 +326,7 @@ class TestCassandraStorageProcessor:
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
"""Test table switching when different tables are used in sequence"""
taskgroup_mock = MagicMock()
@ -345,7 +343,7 @@ class TestCassandraStorageProcessor:
mock_message1.triples = []
await processor.store_triples(mock_message1)
assert processor.table == ('user1', 'collection1')
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
@ -355,14 +353,14 @@ class TestCassandraStorageProcessor:
mock_message2.triples = []
await processor.store_triples(mock_message2)
assert processor.table == ('user2', 'collection2')
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
assert mock_trustgraph.call_count == 2
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
@ -386,13 +384,14 @@ class TestCassandraStorageProcessor:
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
'test_collection',
'subject with spaces & symbols',
'predicate:with/colons',
'object with "quotes" and unicode: ñáéíóú'
)
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
@patch('trustgraph.storage.triples.cassandra.write.KnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
"""Test that table remains unchanged when TrustGraph creation fails"""
taskgroup_mock = MagicMock()

View file

@ -86,15 +86,17 @@ class TestFalkorDBStorageProcessor:
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"collection": 'test_collection',
},
)
@ -104,15 +106,17 @@ class TestFalkorDBStorageProcessor:
mock_result = MagicMock()
mock_result.nodes_created = 1
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"collection": 'test_collection',
},
)
@ -121,23 +125,25 @@ class TestFalkorDBStorageProcessor:
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
dest_uri = 'http://example.com/dest'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri)
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
"user": 'test_user',
"collection": 'test_collection',
},
)
@ -146,23 +152,25 @@ class TestFalkorDBStorageProcessor:
src_uri = 'http://example.com/src'
pred_uri = 'http://example.com/pred'
literal_value = 'literal destination'
mock_result = MagicMock()
mock_result.nodes_created = 0
mock_result.run_time_ms = 5
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value)
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
"user": 'test_user',
"collection": 'test_collection',
},
)
@ -191,14 +199,16 @@ class TestFalkorDBStorageProcessor:
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
# Create object node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/object"}}),
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate"}}),
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -220,14 +230,16 @@ class TestFalkorDBStorageProcessor:
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri})",), {"params": {"uri": "http://example.com/subject"}}),
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
# Create literal object
(("MERGE (n:Literal {value: $value})",), {"params": {"value": "literal object"}}),
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate"}}),
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -408,12 +420,14 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri)
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri})",
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"collection": 'test_collection',
},
)
@ -426,11 +440,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value)
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value})",
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"collection": 'test_collection',
},
)

View file

@ -8,6 +8,7 @@ from . library import Library
from . flow import Flow
from . config import Config
from . knowledge import Knowledge
from . collection import Collection
from . exceptions import *
from . types import *
@ -68,3 +69,6 @@ class Api:
def library(self):
return Library(self)
def collection(self):
return Collection(self)

View file

@ -0,0 +1,90 @@
import datetime
import logging
from . types import CollectionMetadata
from . exceptions import *
logger = logging.getLogger(__name__)
class Collection:
def __init__(self, api):
self.api = api
def request(self, request):
return self.api.request(f"collection-management", request)
def list_collections(self, user, tag_filter=None):
input = {
"operation": "list-collections",
"user": user,
}
if tag_filter:
input["tag_filter"] = tag_filter
object = self.request(input)
try:
return [
CollectionMetadata(
user = v["user"],
collection = v["collection"],
name = v["name"],
description = v["description"],
tags = v["tags"],
created_at = v["created_at"],
updated_at = v["updated_at"]
)
for v in object["collections"]
]
except Exception as e:
logger.error("Failed to parse collection list response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def update_collection(self, user, collection, name=None, description=None, tags=None):
input = {
"operation": "update-collection",
"user": user,
"collection": collection,
}
if name is not None:
input["name"] = name
if description is not None:
input["description"] = description
if tags is not None:
input["tags"] = tags
object = self.request(input)
try:
if "collections" in object and object["collections"]:
v = object["collections"][0]
return CollectionMetadata(
user = v["user"],
collection = v["collection"],
name = v["name"],
description = v["description"],
tags = v["tags"],
created_at = v["created_at"],
updated_at = v["updated_at"]
)
return None
except Exception as e:
logger.error("Failed to parse collection update response", exc_info=True)
raise ProtocolException(f"Response not formatted correctly")
def delete_collection(self, user, collection):
input = {
"operation": "delete-collection",
"user": user,
"collection": collection,
}
object = self.request(input)
return {}

View file

@ -41,3 +41,13 @@ class ProcessingMetadata:
user : str
collection : str
tags : List[str]
@dataclasses.dataclass
class CollectionMetadata:
user : str
collection : str
name : str
description : str
tags : List[str]
created_at : str
updated_at : str

View file

@ -25,6 +25,7 @@ from .translators.objects_query import ObjectsQueryRequestTranslator, ObjectsQue
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
# Register all service translators
TranslatorRegistry.register_service(
@ -135,6 +136,12 @@ TranslatorRegistry.register_service(
StructuredDataDiagnosisResponseTranslator()
)
TranslatorRegistry.register_service(
"collection-management",
CollectionManagementRequestTranslator(),
CollectionManagementResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -0,0 +1,112 @@
from typing import Dict, Any, List
from ...schema import CollectionManagementRequest, CollectionManagementResponse, CollectionMetadata, Error
from .base import MessageTranslator
class CollectionManagementRequestTranslator(MessageTranslator):
"""Translator for CollectionManagementRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementRequest:
return CollectionManagementRequest(
operation=data.get("operation", ""),
user=data.get("user", ""),
collection=data.get("collection", ""),
timestamp=data.get("timestamp", ""),
name=data.get("name", ""),
description=data.get("description", ""),
tags=data.get("tags", []),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
tag_filter=data.get("tag_filter", []),
limit=data.get("limit", 50)
)
def from_pulsar(self, obj: CollectionManagementRequest) -> Dict[str, Any]:
result = {}
if obj.operation:
result["operation"] = obj.operation
if obj.user:
result["user"] = obj.user
if obj.collection:
result["collection"] = obj.collection
if obj.timestamp:
result["timestamp"] = obj.timestamp
if obj.name:
result["name"] = obj.name
if obj.description:
result["description"] = obj.description
if obj.tags:
result["tags"] = list(obj.tags)
if obj.created_at:
result["created_at"] = obj.created_at
if obj.updated_at:
result["updated_at"] = obj.updated_at
if obj.tag_filter:
result["tag_filter"] = list(obj.tag_filter)
if obj.limit:
result["limit"] = obj.limit
return result
class CollectionManagementResponseTranslator(MessageTranslator):
"""Translator for CollectionManagementResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementResponse:
# Handle error
error = None
if "error" in data and data["error"]:
error_data = data["error"]
error = Error(
type=error_data.get("type", ""),
message=error_data.get("message", "")
)
# Handle collections array
collections = []
if "collections" in data:
for coll_data in data["collections"]:
collections.append(CollectionMetadata(
user=coll_data.get("user", ""),
collection=coll_data.get("collection", ""),
name=coll_data.get("name", ""),
description=coll_data.get("description", ""),
tags=coll_data.get("tags", []),
created_at=coll_data.get("created_at", ""),
updated_at=coll_data.get("updated_at", "")
))
return CollectionManagementResponse(
success=data.get("success", ""),
error=error,
timestamp=data.get("timestamp", ""),
collections=collections
)
def from_pulsar(self, obj: CollectionManagementResponse) -> Dict[str, Any]:
result = {}
if obj.success:
result["success"] = obj.success
if obj.error:
result["error"] = {
"type": obj.error.type,
"message": obj.error.message
}
if obj.timestamp:
result["timestamp"] = obj.timestamp
if obj.collections:
result["collections"] = []
for coll in obj.collections:
result["collections"].append({
"user": coll.user,
"collection": coll.collection,
"name": coll.name,
"description": coll.description,
"tags": list(coll.tags) if coll.tags else [],
"created_at": coll.created_at,
"updated_at": coll.updated_at
})
return result

View file

@ -10,4 +10,6 @@ from .lookup import *
from .nlp_query import *
from .structured_query import *
from .objects_query import *
from .diagnosis import *
from .diagnosis import *
from .collection import *
from .storage import *

View file

@ -0,0 +1,60 @@
from pulsar.schema import Record, String, Integer, Array
from datetime import datetime
from ..core.primitives import Error
from ..core.topic import topic
############################################################################
# Collection management operations
# Collection metadata operations (for librarian service)
class CollectionMetadata(Record):
"""Collection metadata record"""
user = String()
collection = String()
name = String()
description = String()
tags = Array(String())
created_at = String() # ISO timestamp
updated_at = String() # ISO timestamp
############################################################################
class CollectionManagementRequest(Record):
"""Request for collection management operations"""
operation = String() # e.g., "delete-collection"
# For 'list-collections'
user = String()
collection = String()
timestamp = String() # ISO timestamp
name = String()
description = String()
tags = Array(String())
created_at = String() # ISO timestamp
updated_at = String() # ISO timestamp
# For list
tag_filter = Array(String()) # Optional filter by tags
limit = Integer()
class CollectionManagementResponse(Record):
"""Response for collection management operations"""
success = String() # "true" or "false"
error = Error() # Only populated if success is "false"
timestamp = String() # ISO timestamp
collections = Array(CollectionMetadata())
############################################################################
# Topics
collection_request_queue = topic(
'collection', kind='non-persistent', namespace='request'
)
collection_response_queue = topic(
'collection', kind='non-persistent', namespace='response'
)

View file

@ -0,0 +1,42 @@
from pulsar.schema import Record, String
from ..core.primitives import Error
from ..core.topic import topic
############################################################################
# Storage management operations
class StorageManagementRequest(Record):
"""Request for storage management operations sent to store processors"""
operation = String() # e.g., "delete-collection"
user = String()
collection = String()
class StorageManagementResponse(Record):
"""Response from storage processors for management operations"""
error = Error() # Only populated if there's an error, if null success
############################################################################
# Storage management topics
# Topics for sending collection management requests to different storage types
vector_storage_management_topic = topic(
'vector-storage-management', kind='non-persistent', namespace='request'
)
object_storage_management_topic = topic(
'object-storage-management', kind='non-persistent', namespace='request'
)
triples_storage_management_topic = topic(
'triples-storage-management', kind='non-persistent', namespace='request'
)
# Topic for receiving responses from storage processors
storage_management_response_topic = topic(
'storage-management', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -86,6 +86,9 @@ tg-list-config-items = "trustgraph.cli.list_config_items:main"
tg-get-config-item = "trustgraph.cli.get_config_item:main"
tg-put-config-item = "trustgraph.cli.put_config_item:main"
tg-delete-config-item = "trustgraph.cli.delete_config_item:main"
tg-list-collections = "trustgraph.cli.list_collections:main"
tg-update-collection = "trustgraph.cli.update_collection:main"
tg-delete-collection = "trustgraph.cli.delete_collection:main"
[tool.setuptools.packages.find]
include = ["trustgraph*"]

View file

@ -0,0 +1,72 @@
"""
Delete a collection and all its data
"""
import argparse
import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = "trustgraph"
def delete_collection(url, user, collection, confirm):
if not confirm:
response = input(f"Are you sure you want to delete collection '{collection}' and all its data? (y/N): ")
if response.lower() not in ['y', 'yes']:
print("Operation cancelled.")
return
api = Api(url).collection()
api.delete_collection(user=user, collection=collection)
print(f"Collection '{collection}' deleted successfully.")
def main():
parser = argparse.ArgumentParser(
prog='tg-delete-collection',
description=__doc__,
)
parser.add_argument(
'collection',
help='Collection ID to delete'
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-y', '--yes',
action='store_true',
help='Skip confirmation prompt'
)
args = parser.parse_args()
try:
delete_collection(
url = args.api_url,
user = args.user,
collection = args.collection,
confirm = args.yes
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,85 @@
"""
List collections for a user
"""
import argparse
import os
import tabulate
from trustgraph.api import Api
import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = "trustgraph"
def list_collections(url, user, tag_filter):
api = Api(url).collection()
collections = api.list_collections(user=user, tag_filter=tag_filter)
if len(collections) == 0:
print("No collections.")
return
table = []
for collection in collections:
table.append([
collection.collection,
collection.name,
collection.description,
", ".join(collection.tags),
collection.created_at,
collection.updated_at
])
headers = ["Collection", "Name", "Description", "Tags", "Created", "Updated"]
print(tabulate.tabulate(
table,
headers=headers,
tablefmt="pretty",
stralign="left",
maxcolwidths=[20, 30, 50, 30, 19, 19],
))
def main():
parser = argparse.ArgumentParser(
prog='tg-list-collections',
description=__doc__,
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-t', '--tag-filter',
action='append',
help='Filter by tags (can be specified multiple times)'
)
args = parser.parse_args()
try:
list_collections(
url = args.api_url,
user = args.user,
tag_filter = args.tag_filter
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,103 @@
"""
Update collection metadata
"""
import argparse
import os
import tabulate
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = "trustgraph"
def update_collection(url, user, collection, name, description, tags):
api = Api(url).collection()
result = api.update_collection(
user=user,
collection=collection,
name=name,
description=description,
tags=tags
)
if result:
print(f"Collection '{collection}' updated successfully.")
table = []
table.append(("Collection", result.collection))
table.append(("Name", result.name))
table.append(("Description", result.description))
table.append(("Tags", ", ".join(result.tags)))
table.append(("Updated", result.updated_at))
print(tabulate.tabulate(
table,
tablefmt="pretty",
stralign="left",
maxcolwidths=[None, 67],
))
else:
print(f"Failed to update collection '{collection}'.")
def main():
parser = argparse.ArgumentParser(
prog='tg-update-collection',
description=__doc__,
)
parser.add_argument(
'collection',
help='Collection ID to update'
)
parser.add_argument(
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-n', '--name',
help='Collection name'
)
parser.add_argument(
'-d', '--description',
help='Collection description'
)
parser.add_argument(
'-t', '--tag',
action='append',
dest='tags',
help='Collection tags (can be specified multiple times)'
)
args = parser.parse_args()
try:
update_collection(
url = args.api_url,
user = args.user,
collection = args.collection,
name = args.name,
description = args.description,
tags = args.tags
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -6,18 +6,18 @@ from ssl import SSLContext, PROTOCOL_TLSv1_2
# Global list to track clusters for cleanup
_active_clusters = []
class TrustGraph:
class KnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", table="default", username=None, password=None
keyspace="trustgraph", username=None, password=None
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.table = table
self.table = "triples" # Fixed table name for unified schema
self.username = username
if username and password:
@ -55,13 +55,19 @@ class TrustGraph:
self.session.execute(f"""
create table if not exists {self.table} (
collection text,
s text,
p text,
o text,
PRIMARY KEY (s, p, o)
PRIMARY KEY (collection, s, p, o)
);
""");
self.session.execute(f"""
create index if not exists {self.table}_s
ON {self.table} (s);
""");
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
@ -72,58 +78,66 @@ class TrustGraph:
ON {self.table} (o);
""");
def insert(self, s, p, o):
def insert(self, collection, s, p, o):
self.session.execute(
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
(s, p, o)
f"insert into {self.table} (collection, s, p, o) values (%s, %s, %s, %s)",
(collection, s, p, o)
)
def get_all(self, limit=50):
def get_all(self, collection, limit=50):
return self.session.execute(
f"select s, p, o from {self.table} limit {limit}"
f"select s, p, o from {self.table} where collection = %s limit {limit}",
(collection,)
)
def get_s(self, s, limit=10):
def get_s(self, collection, s, limit=10):
return self.session.execute(
f"select p, o from {self.table} where s = %s limit {limit}",
(s,)
f"select p, o from {self.table} where collection = %s and s = %s limit {limit}",
(collection, s)
)
def get_p(self, p, limit=10):
def get_p(self, collection, p, limit=10):
return self.session.execute(
f"select s, o from {self.table} where p = %s limit {limit}",
(p,)
f"select s, o from {self.table} where collection = %s and p = %s limit {limit}",
(collection, p)
)
def get_o(self, o, limit=10):
def get_o(self, collection, o, limit=10):
return self.session.execute(
f"select s, p from {self.table} where o = %s limit {limit}",
(o,)
f"select s, p from {self.table} where collection = %s and o = %s limit {limit}",
(collection, o)
)
def get_sp(self, s, p, limit=10):
def get_sp(self, collection, s, p, limit=10):
return self.session.execute(
f"select o from {self.table} where s = %s and p = %s limit {limit}",
(s, p)
f"select o from {self.table} where collection = %s and s = %s and p = %s limit {limit}",
(collection, s, p)
)
def get_po(self, p, o, limit=10):
def get_po(self, collection, p, o, limit=10):
return self.session.execute(
f"select s from {self.table} where p = %s and o = %s limit {limit} allow filtering",
(p, o)
f"select s from {self.table} where collection = %s and p = %s and o = %s limit {limit} allow filtering",
(collection, p, o)
)
def get_os(self, o, s, limit=10):
def get_os(self, collection, o, s, limit=10):
return self.session.execute(
f"select p from {self.table} where o = %s and s = %s limit {limit}",
(o, s)
f"select p from {self.table} where collection = %s and o = %s and s = %s limit {limit} allow filtering",
(collection, o, s)
)
def get_spo(self, s, p, o, limit=10):
def get_spo(self, collection, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
f"""select s as x from {self.table} where collection = %s and s = %s and p = %s and o = %s limit {limit}""",
(collection, s, p, o)
)
def delete_collection(self, collection):
"""Delete all triples for a specific collection"""
self.session.execute(
f"delete from {self.table} where collection = %s",
(collection,)
)
def close(self):

View file

@ -6,7 +6,7 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, dimension, prefix):
def make_safe_collection_name(user, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Milvus only allows letters, numbers, and underscores.
@ -26,7 +26,7 @@ def make_safe_collection_name(user, collection, dimension, prefix):
safe_user = sanitize(user)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
return f"{prefix}_{safe_user}_{safe_collection}"
class DocVectors:
@ -51,7 +51,7 @@ class DocVectors:
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
collection_name = make_safe_collection_name(user, collection, self.prefix)
pkey_field = FieldSchema(
name="id",
@ -162,3 +162,20 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")

View file

@ -6,7 +6,7 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, dimension, prefix):
def make_safe_collection_name(user, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Milvus only allows letters, numbers, and underscores.
@ -26,7 +26,7 @@ def make_safe_collection_name(user, collection, dimension, prefix):
safe_user = sanitize(user)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
return f"{prefix}_{safe_user}_{safe_collection}"
class EntityVectors:
@ -51,7 +51,7 @@ class EntityVectors:
def init_collection(self, dimension, user, collection):
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
collection_name = make_safe_collection_name(user, collection, self.prefix)
pkey_field = FieldSchema(
name="id",
@ -162,3 +162,20 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
"""Delete a collection for the given user and collection"""
collection_name = make_safe_collection_name(user, collection, self.prefix)
# Check if collection exists
if self.client.has_collection(collection_name):
# Drop the collection
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")

View file

@ -0,0 +1,28 @@
from ... schema import CollectionManagementRequest, CollectionManagementResponse
from ... schema import collection_request_queue, collection_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class CollectionManagementRequestor(ServiceRequestor):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
super(CollectionManagementRequestor, self).__init__(
pulsar_client=pulsar_client,
consumer_name = consumer,
subscription = subscriber,
request_queue=collection_request_queue,
response_queue=collection_response_queue,
request_schema=CollectionManagementRequest,
response_schema=CollectionManagementResponse,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("collection-management")
self.response_translator = TranslatorRegistry.get_response_translator("collection-management")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)

View file

@ -11,6 +11,7 @@ from . config import ConfigRequestor
from . flow import FlowRequestor
from . librarian import LibrarianRequestor
from . knowledge import KnowledgeRequestor
from . collection_management import CollectionManagementRequestor
from . embeddings import EmbeddingsRequestor
from . agent import AgentRequestor
@ -66,6 +67,7 @@ global_dispatchers = {
"flow": FlowRequestor,
"librarian": LibrarianRequestor,
"knowledge": KnowledgeRequestor,
"collection-management": CollectionManagementRequestor,
}
sender_dispatchers = {

View file

@ -0,0 +1,362 @@
"""
Collection management service for the librarian
"""
import asyncio
import logging
from datetime import datetime
from .. base import AsyncProcessor, Consumer, Producer
from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error
from .. schema import collection_request_queue, collection_response_queue
from .. schema import CollectionMetadata
from .. schema import StorageManagementRequest, StorageManagementResponse
from .. schema import vector_storage_management_topic, object_storage_management_topic, triples_storage_management_topic, storage_management_response_topic
from .. exceptions import RequestError
from .. tables.library import LibraryTableStore
# Module logger
logger = logging.getLogger(__name__)
default_ident = "collection-management"
default_cassandra_host = "cassandra"
keyspace = "librarian"
class Processor(AsyncProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
# Get Cassandra configuration
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
)
super(Processor, self).__init__(
**params | {
"cassandra_host": ','.join(hosts),
"cassandra_username": username
}
)
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
# Set up metrics
collection_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="collection-request"
)
collection_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="collection-response"
)
# Set up consumer for collection management requests
self.collection_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=collection_request_queue,
subscriber=id,
schema=CollectionManagementRequest,
handler=self.on_collection_request,
metrics=collection_request_metrics,
)
# Set up producer for collection management responses
self.collection_response_producer = Producer(
client=self.pulsar_client,
topic=collection_response_queue,
schema=CollectionManagementResponse,
metrics=collection_response_metrics,
)
# Set up producers for storage management requests
self.vector_storage_producer = Producer(
client=self.pulsar_client,
topic=vector_storage_management_topic,
schema=StorageManagementRequest,
)
self.object_storage_producer = Producer(
client=self.pulsar_client,
topic=object_storage_management_topic,
schema=StorageManagementRequest,
)
self.triples_storage_producer = Producer(
client=self.pulsar_client,
topic=triples_storage_management_topic,
schema=StorageManagementRequest,
)
# Set up consumer for storage management responses
storage_response_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-response"
)
self.storage_response_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=storage_management_response_topic,
subscriber=f"{id}-storage",
schema=StorageManagementResponse,
handler=self.on_storage_response,
metrics=storage_response_metrics,
)
# Initialize table store
self.table_store = LibraryTableStore(
cassandra_host=self.cassandra_host,
cassandra_username=self.cassandra_username,
cassandra_password=self.cassandra_password,
keyspace=keyspace
)
# Track pending deletion requests by user+collection
self.pending_deletions = {} # (user, collection) -> {responses_pending, responses_received, all_successful, error_messages, deletion_complete}
async def on_collection_request(self, message):
"""Handle collection management requests"""
logger.debug(f"Collection request: {message.operation}")
try:
if message.operation == "list-collections":
response = await self.handle_list_collections(message)
elif message.operation == "update-collection":
response = await self.handle_update_collection(message)
elif message.operation == "delete-collection":
response = await self.handle_delete_collection(message)
else:
response = CollectionManagementResponse(
success="false",
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
),
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error processing collection request: {e}", exc_info=True)
response = CollectionManagementResponse(
success="false",
error=Error(
type="processing_error",
message=str(e)
),
timestamp=datetime.now().isoformat()
)
await self.collection_response_producer.send(response)
async def on_storage_response(self, response):
"""Handle storage management responses"""
logger.debug(f"Received storage response: error={response.error}")
# Find matching deletion by checking all pending deletions
# Note: This is simplified correlation - assumes responses come back quickly
# In production, we'd want better correlation mechanism
for deletion_key, info in list(self.pending_deletions.items()):
if info["responses_pending"] > 0:
# Record this response
info["responses_received"].append(response)
info["responses_pending"] -= 1
# Check if this response indicates failure
if response.error and response.error.message:
info["all_successful"] = False
info["error_messages"].append(response.error.message)
logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}")
else:
logger.debug(f"Storage deletion succeeded for {deletion_key}")
# If all responses received, signal completion
if info["responses_pending"] == 0:
logger.info(f"All storage responses received for {deletion_key}")
info["deletion_complete"].set()
break # Only process for first matching deletion
# For now, we'll correlate by user+collection since we don't have deletion_id in the response
# This is a simplified approach - in production we'd want better correlation
for deletion_id, info in list(self.pending_deletions.items()):
if info["responses_pending"] > 0:
# Record this response
info["responses_received"].append(response)
info["responses_pending"] -= 1
# Check if this response indicates failure
if response.error and response.error.message:
info["all_successful"] = False
info["error_messages"].append(response.error.message)
logger.warning(f"Storage deletion failed for {deletion_id}: {response.error.message}")
# If all responses received, signal completion
if info["responses_pending"] == 0:
logger.info(f"All storage responses received for {deletion_id}")
info["deletion_complete"].set()
break # Only process for first matching deletion
async def handle_list_collections(self, message):
"""Handle list collections request"""
try:
tag_filter = list(message.tag_filter) if message.tag_filter else None
collections = await self.table_store.list_collections(message.user, tag_filter)
collection_metadata = [
CollectionMetadata(
user=coll["user"],
collection=coll["collection"],
name=coll["name"],
description=coll["description"],
tags=coll["tags"],
created_at=coll["created_at"],
updated_at=coll["updated_at"]
)
for coll in collections
]
return CollectionManagementResponse(
success="true",
collections=collection_metadata,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error listing collections: {e}")
raise
async def handle_update_collection(self, message):
"""Handle update collection request"""
try:
# Extract fields for update
name = message.name if message.name else None
description = message.description if message.description else None
tags = list(message.tags) if message.tags else None
updated_collection = await self.table_store.update_collection(
message.user, message.collection, name, description, tags
)
collection_metadata = CollectionMetadata(
user=updated_collection["user"],
collection=updated_collection["collection"],
name=updated_collection["name"],
description=updated_collection["description"],
tags=updated_collection["tags"],
created_at="", # Not returned by update
updated_at=updated_collection["updated_at"]
)
return CollectionManagementResponse(
success="true",
collections=[collection_metadata],
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error updating collection: {e}")
raise
async def handle_delete_collection(self, message):
"""Handle delete collection request with cascade to all storage types"""
try:
deletion_key = (message.user, message.collection)
logger.info(f"Starting cascade deletion for {message.user}/{message.collection}")
# Track this deletion request
self.pending_deletions[deletion_key] = {
"responses_pending": 3, # vector, object, triples
"responses_received": [],
"all_successful": True,
"error_messages": [],
"deletion_complete": asyncio.Event()
}
# Create storage management request
storage_request = StorageManagementRequest(
operation="delete-collection",
user=message.user,
collection=message.collection
)
# Send delete requests to all three storage types
await self.vector_storage_producer.send(storage_request)
await self.object_storage_producer.send(storage_request)
await self.triples_storage_producer.send(storage_request)
logger.info(f"Storage deletion requests sent for {message.user}/{message.collection}")
# Wait for all storage responses (with timeout)
try:
await asyncio.wait_for(
self.pending_deletions[deletion_key]["deletion_complete"].wait(),
timeout=30.0 # 30 second timeout
)
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for storage responses for {deletion_key}")
self.pending_deletions[deletion_key]["all_successful"] = False
self.pending_deletions[deletion_key]["error_messages"].append("Timeout waiting for storage responses")
# Check if all storage deletions were successful
deletion_info = self.pending_deletions.pop(deletion_key, {})
if deletion_info.get("all_successful", False):
# All storage deletions succeeded, now delete metadata
await self.table_store.delete_collection_metadata(message.user, message.collection)
logger.info(f"Successfully completed cascade deletion for {message.user}/{message.collection}")
return CollectionManagementResponse(
success="true",
timestamp=datetime.now().isoformat()
)
else:
# Some storage deletions failed
error_messages = deletion_info.get("error_messages", ["Unknown storage deletion error"])
error_msg = "; ".join(error_messages)
logger.error(f"Cascade deletion failed for {deletion_key}: {error_msg}")
return CollectionManagementResponse(
success="false",
error=Error(
type="storage_deletion_error",
message=f"Storage deletion failed: {error_msg}"
),
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error in cascade deletion: {e}")
return CollectionManagementResponse(
success="false",
error=Error(
type="deletion_error",
message=f"Failed to delete collection: {str(e)}"
),
timestamp=datetime.now().isoformat()
)
@staticmethod
def add_args(parser):
AsyncProcessor.add_args(parser)
add_cassandra_args(parser)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -95,7 +95,7 @@ class Processor(DocumentEmbeddingsQueryService):
dim = len(vec)
index_name = (
"d-" + msg.user + "-" + msg.collection + "-" + str(dim)
"d-" + msg.user + "-" + msg.collection
)
self.ensure_index_exists(index_name, dim)

View file

@ -104,7 +104,7 @@ class Processor(GraphEmbeddingsQueryService):
dim = len(vec)
index_name = (
"t-" + msg.user + "-" + msg.collection + "-" + str(dim)
"t-" + msg.user + "-" + msg.collection
)
self.ensure_index_exists(index_name, dim)

View file

@ -6,7 +6,7 @@ null. Output is a list of triples.
import logging
from .... direct.cassandra import TrustGraph
from .... direct.cassandra_kg import KnowledgeGraph
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Value, Triple
from .... base import TriplesQueryService
@ -56,21 +56,21 @@ class Processor(TriplesQueryService):
try:
table = (query.user, query.collection)
user = query.user
if table != self.table:
if user != self.table:
if self.cassandra_username and self.cassandra_password:
self.tg = TrustGraph(
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=query.user, table=query.collection,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = TrustGraph(
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=query.user, table=query.collection,
keyspace=query.user,
)
self.table = table
self.table = user
triples = []
@ -78,13 +78,13 @@ class Processor(TriplesQueryService):
if query.p is not None:
if query.o is not None:
resp = self.tg.get_spo(
query.s.value, query.p.value, query.o.value,
query.collection, query.s.value, query.p.value, query.o.value,
limit=query.limit
)
triples.append((query.s.value, query.p.value, query.o.value))
else:
resp = self.tg.get_sp(
query.s.value, query.p.value,
query.collection, query.s.value, query.p.value,
limit=query.limit
)
for t in resp:
@ -92,14 +92,14 @@ class Processor(TriplesQueryService):
else:
if query.o is not None:
resp = self.tg.get_os(
query.o.value, query.s.value,
query.collection, query.o.value, query.s.value,
limit=query.limit
)
for t in resp:
triples.append((query.s.value, t.p, query.o.value))
else:
resp = self.tg.get_s(
query.s.value,
query.collection, query.s.value,
limit=query.limit
)
for t in resp:
@ -108,14 +108,14 @@ class Processor(TriplesQueryService):
if query.p is not None:
if query.o is not None:
resp = self.tg.get_po(
query.p.value, query.o.value,
query.collection, query.p.value, query.o.value,
limit=query.limit
)
for t in resp:
triples.append((t.s, query.p.value, query.o.value))
else:
resp = self.tg.get_p(
query.p.value,
query.collection, query.p.value,
limit=query.limit
)
for t in resp:
@ -123,13 +123,14 @@ class Processor(TriplesQueryService):
else:
if query.o is not None:
resp = self.tg.get_o(
query.o.value,
query.collection, query.o.value,
limit=query.limit
)
for t in resp:
triples.append((t.s, t.p, query.o.value))
else:
resp = self.tg.get_all(
query.collection,
limit=query.limit
)
for t in resp:

View file

@ -3,8 +3,17 @@
Accepts entity/vector pairs and writes them to a Milvus store.
"""
import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
default_ident = "de-write"
default_store_uri = 'http://localhost:19530'
@ -23,6 +32,34 @@ class Processor(DocumentEmbeddingsStoreService):
self.vecstore = DocVectors(store_uri)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_document_embeddings(self, message):
for emb in message.chunks:
@ -50,6 +87,48 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import os
import logging
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -55,6 +59,34 @@ class Processor(DocumentEmbeddingsStoreService):
self.last_index_name = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_index(self, index_name, dim):
self.pinecone.create_index(
@ -96,7 +128,7 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
index_name = (
"d-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
"d-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
@ -160,6 +192,54 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
index_name = f"d-{message.user}-{message.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -10,6 +10,10 @@ import uuid
import logging
from .... base import DocumentEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -36,6 +40,37 @@ class Processor(DocumentEmbeddingsStoreService):
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
# (they may not be in unit tests)
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_document_embeddings(self, message):
for emb in message.chunks:
@ -48,8 +83,7 @@ class Processor(DocumentEmbeddingsStoreService):
dim = len(vec)
collection = (
"d_" + message.metadata.user + "_" +
message.metadata.collection + "_" +
str(dim)
message.metadata.collection
)
if collection != self.last_collection:
@ -99,6 +133,54 @@ class Processor(DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for document embeddings"""
try:
collection_name = f"d_{message.user}_{message.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -3,8 +3,17 @@
Accepts entity/vector pairs and writes them to a Milvus store.
"""
import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
default_ident = "ge-write"
default_store_uri = 'http://localhost:19530'
@ -23,6 +32,34 @@ class Processor(GraphEmbeddingsStoreService):
self.vecstore = EntityVectors(store_uri)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_graph_embeddings(self, message):
for entity in message.entities:
@ -46,6 +83,48 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
self.vecstore.delete_collection(message.user, message.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import os
import logging
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -55,6 +59,34 @@ class Processor(GraphEmbeddingsStoreService):
self.last_index_name = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_index(self, index_name, dim):
self.pinecone.create_index(
@ -95,7 +127,7 @@ class Processor(GraphEmbeddingsStoreService):
dim = len(vec)
index_name = (
"t-" + message.metadata.user + "-" + message.metadata.collection + "-" + str(dim)
"t-" + message.metadata.user + "-" + message.metadata.collection
)
if index_name != self.last_index_name:
@ -159,6 +191,54 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
index_name = f"t-{message.user}-{message.collection}"
if self.pinecone.has_index(index_name):
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
else:
logger.info(f"Index {index_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -10,6 +10,10 @@ import uuid
import logging
from .... base import GraphEmbeddingsStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import vector_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -36,10 +40,41 @@ class Processor(GraphEmbeddingsStoreService):
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
# Set up storage management if base class attributes are available
# (they may not be in unit tests)
if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=vector_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def get_collection(self, dim, user, collection):
cname = (
"t_" + user + "_" + collection + "_" + str(dim)
"t_" + user + "_" + collection
)
if cname != self.last_collection:
@ -105,6 +140,54 @@ class Processor(GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for graph embeddings"""
try:
collection_name = f"t_{message.user}_{message.collection}"
if self.qdrant.collection_exists(collection_name):
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
else:
logger.info(f"Collection {collection_name} does not exist, nothing to delete")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -13,7 +13,9 @@ from cassandra import ConsistencyLevel
from .... schema import ExtractedObject
from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... schema import StorageManagementRequest, StorageManagementResponse
from .... schema import object_storage_management_topic, storage_management_response_topic
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
# Module logger
@ -61,7 +63,38 @@ class Processor(FlowProcessor):
handler = self.on_object
)
)
# Set up storage management consumer and producer directly
# (FlowProcessor doesn't support topic-based specs outside of flows)
from .... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Create storage management consumer
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=object_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Create storage management response producer
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
# Register config handler for schema updates
self.register_config_handler(self.on_schema_config)
@ -390,6 +423,100 @@ class Processor(FlowProcessor):
logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True)
raise
async def on_storage_management(self, msg, consumer, flow):
"""Handle storage management requests for collection operations"""
logger.info(f"Received storage management request: {msg.operation} for {msg.user}/{msg.collection}")
try:
if msg.operation == "delete-collection":
await self.delete_collection(msg.user, msg.collection)
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {msg.user}/{msg.collection}")
else:
logger.warning(f"Unknown storage management operation: {msg.operation}")
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="unknown_operation",
message=f"Unknown operation: {msg.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error handling storage management request: {e}", exc_info=True)
# Send error response
from .... schema import Error
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.send("storage-response", response)
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection"""
# Connect if not already connected
self.connect_cassandra()
# Sanitize names for safety
safe_keyspace = self.sanitize_name(user)
# Check if keyspace exists
if safe_keyspace not in self.known_keyspaces:
# Query to verify keyspace exists
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
if not result.one():
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(safe_keyspace)
# Get all tables in the keyspace that might contain collection data
get_tables_cql = """
SELECT table_name FROM system_schema.tables
WHERE keyspace_name = %s
"""
tables = self.session.execute(get_tables_cql, (safe_keyspace,))
tables_deleted = 0
for row in tables:
table_name = row.table_name
# Check if the table has a collection column
check_column_cql = """
SELECT column_name FROM system_schema.columns
WHERE keyspace_name = %s AND table_name = %s AND column_name = 'collection'
"""
result = self.session.execute(check_column_cql, (safe_keyspace, table_name))
if result.one():
# Table has collection column, delete data for this collection
try:
delete_cql = f"""
DELETE FROM {safe_keyspace}.{table_name}
WHERE collection = %s
"""
self.session.execute(delete_cql, (collection,))
tables_deleted += 1
logger.info(f"Deleted collection {collection} from table {safe_keyspace}.{table_name}")
except Exception as e:
logger.error(f"Failed to delete from table {safe_keyspace}.{table_name}: {e}")
raise
logger.info(f"Deleted collection {collection} from {tables_deleted} tables in keyspace {safe_keyspace}")
def close(self):
"""Clean up Cassandra connections"""
if self.cluster:

View file

@ -10,9 +10,13 @@ import argparse
import time
import logging
from .... direct.cassandra import TrustGraph
from .... direct.cassandra_kg import KnowledgeGraph
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -50,42 +54,146 @@ class Processor(TriplesStoreService):
self.cassandra_password = password
self.table = None
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
async def store_triples(self, message):
table = (message.metadata.user, message.metadata.collection)
user = message.metadata.user
if self.table is None or self.table != table:
if self.table is None or self.table != user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = TrustGraph(
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = TrustGraph(
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.metadata.user,
table=message.metadata.collection,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self.table = table
self.table = user
for t in message.triples:
self.tg.insert(
message.metadata.collection,
t.s.value,
t.p.value,
t.o.value
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection from the unified triples table"""
try:
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != message.user:
self.tg = None
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KnowledgeGraph(
hosts=self.cassandra_host,
keyspace=message.user,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {message.user}: {e}")
raise
self.table = message.user
# Delete all triples for this collection from the unified table
# In the unified table schema, collection is the partition key
delete_cql = """
DELETE FROM triples
WHERE collection = ?
"""
try:
self.tg.session.execute(delete_cql, (message.collection,))
logger.info(f"Deleted all triples for collection {message.collection} from keyspace {message.user}")
except Exception as e:
logger.error(f"Failed to delete collection data: {e}")
raise
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
@staticmethod
def add_args(parser):

View file

@ -13,6 +13,10 @@ import logging
from falkordb import FalkorDB
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -40,14 +44,44 @@ class Processor(TriplesStoreService):
self.io = FalkorDB.from_url(graph_url).select_graph(database)
def create_node(self, uri):
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
logger.debug(f"Create node {uri}")
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_node(self, uri, user, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
res = self.io.query(
"MERGE (n:Node {uri: $uri})",
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
params={
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -56,14 +90,16 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def create_literal(self, value):
def create_literal(self, value, user, collection):
logger.debug(f"Create literal {value}")
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
res = self.io.query(
"MERGE (n:Literal {value: $value})",
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
params={
"value": value,
"user": user,
"collection": collection,
},
)
@ -72,18 +108,20 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def relate_node(self, src, uri, dest):
def relate_node(self, src, uri, dest, user, collection):
logger.debug(f"Create node rel {src} {uri} {dest}")
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Node {uri: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -92,18 +130,20 @@ class Processor(TriplesStoreService):
time=res.run_time_ms
))
def relate_literal(self, src, uri, dest):
def relate_literal(self, src, uri, dest, user, collection):
logger.debug(f"Create literal rel {src} {uri} {dest}")
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src}) "
"MATCH (dest:Literal {value: $dest}) "
"MERGE (src)-[:Rel {uri: $uri}]->(dest)",
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"collection": collection,
},
)
@ -113,17 +153,20 @@ class Processor(TriplesStoreService):
))
async def store_triples(self, message):
# Extract user and collection from metadata
user = message.metadata.user if message.metadata.user else "default"
collection = message.metadata.collection if message.metadata.collection else "default"
for t in message.triples:
self.create_node(t.s.value)
self.create_node(t.s.value, user, collection)
if t.o.is_uri:
self.create_node(t.o.value)
self.relate_node(t.s.value, t.p.value, t.o.value)
self.create_node(t.o.value, user, collection)
self.relate_node(t.s.value, t.p.value, t.o.value, user, collection)
else:
self.create_literal(t.o.value)
self.relate_literal(t.s.value, t.p.value, t.o.value)
self.create_literal(t.o.value, user, collection)
self.relate_literal(t.s.value, t.p.value, t.o.value, user, collection)
@staticmethod
def add_args(parser):
@ -142,6 +185,59 @@ class Processor(TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete the collection for FalkorDB triples"""
try:
# Delete all nodes and literals for this user/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": message.user, "collection": message.collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes and {literal_result.nodes_deleted} literals for collection {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -13,6 +13,10 @@ import logging
from neo4j import GraphDatabase
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
with self.io.session(database=self.db) as session:
self.create_indexes(session)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{self.id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_indexes(self, session):
# Race condition, index creation failure is ignored. Right thing
@ -285,6 +317,67 @@ class Processor(TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,6 +12,10 @@ import logging
from neo4j import GraphDatabase
from .... base import TriplesStoreService
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... schema import StorageManagementRequest, StorageManagementResponse, Error
from .... schema import triples_storage_management_topic, storage_management_response_topic
# Module logger
logger = logging.getLogger(__name__)
@ -49,6 +53,34 @@ class Processor(TriplesStoreService):
with self.io.session(database=self.db) as session:
self.create_indexes(session)
# Set up metrics for storage management
storage_request_metrics = ConsumerMetrics(
processor=self.id, flow=None, name="storage-request"
)
storage_response_metrics = ProducerMetrics(
processor=self.id, flow=None, name="storage-response"
)
# Set up consumer for storage management requests
self.storage_request_consumer = Consumer(
taskgroup=self.taskgroup,
client=self.pulsar_client,
flow=None,
topic=triples_storage_management_topic,
subscriber=f"{id}-storage",
schema=StorageManagementRequest,
handler=self.on_storage_management,
metrics=storage_request_metrics,
)
# Set up producer for storage management responses
self.storage_response_producer = Producer(
client=self.pulsar_client,
topic=storage_management_response_topic,
schema=StorageManagementResponse,
metrics=storage_response_metrics,
)
def create_indexes(self, session):
# Race condition, index creation failure is ignored. Right thing
@ -236,6 +268,67 @@ class Processor(TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
async def on_storage_management(self, message):
"""Handle storage management requests"""
logger.info(f"Storage management request: {message.operation} for {message.user}/{message.collection}")
try:
if message.operation == "delete-collection":
await self.handle_delete_collection(message)
else:
response = StorageManagementResponse(
error=Error(
type="invalid_operation",
message=f"Unknown operation: {message.operation}"
)
)
await self.storage_response_producer.send(response)
except Exception as e:
logger.error(f"Error processing storage management request: {e}", exc_info=True)
response = StorageManagementResponse(
error=Error(
type="processing_error",
message=str(e)
)
)
await self.storage_response_producer.send(response)
async def handle_delete_collection(self, message):
"""Delete all data for a specific collection"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"DETACH DELETE n",
user=message.user, collection=message.collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes and {literals_deleted} literals for {message.user}/{message.collection}")
# Send success response
response = StorageManagementResponse(
error=None # No error means success
)
await self.storage_response_producer.send(response)
logger.info(f"Successfully deleted collection {message.user}/{message.collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
raise
def run():
Processor.launch(default_ident, __doc__)

View file

@ -111,6 +111,21 @@ class LibraryTableStore:
);
""");
logger.debug("collections table...")
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS collections (
user text,
collection text,
name text,
description text,
tags set<text>,
created_at timestamp,
updated_at timestamp,
PRIMARY KEY (user, collection)
);
""");
logger.info("Cassandra schema OK.")
def prepare_statements(self):
@ -187,6 +202,43 @@ class LibraryTableStore:
LIMIT 1
""")
# Collection management statements
self.insert_collection_stmt = self.cassandra.prepare("""
INSERT INTO collections
(user, collection, name, description, tags, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""")
self.update_collection_stmt = self.cassandra.prepare("""
UPDATE collections
SET name = ?, description = ?, tags = ?, updated_at = ?
WHERE user = ? AND collection = ?
""")
self.get_collection_stmt = self.cassandra.prepare("""
SELECT collection, name, description, tags, created_at, updated_at
FROM collections
WHERE user = ? AND collection = ?
""")
self.list_collections_stmt = self.cassandra.prepare("""
SELECT collection, name, description, tags, created_at, updated_at
FROM collections
WHERE user = ?
""")
self.delete_collection_stmt = self.cassandra.prepare("""
DELETE FROM collections
WHERE user = ? AND collection = ?
""")
self.collection_exists_stmt = self.cassandra.prepare("""
SELECT collection
FROM collections
WHERE user = ? AND collection = ?
LIMIT 1
""")
self.list_processing_stmt = self.cassandra.prepare("""
SELECT
id, document_id, time, flow, collection, tags
@ -521,3 +573,113 @@ class LibraryTableStore:
return lst
# Collection management methods
async def ensure_collection_exists(self, user, collection):
"""Ensure collection metadata record exists, create if not"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.collection_exists_stmt, [user, collection]
)
if resp:
return
import datetime
now = datetime.datetime.now()
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.insert_collection_stmt,
[user, collection, collection, "", set(), now, now]
)
logger.debug(f"Created collection metadata for {user}/{collection}")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
async def list_collections(self, user, tag_filter=None):
"""List collections for a user, optionally filtered by tags"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.list_collections_stmt, [user]
)
collections = []
for row in resp:
collection_data = {
"user": user,
"collection": row[0],
"name": row[1] or row[0],
"description": row[2] or "",
"tags": list(row[3]) if row[3] else [],
"created_at": row[4].isoformat() if row[4] else "",
"updated_at": row[5].isoformat() if row[5] else ""
}
if tag_filter:
collection_tags = set(collection_data["tags"])
filter_tags = set(tag_filter)
if not filter_tags.intersection(collection_tags):
continue
collections.append(collection_data)
return collections
except Exception as e:
logger.error(f"Error listing collections: {e}")
raise
async def update_collection(self, user, collection, name=None, description=None, tags=None):
"""Update collection metadata"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
)
if not resp:
raise RequestError(f"Collection {collection} not found")
row = resp.one()
current_name = row[1] or collection
current_description = row[2] or ""
current_tags = set(row[3]) if row[3] else set()
new_name = name if name is not None else current_name
new_description = description if description is not None else current_description
new_tags = set(tags) if tags is not None else current_tags
import datetime
now = datetime.datetime.now()
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.update_collection_stmt,
[new_name, new_description, new_tags, now, user, collection]
)
return {
"user": user, "collection": collection, "name": new_name,
"description": new_description, "tags": list(new_tags),
"updated_at": now.isoformat()
}
except Exception as e:
logger.error(f"Error updating collection: {e}")
raise
async def delete_collection_metadata(self, user, collection):
"""Delete collection metadata record"""
try:
await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.delete_collection_stmt, [user, collection]
)
logger.debug(f"Deleted collection metadata for {user}/{collection}")
except Exception as e:
logger.error(f"Error deleting collection metadata: {e}")
raise
async def get_collection(self, user, collection):
"""Get collection metadata"""
try:
resp = await asyncio.get_event_loop().run_in_executor(
None, self.cassandra.execute, self.get_collection_stmt, [user, collection]
)
if not resp:
return None
row = resp.one()
return {
"user": user, "collection": row[0], "name": row[1] or row[0],
"description": row[2] or "", "tags": list(row[3]) if row[3] else [],
"created_at": row[4].isoformat() if row[4] else "",
"updated_at": row[5].isoformat() if row[5] else ""
}
except Exception as e:
logger.error(f"Error getting collection: {e}")
raise