diff --git a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py index b9a306c1..f4a1d977 100644 --- a/tests/unit/test_query/test_doc_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_doc_embeddings_qdrant_query.py @@ -104,7 +104,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert # Verify query was called with correct parameters - expected_collection = 'd_test_user_test_collection_3' + expected_collection = 'd_test_user_test_collection' mock_qdrant_instance.query_points.assert_called_once_with( collection_name=expected_collection, query=[0.1, 0.2, 0.3], @@ -166,7 +166,7 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): assert mock_qdrant_instance.query_points.call_count == 2 # Verify both collections were queried - expected_collection = 'd_multi_user_multi_collection_2' + expected_collection = 'd_multi_user_multi_collection' calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection assert calls[1][1]['collection_name'] == expected_collection @@ -303,11 +303,11 @@ class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase): calls = mock_qdrant_instance.query_points.call_args_list # First call should use 2D collection - assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2' + assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection' assert calls[0][1]['query'] == [0.1, 0.2] # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' + assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection' assert calls[1][1]['query'] == [0.3, 0.4, 0.5] # Verify results diff --git a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py index 11d11d35..0dd0e94e 100644 --- a/tests/unit/test_query/test_graph_embeddings_qdrant_query.py +++ b/tests/unit/test_query/test_graph_embeddings_qdrant_query.py @@ -176,7 +176,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): # Assert # Verify query was called with correct parameters - expected_collection = 't_test_user_test_collection_3' + expected_collection = 't_test_user_test_collection' mock_qdrant_instance.query_points.assert_called_once_with( collection_name=expected_collection, query=[0.1, 0.2, 0.3], @@ -236,7 +236,7 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): assert mock_qdrant_instance.query_points.call_count == 2 # Verify both collections were queried - expected_collection = 't_multi_user_multi_collection_2' + expected_collection = 't_multi_user_multi_collection' calls = mock_qdrant_instance.query_points.call_args_list assert calls[0][1]['collection_name'] == expected_collection assert calls[1][1]['collection_name'] == expected_collection @@ -374,11 +374,11 @@ class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase): calls = mock_qdrant_instance.query_points.call_args_list # First call should use 2D collection - assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2' + assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection' assert calls[0][1]['query'] == [0.1, 0.2] # Second call should use 3D collection - assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3' + assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection' assert calls[1][1]['query'] == [0.3, 0.4, 0.5] # Verify results diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py index 9a826899..0e1abeaf 100644 --- a/trustgraph-base/trustgraph/api/collection.py +++ b/trustgraph-base/trustgraph/api/collection.py @@ -27,6 +27,14 @@ class Collection: object = self.request(input) try: + # Handle case where collections might be None or missing + if object is None or "collections" not in object: + return [] + + collections = object.get("collections", []) + if collections is None: + return [] + return [ CollectionMetadata( user = v["user"], @@ -37,7 +45,7 @@ class Collection: created_at = v["created_at"], updated_at = v["updated_at"] ) - for v in object["collections"] + for v in collections ] except Exception as e: logger.error("Failed to parse collection list response", exc_info=True) diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 5c2a0fd4..38ac813b 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -8,43 +8,43 @@ class CollectionManagementRequestTranslator(MessageTranslator): def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( - operation=data.get("operation", ""), - user=data.get("user", ""), - collection=data.get("collection", ""), - timestamp=data.get("timestamp", ""), - name=data.get("name", ""), - description=data.get("description", ""), - tags=data.get("tags", []), - created_at=data.get("created_at", ""), - updated_at=data.get("updated_at", ""), - tag_filter=data.get("tag_filter", []), - limit=data.get("limit", 50) + operation=data.get("operation"), + user=data.get("user"), + collection=data.get("collection"), + timestamp=data.get("timestamp"), + name=data.get("name"), + description=data.get("description"), + tags=data.get("tags"), + created_at=data.get("created_at"), + updated_at=data.get("updated_at"), + tag_filter=data.get("tag_filter"), + limit=data.get("limit") ) def from_pulsar(self, obj: CollectionManagementRequest) -> Dict[str, Any]: result = {} - if obj.operation: + if obj.operation is not None: result["operation"] = obj.operation - if obj.user: + if obj.user is not None: result["user"] = obj.user - if obj.collection: + if obj.collection is not None: result["collection"] = obj.collection - if obj.timestamp: + if obj.timestamp is not None: result["timestamp"] = obj.timestamp - if obj.name: + if obj.name is not None: result["name"] = obj.name - if obj.description: + if obj.description is not None: result["description"] = obj.description - if obj.tags: + if obj.tags is not None: result["tags"] = list(obj.tags) - if obj.created_at: + if obj.created_at is not None: result["created_at"] = obj.created_at - if obj.updated_at: + if obj.updated_at is not None: result["updated_at"] = obj.updated_at - if obj.tag_filter: + if obj.tag_filter is not None: result["tag_filter"] = list(obj.tag_filter) - if obj.limit: + if obj.limit is not None: result["limit"] = obj.limit return result @@ -54,13 +54,14 @@ class CollectionManagementResponseTranslator(MessageTranslator): """Translator for CollectionManagementResponse schema objects""" def to_pulsar(self, data: Dict[str, Any]) -> CollectionManagementResponse: + # Handle error error = None if "error" in data and data["error"]: error_data = data["error"] error = Error( - type=error_data.get("type", ""), - message=error_data.get("message", "") + type=error_data.get("type"), + message=error_data.get("message") ) # Handle collections array @@ -68,35 +69,34 @@ class CollectionManagementResponseTranslator(MessageTranslator): if "collections" in data: for coll_data in data["collections"]: collections.append(CollectionMetadata( - user=coll_data.get("user", ""), - collection=coll_data.get("collection", ""), - name=coll_data.get("name", ""), - description=coll_data.get("description", ""), - tags=coll_data.get("tags", []), - created_at=coll_data.get("created_at", ""), - updated_at=coll_data.get("updated_at", "") + user=coll_data.get("user"), + collection=coll_data.get("collection"), + name=coll_data.get("name"), + description=coll_data.get("description"), + tags=coll_data.get("tags"), + created_at=coll_data.get("created_at"), + updated_at=coll_data.get("updated_at") )) return CollectionManagementResponse( - success=data.get("success", ""), error=error, - timestamp=data.get("timestamp", ""), + timestamp=data.get("timestamp"), collections=collections ) def from_pulsar(self, obj: CollectionManagementResponse) -> Dict[str, Any]: result = {} - if obj.success: - result["success"] = obj.success - if obj.error: + print("COLLECTIONMGMT", obj, flush=True) + + if obj.error is not None: result["error"] = { "type": obj.error.type, "message": obj.error.message } - if obj.timestamp: + if obj.timestamp is not None: result["timestamp"] = obj.timestamp - if obj.collections: + if obj.collections is not None: result["collections"] = [] for coll in obj.collections: result["collections"].append({ @@ -109,4 +109,6 @@ class CollectionManagementResponseTranslator(MessageTranslator): "updated_at": coll.updated_at }) - return result \ No newline at end of file + print("RESULT IS", result, flush=True) + + return result diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index bb837c63..905b2056 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -42,8 +42,7 @@ class CollectionManagementRequest(Record): class CollectionManagementResponse(Record): """Response for collection management operations""" - success = String() # "true" or "false" - error = Error() # Only populated if success is "false" + error = Error() # Only populated if there's an error timestamp = String() # ISO timestamp collections = Array(CollectionMetadata()) diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 06b1e303..70b0a1b8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -87,7 +87,7 @@ tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" tg-delete-config-item = "trustgraph.cli.delete_config_item:main" tg-list-collections = "trustgraph.cli.list_collections:main" -tg-update-collection = "trustgraph.cli.update_collection:main" +tg-set-collection = "trustgraph.cli.set_collection:main" tg-delete-collection = "trustgraph.cli.delete_collection:main" [tool.setuptools.packages.find] diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py index 8429b0cb..56929e93 100644 --- a/trustgraph-cli/trustgraph/cli/list_collections.py +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -17,8 +17,9 @@ def list_collections(url, user, tag_filter): collections = api.list_collections(user=user, tag_filter=tag_filter) - if len(collections) == 0: - print("No collections.") + # Handle None or empty collections + if not collections or len(collections) == 0: + print("No collections found.") return table = [] diff --git a/trustgraph-cli/trustgraph/cli/update_collection.py b/trustgraph-cli/trustgraph/cli/set_collection.py similarity index 85% rename from trustgraph-cli/trustgraph/cli/update_collection.py rename to trustgraph-cli/trustgraph/cli/set_collection.py index 094c033c..e987c4c8 100644 --- a/trustgraph-cli/trustgraph/cli/update_collection.py +++ b/trustgraph-cli/trustgraph/cli/set_collection.py @@ -1,5 +1,5 @@ """ -Update collection metadata +Set collection metadata (creates if doesn't exist) """ import argparse @@ -10,7 +10,7 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = "trustgraph" -def update_collection(url, user, collection, name, description, tags): +def set_collection(url, user, collection, name, description, tags): api = Api(url).collection() @@ -23,7 +23,7 @@ def update_collection(url, user, collection, name, description, tags): ) if result: - print(f"Collection '{collection}' updated successfully.") + print(f"Collection '{collection}' set successfully.") table = [] table.append(("Collection", result.collection)) @@ -39,18 +39,18 @@ def update_collection(url, user, collection, name, description, tags): maxcolwidths=[None, 67], )) else: - print(f"Failed to update collection '{collection}'.") + print(f"Failed to set collection '{collection}'.") def main(): parser = argparse.ArgumentParser( - prog='tg-update-collection', + prog='tg-set-collection', description=__doc__, ) parser.add_argument( 'collection', - help='Collection ID to update' + help='Collection ID to set' ) parser.add_argument( @@ -86,7 +86,7 @@ def main(): try: - update_collection( + set_collection( url = args.api_url, user = args.user, collection = args.collection, diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index a4cf12b4..20f58b6b 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -31,9 +31,9 @@ class KnowledgeGraph: self.table = "triples" # Legacy single table else: # New optimized tables - self.subject_table = "triples_by_subject" - self.po_table = "triples_by_po" - self.object_table = "triples_by_object" + self.subject_table = "triples_s" + self.po_table = "triples_p" + self.object_table = "triples_o" if username and password: ssl_context = SSLContext(PROTOCOL_TLSv1_2) @@ -157,7 +157,7 @@ class KnowledgeGraph: # Query statements for optimized access self.get_all_stmt = self.session.prepare( - f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ?" + f"SELECT s, p, o FROM {self.subject_table} WHERE collection = ? LIMIT ? ALLOW FILTERING" ) self.get_s_stmt = self.session.prepare( @@ -182,7 +182,7 @@ class KnowledgeGraph: ) self.get_os_stmt = self.session.prepare( - f"SELECT p FROM {self.subject_table} WHERE collection = ? AND s = ? AND o = ? LIMIT ?" + f"SELECT p FROM {self.object_table} WHERE collection = ? AND o = ? AND s = ? LIMIT ?" ) self.get_spo_stmt = self.session.prepare( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py index 6e78db48..f2755ae8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py @@ -22,7 +22,9 @@ class CollectionManagementRequestor(ServiceRequestor): self.response_translator = TranslatorRegistry.get_response_translator("collection-management") def to_request(self, body): + print("REQUEST", body, flush=True) return self.request_translator.to_pulsar(body) def from_response(self, message): - return self.response_translator.from_response_with_completion(message) \ No newline at end of file + print("RESPONSE", message, flush=True) + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py new file mode 100644 index 00000000..f830db28 --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -0,0 +1,315 @@ +""" +Collection management for the librarian +""" + +import asyncio +import logging +from datetime import datetime +from typing import Dict, Any, List, Optional + +from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error +from .. schema import CollectionMetadata +from .. schema import StorageManagementRequest, StorageManagementResponse +from .. exceptions import RequestError +from .. tables.library import LibraryTableStore + +# Module logger +logger = logging.getLogger(__name__) + +class CollectionManager: + """Manages collection metadata and coordinates collection operations across storage types""" + + def __init__( + self, + cassandra_host, + cassandra_username, + cassandra_password, + keyspace, + vector_storage_producer=None, + object_storage_producer=None, + triples_storage_producer=None, + storage_response_consumer=None + ): + """ + Initialize the CollectionManager + + Args: + cassandra_host: Cassandra host(s) + cassandra_username: Cassandra username + cassandra_password: Cassandra password + keyspace: Cassandra keyspace for library data + vector_storage_producer: Producer for vector storage management + object_storage_producer: Producer for object storage management + triples_storage_producer: Producer for triples storage management + storage_response_consumer: Consumer for storage management responses + """ + self.table_store = LibraryTableStore( + cassandra_host, cassandra_username, cassandra_password, keyspace + ) + + # Storage management producers + self.vector_storage_producer = vector_storage_producer + self.object_storage_producer = object_storage_producer + self.triples_storage_producer = triples_storage_producer + self.storage_response_consumer = storage_response_consumer + + # Track pending deletion operations + self.pending_deletions = {} + + logger.info("Collection manager initialized") + + async def ensure_collection_exists(self, user: str, collection: str): + """ + Ensure a collection exists, creating it if necessary (lazy creation) + + Args: + user: User ID + collection: Collection ID + """ + try: + # Check if collection already exists + existing = await self.table_store.get_collection(user, collection) + if existing: + logger.debug(f"Collection {user}/{collection} already exists") + return + + # Create new collection with default metadata + logger.info(f"Creating new collection {user}/{collection}") + await self.table_store.create_collection( + user=user, + collection=collection, + name=collection, # Default name to collection ID + description="", + tags=set() + ) + + except Exception as e: + logger.error(f"Error ensuring collection exists: {e}") + # Don't fail the operation if collection creation fails + # This maintains backward compatibility + + async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse: + """ + List collections for a user with optional tag filtering + + Args: + request: Collection management request + + Returns: + CollectionManagementResponse with list of collections + """ + try: + tag_filter = list(request.tag_filter) if request.tag_filter else None + collections = await self.table_store.list_collections(request.user, tag_filter) + + collection_metadata = [ + CollectionMetadata( + user=coll["user"], + collection=coll["collection"], + name=coll["name"], + description=coll["description"], + tags=coll["tags"], + created_at=coll["created_at"], + updated_at=coll["updated_at"] + ) + for coll in collections + ] + + return CollectionManagementResponse( + error=None, + collections=collection_metadata, + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error listing collections: {e}") + raise RequestError(f"Failed to list collections: {str(e)}") + + async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: + """ + Update collection metadata (creates if doesn't exist) + + Args: + request: Collection management request + + Returns: + CollectionManagementResponse with updated collection + """ + try: + # Check if collection exists, create if it doesn't + existing = await self.table_store.get_collection(request.user, request.collection) + if not existing: + # Create new collection with provided metadata + logger.info(f"Creating new collection {request.user}/{request.collection}") + + name = request.name if request.name else request.collection + description = request.description if request.description else "" + tags = set(request.tags) if request.tags else set() + + await self.table_store.create_collection( + user=request.user, + collection=request.collection, + name=name, + description=description, + tags=tags + ) + + # Get the newly created collection for response + created_collection = await self.table_store.get_collection(request.user, request.collection) + + collection_metadata = CollectionMetadata( + user=created_collection["user"], + collection=created_collection["collection"], + name=created_collection["name"], + description=created_collection["description"], + tags=created_collection["tags"], + created_at=created_collection["created_at"], + updated_at=created_collection["updated_at"] + ) + else: + # Collection exists, update it + name = request.name if request.name else None + description = request.description if request.description else None + tags = list(request.tags) if request.tags else None + + updated_collection = await self.table_store.update_collection( + request.user, request.collection, name, description, tags + ) + + collection_metadata = CollectionMetadata( + user=updated_collection["user"], + collection=updated_collection["collection"], + name=updated_collection["name"], + description=updated_collection["description"], + tags=updated_collection["tags"], + created_at="", # Not returned by update + updated_at=updated_collection["updated_at"] + ) + + return CollectionManagementResponse( + error=None, + collections=[collection_metadata], + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error updating collection: {e}") + raise RequestError(f"Failed to update collection: {str(e)}") + + async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: + """ + Delete collection with cascade to all storage types + + Args: + request: Collection management request + + Returns: + CollectionManagementResponse indicating success or failure + """ + try: + deletion_key = (request.user, request.collection) + + logger.info(f"Starting cascade deletion for {request.user}/{request.collection}") + + # Track this deletion request + self.pending_deletions[deletion_key] = { + "responses_pending": 3, # vector, object, triples + "responses_received": [], + "all_successful": True, + "error_messages": [], + "deletion_complete": asyncio.Event() + } + + # Create storage management request + storage_request = StorageManagementRequest( + operation="delete-collection", + user=request.user, + collection=request.collection + ) + + # Send deletion requests to all storage types + if self.vector_storage_producer: + await self.vector_storage_producer.send(storage_request) + if self.object_storage_producer: + await self.object_storage_producer.send(storage_request) + if self.triples_storage_producer: + await self.triples_storage_producer.send(storage_request) + + # Wait for all storage deletions to complete (with timeout) + deletion_info = self.pending_deletions[deletion_key] + try: + await asyncio.wait_for( + deletion_info["deletion_complete"].wait(), + timeout=30.0 # 30 second timeout + ) + except asyncio.TimeoutError: + logger.error(f"Timeout waiting for storage deletion responses for {deletion_key}") + deletion_info["all_successful"] = False + deletion_info["error_messages"].append("Timeout waiting for storage deletion") + + # Check if all deletions succeeded + if not deletion_info["all_successful"]: + error_msg = f"Storage deletion failed: {'; '.join(deletion_info['error_messages'])}" + logger.error(error_msg) + + # Clean up tracking + del self.pending_deletions[deletion_key] + + return CollectionManagementResponse( + error=Error( + type="storage_deletion_error", + message=error_msg + ), + timestamp=datetime.now().isoformat() + ) + + # All storage deletions succeeded, now delete metadata + logger.info(f"Storage deletions complete, removing metadata for {deletion_key}") + await self.table_store.delete_collection(request.user, request.collection) + + # Clean up tracking + del self.pending_deletions[deletion_key] + + return CollectionManagementResponse( + error=None, + timestamp=datetime.now().isoformat() + ) + + except Exception as e: + logger.error(f"Error deleting collection: {e}") + # Clean up tracking on error + if deletion_key in self.pending_deletions: + del self.pending_deletions[deletion_key] + raise RequestError(f"Failed to delete collection: {str(e)}") + + async def on_storage_response(self, response: StorageManagementResponse): + """ + Handle storage management responses for deletion tracking + + Args: + response: Storage management response + """ + logger.debug(f"Received storage response: error={response.error}") + + # Find matching deletion by checking all pending deletions + # Note: This is simplified correlation - in production we'd want better correlation + for deletion_key, info in list(self.pending_deletions.items()): + if info["responses_pending"] > 0: + # Record this response + info["responses_received"].append(response) + info["responses_pending"] -= 1 + + # Check if this response indicates failure + if response.error and response.error.message: + info["all_successful"] = False + info["error_messages"].append(response.error.message) + logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}") + else: + logger.debug(f"Storage deletion succeeded for {deletion_key}") + + # If all responses received, signal completion + if info["responses_pending"] == 0: + logger.info(f"All storage responses received for {deletion_key}") + info["deletion_complete"].set() + + break # Only process for first matching deletion \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/librarian/collection_service.py b/trustgraph-flow/trustgraph/librarian/collection_service.py deleted file mode 100644 index 7a4b9e6e..00000000 --- a/trustgraph-flow/trustgraph/librarian/collection_service.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Collection management service for the librarian -""" - -import asyncio -import logging -from datetime import datetime - -from .. base import AsyncProcessor, Consumer, Producer -from .. base import ConsumerMetrics, ProducerMetrics -from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config - -from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error -from .. schema import collection_request_queue, collection_response_queue -from .. schema import CollectionMetadata -from .. schema import StorageManagementRequest, StorageManagementResponse -from .. schema import vector_storage_management_topic, object_storage_management_topic, triples_storage_management_topic, storage_management_response_topic - -from .. exceptions import RequestError -from .. tables.library import LibraryTableStore - -# Module logger -logger = logging.getLogger(__name__) - -default_ident = "collection-management" -default_cassandra_host = "cassandra" -keyspace = "librarian" - -class Processor(AsyncProcessor): - - def __init__(self, **params): - - id = params.get("id", default_ident) - - # Get Cassandra configuration - cassandra_host = params.get("cassandra_host", default_cassandra_host) - cassandra_username = params.get("cassandra_username") - cassandra_password = params.get("cassandra_password") - - # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( - host=cassandra_host, - username=cassandra_username, - password=cassandra_password - ) - - super(Processor, self).__init__( - **params | { - "cassandra_host": ','.join(hosts), - "cassandra_username": username - } - ) - - self.cassandra_host = hosts - self.cassandra_username = username - self.cassandra_password = password - - # Set up metrics - collection_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="collection-request" - ) - collection_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="collection-response" - ) - - # Set up consumer for collection management requests - self.collection_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=collection_request_queue, - subscriber=id, - schema=CollectionManagementRequest, - handler=self.on_collection_request, - metrics=collection_request_metrics, - ) - - # Set up producer for collection management responses - self.collection_response_producer = Producer( - client=self.pulsar_client, - topic=collection_response_queue, - schema=CollectionManagementResponse, - metrics=collection_response_metrics, - ) - - # Set up producers for storage management requests - self.vector_storage_producer = Producer( - client=self.pulsar_client, - topic=vector_storage_management_topic, - schema=StorageManagementRequest, - ) - - self.object_storage_producer = Producer( - client=self.pulsar_client, - topic=object_storage_management_topic, - schema=StorageManagementRequest, - ) - - self.triples_storage_producer = Producer( - client=self.pulsar_client, - topic=triples_storage_management_topic, - schema=StorageManagementRequest, - ) - - # Set up consumer for storage management responses - storage_response_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - self.storage_response_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=storage_management_response_topic, - subscriber=f"{id}-storage", - schema=StorageManagementResponse, - handler=self.on_storage_response, - metrics=storage_response_metrics, - ) - - # Initialize table store - self.table_store = LibraryTableStore( - cassandra_host=self.cassandra_host, - cassandra_username=self.cassandra_username, - cassandra_password=self.cassandra_password, - keyspace=keyspace - ) - - # Track pending deletion requests by user+collection - self.pending_deletions = {} # (user, collection) -> {responses_pending, responses_received, all_successful, error_messages, deletion_complete} - - async def on_collection_request(self, message): - """Handle collection management requests""" - - logger.debug(f"Collection request: {message.operation}") - - try: - if message.operation == "list-collections": - response = await self.handle_list_collections(message) - elif message.operation == "update-collection": - response = await self.handle_update_collection(message) - elif message.operation == "delete-collection": - response = await self.handle_delete_collection(message) - else: - response = CollectionManagementResponse( - success="false", - error=Error( - type="invalid_operation", - message=f"Unknown operation: {message.operation}" - ), - timestamp=datetime.now().isoformat() - ) - - except Exception as e: - logger.error(f"Error processing collection request: {e}", exc_info=True) - response = CollectionManagementResponse( - success="false", - error=Error( - type="processing_error", - message=str(e) - ), - timestamp=datetime.now().isoformat() - ) - - await self.collection_response_producer.send(response) - - async def on_storage_response(self, response): - """Handle storage management responses""" - logger.debug(f"Received storage response: error={response.error}") - - # Find matching deletion by checking all pending deletions - # Note: This is simplified correlation - assumes responses come back quickly - # In production, we'd want better correlation mechanism - for deletion_key, info in list(self.pending_deletions.items()): - if info["responses_pending"] > 0: - # Record this response - info["responses_received"].append(response) - info["responses_pending"] -= 1 - - # Check if this response indicates failure - if response.error and response.error.message: - info["all_successful"] = False - info["error_messages"].append(response.error.message) - logger.warning(f"Storage deletion failed for {deletion_key}: {response.error.message}") - else: - logger.debug(f"Storage deletion succeeded for {deletion_key}") - - # If all responses received, signal completion - if info["responses_pending"] == 0: - logger.info(f"All storage responses received for {deletion_key}") - info["deletion_complete"].set() - - break # Only process for first matching deletion - - # For now, we'll correlate by user+collection since we don't have deletion_id in the response - # This is a simplified approach - in production we'd want better correlation - for deletion_id, info in list(self.pending_deletions.items()): - if info["responses_pending"] > 0: - # Record this response - info["responses_received"].append(response) - info["responses_pending"] -= 1 - - # Check if this response indicates failure - if response.error and response.error.message: - info["all_successful"] = False - info["error_messages"].append(response.error.message) - logger.warning(f"Storage deletion failed for {deletion_id}: {response.error.message}") - - # If all responses received, signal completion - if info["responses_pending"] == 0: - logger.info(f"All storage responses received for {deletion_id}") - info["deletion_complete"].set() - - break # Only process for first matching deletion - - async def handle_list_collections(self, message): - """Handle list collections request""" - try: - tag_filter = list(message.tag_filter) if message.tag_filter else None - collections = await self.table_store.list_collections(message.user, tag_filter) - - collection_metadata = [ - CollectionMetadata( - user=coll["user"], - collection=coll["collection"], - name=coll["name"], - description=coll["description"], - tags=coll["tags"], - created_at=coll["created_at"], - updated_at=coll["updated_at"] - ) - for coll in collections - ] - - return CollectionManagementResponse( - success="true", - collections=collection_metadata, - timestamp=datetime.now().isoformat() - ) - - except Exception as e: - logger.error(f"Error listing collections: {e}") - raise - - async def handle_update_collection(self, message): - """Handle update collection request""" - try: - # Extract fields for update - name = message.name if message.name else None - description = message.description if message.description else None - tags = list(message.tags) if message.tags else None - - updated_collection = await self.table_store.update_collection( - message.user, message.collection, name, description, tags - ) - - collection_metadata = CollectionMetadata( - user=updated_collection["user"], - collection=updated_collection["collection"], - name=updated_collection["name"], - description=updated_collection["description"], - tags=updated_collection["tags"], - created_at="", # Not returned by update - updated_at=updated_collection["updated_at"] - ) - - return CollectionManagementResponse( - success="true", - collections=[collection_metadata], - timestamp=datetime.now().isoformat() - ) - - except Exception as e: - logger.error(f"Error updating collection: {e}") - raise - - async def handle_delete_collection(self, message): - """Handle delete collection request with cascade to all storage types""" - try: - deletion_key = (message.user, message.collection) - - logger.info(f"Starting cascade deletion for {message.user}/{message.collection}") - - # Track this deletion request - self.pending_deletions[deletion_key] = { - "responses_pending": 3, # vector, object, triples - "responses_received": [], - "all_successful": True, - "error_messages": [], - "deletion_complete": asyncio.Event() - } - - # Create storage management request - storage_request = StorageManagementRequest( - operation="delete-collection", - user=message.user, - collection=message.collection - ) - - # Send delete requests to all three storage types - await self.vector_storage_producer.send(storage_request) - await self.object_storage_producer.send(storage_request) - await self.triples_storage_producer.send(storage_request) - - logger.info(f"Storage deletion requests sent for {message.user}/{message.collection}") - - # Wait for all storage responses (with timeout) - try: - await asyncio.wait_for( - self.pending_deletions[deletion_key]["deletion_complete"].wait(), - timeout=30.0 # 30 second timeout - ) - except asyncio.TimeoutError: - logger.error(f"Timeout waiting for storage responses for {deletion_key}") - self.pending_deletions[deletion_key]["all_successful"] = False - self.pending_deletions[deletion_key]["error_messages"].append("Timeout waiting for storage responses") - - # Check if all storage deletions were successful - deletion_info = self.pending_deletions.pop(deletion_key, {}) - - if deletion_info.get("all_successful", False): - # All storage deletions succeeded, now delete metadata - await self.table_store.delete_collection_metadata(message.user, message.collection) - logger.info(f"Successfully completed cascade deletion for {message.user}/{message.collection}") - - return CollectionManagementResponse( - success="true", - timestamp=datetime.now().isoformat() - ) - else: - # Some storage deletions failed - error_messages = deletion_info.get("error_messages", ["Unknown storage deletion error"]) - error_msg = "; ".join(error_messages) - logger.error(f"Cascade deletion failed for {deletion_key}: {error_msg}") - - return CollectionManagementResponse( - success="false", - error=Error( - type="storage_deletion_error", - message=f"Storage deletion failed: {error_msg}" - ), - timestamp=datetime.now().isoformat() - ) - - except Exception as e: - logger.error(f"Error in cascade deletion: {e}") - return CollectionManagementResponse( - success="false", - error=Error( - type="deletion_error", - message=f"Failed to delete collection: {str(e)}" - ), - timestamp=datetime.now().isoformat() - ) - - @staticmethod - def add_args(parser): - AsyncProcessor.add_args(parser) - add_cassandra_args(parser) - -def run(): - Processor.launch(default_ident, __doc__) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index d1e2ae01..00d64010 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -8,6 +8,7 @@ import asyncio import base64 import json import logging +from datetime import datetime from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber from .. base import ConsumerMetrics, ProducerMetrics @@ -15,6 +16,11 @@ from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_confi from .. schema import LibrarianRequest, LibrarianResponse, Error from .. schema import librarian_request_queue, librarian_response_queue +from .. schema import CollectionManagementRequest, CollectionManagementResponse +from .. schema import collection_request_queue, collection_response_queue +from .. schema import StorageManagementRequest, StorageManagementResponse +from .. schema import vector_storage_management_topic, object_storage_management_topic +from .. schema import triples_storage_management_topic, storage_management_response_topic from .. schema import Document, Metadata from .. schema import TextDocument, Metadata @@ -22,6 +28,7 @@ from .. schema import TextDocument, Metadata from .. exceptions import RequestError from . librarian import Librarian +from . collection_manager import CollectionManager # Module logger logger = logging.getLogger(__name__) @@ -30,6 +37,8 @@ default_ident = "librarian" default_librarian_request_queue = librarian_request_queue default_librarian_response_queue = librarian_response_queue +default_collection_request_queue = collection_request_queue +default_collection_response_queue = collection_response_queue default_minio_host = "minio:9000" default_minio_access_key = "minioadmin" @@ -57,6 +66,14 @@ class Processor(AsyncProcessor): "librarian_response_queue", default_librarian_response_queue ) + collection_request_queue = params.get( + "collection_request_queue", default_collection_request_queue + ) + + collection_response_queue = params.get( + "collection_response_queue", default_collection_response_queue + ) + minio_host = params.get("minio_host", default_minio_host) minio_access_key = params.get( "minio_access_key", @@ -87,6 +104,8 @@ class Processor(AsyncProcessor): **params | { "librarian_request_queue": librarian_request_queue, "librarian_response_queue": librarian_response_queue, + "collection_request_queue": collection_request_queue, + "collection_response_queue": collection_response_queue, "minio_host": minio_host, "minio_access_key": minio_access_key, "cassandra_host": self.cassandra_host, @@ -103,6 +122,18 @@ class Processor(AsyncProcessor): processor = self.id, flow = None, name = "librarian-response" ) + collection_request_metrics = ConsumerMetrics( + processor = self.id, flow = None, name = "collection-request" + ) + + collection_response_metrics = ProducerMetrics( + processor = self.id, flow = None, name = "collection-response" + ) + + storage_response_metrics = ConsumerMetrics( + processor = self.id, flow = None, name = "storage-response" + ) + self.librarian_request_consumer = Consumer( taskgroup = self.taskgroup, client = self.pulsar_client, @@ -121,6 +152,54 @@ class Processor(AsyncProcessor): metrics = librarian_response_metrics, ) + self.collection_request_consumer = Consumer( + taskgroup = self.taskgroup, + client = self.pulsar_client, + flow = None, + topic = collection_request_queue, + subscriber = id, + schema = CollectionManagementRequest, + handler = self.on_collection_request, + metrics = collection_request_metrics, + ) + + self.collection_response_producer = Producer( + client = self.pulsar_client, + topic = collection_response_queue, + schema = CollectionManagementResponse, + metrics = collection_response_metrics, + ) + + # Storage management producers for collection deletion + self.vector_storage_producer = Producer( + client = self.pulsar_client, + topic = vector_storage_management_topic, + schema = StorageManagementRequest, + ) + + self.object_storage_producer = Producer( + client = self.pulsar_client, + topic = object_storage_management_topic, + schema = StorageManagementRequest, + ) + + self.triples_storage_producer = Producer( + client = self.pulsar_client, + topic = triples_storage_management_topic, + schema = StorageManagementRequest, + ) + + self.storage_response_consumer = Consumer( + taskgroup = self.taskgroup, + client = self.pulsar_client, + flow = None, + topic = storage_management_response_topic, + subscriber = id, + schema = StorageManagementResponse, + handler = self.on_storage_response, + metrics = storage_response_metrics, + ) + self.librarian = Librarian( cassandra_host = self.cassandra_host, cassandra_username = self.cassandra_username, @@ -133,6 +212,17 @@ class Processor(AsyncProcessor): load_document = self.load_document, ) + self.collection_manager = CollectionManager( + cassandra_host = self.cassandra_host, + cassandra_username = self.cassandra_username, + cassandra_password = self.cassandra_password, + keyspace = keyspace, + vector_storage_producer = self.vector_storage_producer, + object_storage_producer = self.object_storage_producer, + triples_storage_producer = self.triples_storage_producer, + storage_response_consumer = self.storage_response_consumer, + ) + self.register_config_handler(self.on_librarian_config) self.flows = {} @@ -144,6 +234,12 @@ class Processor(AsyncProcessor): await super(Processor, self).start() await self.librarian_request_consumer.start() await self.librarian_response_producer.start() + await self.collection_request_consumer.start() + await self.collection_response_producer.start() + await self.vector_storage_producer.start() + await self.object_storage_producer.start() + await self.triples_storage_producer.start() + await self.storage_response_consumer.start() async def on_librarian_config(self, config, version): @@ -223,6 +319,19 @@ class Processor(AsyncProcessor): logger.debug("Document submitted") + async def add_processing_with_collection(self, request): + """ + Wrapper for add_processing that ensures collection exists + """ + # Ensure collection exists when processing is added + if hasattr(request, 'processing_metadata') and request.processing_metadata: + user = request.processing_metadata.user + collection = request.processing_metadata.collection + await self.collection_manager.ensure_collection_exists(user, collection) + + # Call the original add_processing method + return await self.librarian.add_processing(request) + async def process_request(self, v): if v.operation is None: @@ -236,7 +345,7 @@ class Processor(AsyncProcessor): "update-document": self.librarian.update_document, "get-document-metadata": self.librarian.get_document_metadata, "get-document-content": self.librarian.get_document_content, - "add-processing": self.librarian.add_processing, + "add-processing": self.add_processing_with_collection, "remove-processing": self.librarian.remove_processing, "list-documents": self.librarian.list_documents, "list-processing": self.librarian.list_processing, @@ -296,6 +405,73 @@ class Processor(AsyncProcessor): logger.debug("Librarian input processing complete") + async def process_collection_request(self, v): + """ + Process collection management requests + """ + if v.operation is None: + raise RequestError("Null operation") + + logger.debug(f"Collection request: {v.operation}") + + impls = { + "list-collections": self.collection_manager.list_collections, + "update-collection": self.collection_manager.update_collection, + "delete-collection": self.collection_manager.delete_collection, + } + + if v.operation not in impls: + raise RequestError(f"Invalid collection operation: {v.operation}") + + return await impls[v.operation](v) + + async def on_collection_request(self, msg, consumer, flow): + """ + Handle collection management request messages + """ + v = msg.value() + id = msg.properties().get("id", "unknown") + + logger.info(f"Handling collection request {id}...") + + try: + resp = await self.process_collection_request(v) + await self.collection_response_producer.send( + resp, properties={"id": id} + ) + except RequestError as e: + resp = CollectionManagementResponse( + error=Error( + type="request-error", + message=str(e), + ), + timestamp=datetime.now().isoformat() + ) + await self.collection_response_producer.send( + resp, properties={"id": id} + ) + except Exception as e: + resp = CollectionManagementResponse( + error=Error( + type="unexpected-error", + message=str(e), + ), + timestamp=datetime.now().isoformat() + ) + await self.collection_response_producer.send( + resp, properties={"id": id} + ) + + logger.debug("Collection request processing complete") + + async def on_storage_response(self, msg, consumer, flow): + """ + Handle storage management response messages + """ + v = msg.value() + logger.debug("Received storage management response") + await self.collection_manager.on_storage_response(v) + @staticmethod def add_args(parser): @@ -313,6 +489,18 @@ class Processor(AsyncProcessor): help=f'Config response queue {default_librarian_response_queue}', ) + parser.add_argument( + '--collection-request-queue', + default=default_collection_request_queue, + help=f'Collection request queue (default: {default_collection_request_queue})' + ) + + parser.add_argument( + '--collection-response-queue', + default=default_collection_response_queue, + help=f'Collection response queue (default: {default_collection_response_queue})' + ) + parser.add_argument( '--minio-host', default=default_minio_host, diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index bb07e063..f94f3b93 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -67,8 +67,7 @@ class Processor(DocumentEmbeddingsQueryService): dim = len(vec) collection = ( - "d_" + msg.user + "_" + msg.collection + "_" + - str(dim) + "d_" + msg.user + "_" + msg.collection ) self.ensure_collection_exists(collection, dim) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 756f619b..0b792566 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -74,8 +74,7 @@ class Processor(GraphEmbeddingsQueryService): dim = len(vec) collection = ( - "t_" + msg.user + "_" + msg.collection + "_" + - str(dim) + "t_" + msg.user + "_" + msg.collection ) self.ensure_collection_exists(collection, dim) diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index fb3d5a0e..839f3afa 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -654,7 +654,7 @@ class LibraryTableStore: logger.error(f"Error updating collection: {e}") raise - async def delete_collection_metadata(self, user, collection): + async def delete_collection(self, user, collection): """Delete collection metadata record""" try: await asyncio.get_event_loop().run_in_executor( @@ -683,3 +683,35 @@ class LibraryTableStore: except Exception as e: logger.error(f"Error getting collection: {e}") raise + + async def create_collection(self, user, collection, name=None, description=None, tags=None): + """Create a new collection metadata record""" + try: + import datetime + now = datetime.datetime.now() + + # Set defaults for optional parameters + name = name if name is not None else collection + description = description if description is not None else "" + tags = tags if tags is not None else set() + + await asyncio.get_event_loop().run_in_executor( + None, self.cassandra.execute, self.insert_collection_stmt, + [user, collection, name, description, tags, now, now] + ) + + logger.info(f"Created collection {user}/{collection}") + + # Return the created collection data + return { + "user": user, + "collection": collection, + "name": name, + "description": description, + "tags": list(tags) if isinstance(tags, set) else tags, + "created_at": now.isoformat(), + "updated_at": now.isoformat() + } + except Exception as e: + logger.error(f"Error creating collection: {e}") + raise