diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 48e86ea1..d02df438 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=1.7.999 + run: make update-package-versions VERSION=1.8.999 - name: Setup environment run: python3 -m venv env diff --git a/docs/apis/api-librarian.md b/docs/apis/api-librarian.md index 71f1b912..43db4258 100644 --- a/docs/apis/api-librarian.md +++ b/docs/apis/api-librarian.md @@ -1,8 +1,8 @@ # TrustGraph Librarian API -This API provides document library management for TrustGraph. It handles document storage, -metadata management, and processing orchestration using hybrid storage (MinIO for content, -Cassandra for metadata) with multi-user support. +This API provides document library management for TrustGraph. It handles document storage, +metadata management, and processing orchestration using hybrid storage (S3-compatible object +storage for content, Cassandra for metadata) with multi-user support. ## Request/response @@ -374,13 +374,14 @@ await client.add_processing( ## Features -- **Hybrid Storage**: MinIO for content, Cassandra for metadata +- **Hybrid Storage**: S3-compatible object storage (MinIO, Ceph RGW, AWS S3, etc.) for content, Cassandra for metadata - **Multi-user Support**: User-based document ownership and access control - **Rich Metadata**: RDF-style metadata triples and tagging system - **Processing Integration**: Automatic triggering of document processing workflows - **Content Types**: Support for multiple document formats (PDF, text, etc.) - **Collection Management**: Optional document grouping by collection - **Metadata Search**: Query documents by metadata criteria +- **Flexible Storage Backend**: Works with any S3-compatible storage (MinIO, Ceph RADOS Gateway, AWS S3, Cloudflare R2, etc.) ## Use Cases diff --git a/docs/tech-specs/collection-management.md b/docs/tech-specs/collection-management.md index ffc5f63f..542abdd0 100644 --- a/docs/tech-specs/collection-management.md +++ b/docs/tech-specs/collection-management.md @@ -233,9 +233,13 @@ When a user initiates collection deletion through the librarian service: #### Collection Management Interface -All store writers implement a standardized collection management interface with a common schema: +**⚠️ LEGACY APPROACH - REPLACED BY CONFIG-BASED PATTERN** -**Message Schema (`StorageManagementRequest`):** +The queue-based architecture described below has been replaced with a config-based approach using `CollectionConfigHandler`. All storage backends now receive collection updates via config push messages instead of dedicated management queues. + +~~All store writers implement a standardized collection management interface with a common schema:~~ + +~~**Message Schema (`StorageManagementRequest`):**~~ ```json { "operation": "create-collection" | "delete-collection", @@ -244,24 +248,26 @@ All store writers implement a standardized collection management interface with } ``` -**Queue Architecture:** -- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores -- **Object Store Management Queue** (`object-storage-management`): Object/document stores -- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores -- **Storage Response Queue** (`storage-management-response`): All responses sent here +~~**Queue Architecture:**~~ +- ~~**Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores~~ +- ~~**Object Store Management Queue** (`object-storage-management`): Object/document stores~~ +- ~~**Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores~~ +- ~~**Storage Response Queue** (`storage-management-response`): All responses sent here~~ -Each store writer implements: -- **Collection Management Handler**: Processes `StorageManagementRequest` messages -- **Create Collection Operation**: Establishes collection in storage backend -- **Delete Collection Operation**: Removes all data associated with collection -- **Collection State Tracking**: Maintains knowledge of which collections exist -- **Message Processing**: Consumes from dedicated management queue -- **Status Reporting**: Returns success/failure via `StorageManagementResponse` -- **Idempotent Operations**: Safe to call create/delete multiple times +**Current Implementation:** -**Supported Operations:** -- `create-collection`: Create collection in storage backend -- `delete-collection`: Remove all collection data from storage backend +All storage backends now use `CollectionConfigHandler`: +- **Config Push Integration**: Storage services register for config push notifications +- **Automatic Synchronization**: Collections created/deleted based on config changes +- **Declarative Model**: Collections defined in config service, backends sync to match +- **No Request/Response**: Eliminates coordination overhead and response tracking +- **Collection State Tracking**: Maintained via `known_collections` cache +- **Idempotent Operations**: Safe to process same config multiple times + +Each storage backend implements: +- `create_collection(user: str, collection: str, metadata: dict)` - Create collection structures +- `delete_collection(user: str, collection: str)` - Remove all collection data +- `collection_exists(user: str, collection: str) -> bool` - Validate before writes #### Cassandra Triple Store Refactor @@ -365,62 +371,33 @@ Comprehensive testing will cover: - `triples_collection` table for SPO queries and deletion tracking - Collection deletion implemented with read-then-delete pattern -### 🔄 In Progress Components +### ✅ Migration to Config-Based Pattern - COMPLETED -1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`) - - Update `update_collection()` to send "create-collection" to storage backends - - Wait for confirmations from all storage processors - - Handle creation failures appropriately +**All storage backends have been migrated from the queue-based pattern to the config-based `CollectionConfigHandler` pattern.** -2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar) - - Check if collection exists when document submitted - - If not exists: Create collection with defaults before processing document - - Trigger same "create-collection" broadcast as `tg-set-collection` - - Ensure collection established before document flows to storage processors +Completed migrations: +- ✅ `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py` +- ✅ `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` -### ❌ Pending Components +All backends now: +- Inherit from `CollectionConfigHandler` +- Register for config push notifications via `self.register_config_handler(self.on_collection_config)` +- Implement `create_collection(user, collection, metadata)` and `delete_collection(user, collection)` +- Use `collection_exists(user, collection)` to validate before writes +- Automatically sync with config service changes -1. **Collection State Tracking** - Need to implement in each storage backend: - - **Cassandra Triples**: Use `triples_collection` table with marker triples - - **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes - - **Qdrant/Milvus/Pinecone**: Use native collection APIs - - **Cassandra Objects**: Add collection metadata tracking - -2. **Storage Management Handlers** - Need "create-collection" support in 12 files: - - `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` - - `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py` - - `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py` - - `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py` - - `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py` - - `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py` - - `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py` - - `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py` - - `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py` - - `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py` - - `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` - - Plus any other storage implementations - -3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods - -4. **Query Operation Handling** - Update queries to return empty for non-existent collections - -### Next Implementation Steps - -**Phase 1: Core Infrastructure (2-3 days)** -1. Add collection state tracking methods to all storage backends -2. Implement `collection_exists()` and `create_collection()` methods - -**Phase 2: Storage Handlers (1 week)** -3. Add "create-collection" handlers to all storage processors -4. Add write validation to reject non-existent collections -5. Update query handling for non-existent collections - -**Phase 3: Collection Manager (2-3 days)** -6. Update collection_manager to broadcast creates -7. Implement response tracking and error handling - -**Phase 4: Testing (3-5 days)** -8. End-to-end testing of explicit creation workflow -9. Test all storage backends -10. Validate error handling and edge cases +Legacy queue-based infrastructure removed: +- ✅ Removed `StorageManagementRequest` and `StorageManagementResponse` schemas +- ✅ Removed storage management queue topic definitions +- ✅ Removed storage management consumer/producer from all backends +- ✅ Removed `on_storage_management` handlers from all backends diff --git a/docs/tech-specs/logging-strategy.md b/docs/tech-specs/logging-strategy.md index b05b7c59..84f1eac8 100644 --- a/docs/tech-specs/logging-strategy.md +++ b/docs/tech-specs/logging-strategy.md @@ -2,17 +2,29 @@ ## Overview -TrustGraph uses Python's built-in `logging` module for all logging operations. This provides a standardized, flexible approach to logging across all components of the system. +TrustGraph uses Python's built-in `logging` module for all logging operations, with centralized configuration and optional Loki integration for log aggregation. This provides a standardized, flexible approach to logging across all components of the system. ## Default Configuration ### Logging Level - **Default Level**: `INFO` -- **Debug Mode**: `DEBUG` (enabled via command-line argument) -- **Production**: `WARNING` or `ERROR` as appropriate +- **Configurable via**: `--log-level` command-line argument +- **Choices**: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` -### Output Destination -All logs should be written to **standard output (stdout)** to ensure compatibility with containerized environments and log aggregation systems. +### Output Destinations +1. **Console (stdout)**: Always enabled - ensures compatibility with containerized environments +2. **Loki**: Optional centralized log aggregation (enabled by default, can be disabled) + +## Centralized Logging Module + +All logging configuration is managed by `trustgraph.base.logging` module, which provides: +- `add_logging_args(parser)` - Adds standard logging CLI arguments +- `setup_logging(args)` - Configures logging from parsed arguments + +This module is used by all server-side components: +- AsyncProcessor-based services +- API Gateway +- MCP Server ## Implementation Guidelines @@ -26,39 +38,80 @@ import logging logger = logging.getLogger(__name__) ``` -### 2. Centralized Configuration +The logger name is automatically used as a label in Loki for filtering and searching. -The logging configuration should be centralized in `async_processor.py` (or a dedicated logging configuration module) since it's inherited by much of the codebase: +### 2. Service Initialization + +All server-side services automatically get logging configuration through the centralized module: ```python -import logging +from trustgraph.base import add_logging_args, setup_logging import argparse -def setup_logging(log_level='INFO'): - """Configure logging for the entire application""" - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] - ) - -def parse_args(): +def main(): parser = argparse.ArgumentParser() - parser.add_argument( - '--log-level', - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - help='Set the logging level (default: INFO)' - ) - return parser.parse_args() -# In main execution -if __name__ == '__main__': - args = parse_args() - setup_logging(args.log_level) + # Add standard logging arguments (includes Loki configuration) + add_logging_args(parser) + + # Add your service-specific arguments + parser.add_argument('--port', type=int, default=8080) + + args = parser.parse_args() + args = vars(args) + + # Setup logging early in startup + setup_logging(args) + + # Rest of your service initialization + logger = logging.getLogger(__name__) + logger.info("Service starting...") ``` -### 3. Logging Best Practices +### 3. Command-Line Arguments + +All services support these logging arguments: + +**Log Level:** +```bash +--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL} +``` + +**Loki Configuration:** +```bash +--loki-enabled # Enable Loki (default) +--no-loki-enabled # Disable Loki +--loki-url URL # Loki push URL (default: http://loki:3100/loki/api/v1/push) +--loki-username USERNAME # Optional authentication +--loki-password PASSWORD # Optional authentication +``` + +**Examples:** +```bash +# Default - INFO level, Loki enabled +./my-service + +# Debug mode, console only +./my-service --log-level DEBUG --no-loki-enabled + +# Custom Loki server with auth +./my-service --loki-url http://loki.prod:3100/loki/api/v1/push \ + --loki-username admin --loki-password secret +``` + +### 4. Environment Variables + +Loki configuration supports environment variable fallbacks: + +```bash +export LOKI_URL=http://loki.prod:3100/loki/api/v1/push +export LOKI_USERNAME=admin +export LOKI_PASSWORD=secret +``` + +Command-line arguments take precedence over environment variables. + +### 5. Logging Best Practices #### Log Levels Usage - **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit) @@ -89,20 +142,25 @@ if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Debug data: {debug_data}") ``` -### 4. Structured Logging +### 6. Structured Logging with Loki -For complex data, use structured logging: +For complex data, use structured logging with extra tags for Loki: ```python logger.info("Request processed", extra={ - 'request_id': request_id, - 'duration_ms': duration, - 'status_code': status_code, - 'user_id': user_id + 'tags': { + 'request_id': request_id, + 'user_id': user_id, + 'status': 'success' + } }) ``` -### 5. Exception Logging +These tags become searchable labels in Loki, in addition to automatic labels: +- `severity` - Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) +- `logger` - Module name (from `__name__`) + +### 7. Exception Logging Always include stack traces for exceptions: @@ -114,9 +172,13 @@ except Exception as e: raise ``` -### 6. Async Logging Considerations +### 8. Async Logging Considerations -For async code, ensure thread-safe logging: +The logging system uses non-blocking queued handlers for Loki: +- Console output is synchronous (fast) +- Loki output is queued with 500-message buffer +- Background thread handles Loki transmission +- No blocking of main application code ```python import asyncio @@ -124,46 +186,165 @@ import logging async def async_operation(): logger = logging.getLogger(__name__) + # Logging is thread-safe and won't block async operations logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}") ``` -## Environment Variables +## Loki Integration -Support environment-based configuration as a fallback: +### Architecture + +The logging system uses Python's built-in `QueueHandler` and `QueueListener` for non-blocking Loki integration: + +1. **QueueHandler**: Logs are placed in a 500-message queue (non-blocking) +2. **Background Thread**: QueueListener sends logs to Loki asynchronously +3. **Graceful Degradation**: If Loki is unavailable, console logging continues + +### Automatic Labels + +Every log sent to Loki includes: +- `processor`: Processor identity (e.g., `config-svc`, `text-completion`, `embeddings`) +- `severity`: Log level (DEBUG, INFO, etc.) +- `logger`: Module name (e.g., `trustgraph.gateway.service`, `trustgraph.agent.react.service`) + +### Custom Labels + +Add custom labels via the `extra` parameter: ```python -import os - -log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', 'INFO') +logger.info("User action", extra={ + 'tags': { + 'user_id': user_id, + 'action': 'document_upload', + 'collection': collection_name + } +}) ``` +### Querying Logs in Loki + +```logql +# All logs from a specific processor (recommended - matches Prometheus metrics) +{processor="config-svc"} +{processor="text-completion"} +{processor="embeddings"} + +# Error logs from a specific processor +{processor="config-svc", severity="ERROR"} + +# Error logs from all processors +{severity="ERROR"} + +# Logs from a specific processor with text filter +{processor="text-completion"} |= "Processing" + +# All logs from API gateway +{processor="api-gateway"} + +# Logs from processors matching pattern +{processor=~".*-completion"} + +# Logs with custom tags +{processor="api-gateway"} | json | user_id="12345" +``` + +### Graceful Degradation + +If Loki is unavailable or `python-logging-loki` is not installed: +- Warning message printed to console +- Console logging continues normally +- Application continues running +- No retry logic for Loki connection (fail fast, degrade gracefully) + ## Testing During tests, consider using a different logging configuration: ```python # In test setup -logging.getLogger().setLevel(logging.WARNING) # Reduce noise during tests +import logging + +# Reduce noise during tests +logging.getLogger().setLevel(logging.WARNING) + +# Or disable Loki for tests +setup_logging({'log_level': 'WARNING', 'loki_enabled': False}) ``` ## Monitoring Integration -Ensure log format is compatible with monitoring tools: -- Include timestamps in ISO format -- Use consistent field names -- Include correlation IDs where applicable -- Structure logs for easy parsing (JSON format for production) +### Standard Format +All logs use consistent format: +``` +2025-01-09 10:30:45,123 - trustgraph.gateway.service - INFO - Request processed +``` + +Format components: +- Timestamp (ISO format with milliseconds) +- Logger name (module path) +- Log level +- Message + +### Loki Queries for Monitoring + +Common monitoring queries: + +```logql +# Error rate by processor +rate({severity="ERROR"}[5m]) by (processor) + +# Top error-producing processors +topk(5, count_over_time({severity="ERROR"}[1h]) by (processor)) + +# Recent errors with processor name +{severity="ERROR"} | line_format "{{.processor}}: {{.message}}" + +# All agent processors +{processor=~".*agent.*"} |= "exception" + +# Specific processor error count +count_over_time({processor="config-svc", severity="ERROR"}[1h]) +``` ## Security Considerations -- Never log sensitive information (passwords, API keys, personal data) -- Sanitize user input before logging -- Use placeholders for sensitive fields: `user_id=****1234` +- **Never log sensitive information** (passwords, API keys, personal data, tokens) +- **Sanitize user input** before logging +- **Use placeholders** for sensitive fields: `user_id=****1234` +- **Loki authentication**: Use `--loki-username` and `--loki-password` for secure deployments +- **Secure transport**: Use HTTPS for Loki URL in production: `https://loki.prod:3100/loki/api/v1/push` + +## Dependencies + +The centralized logging module requires: +- `python-logging-loki` - For Loki integration (optional, graceful degradation if missing) + +Already included in `trustgraph-base/pyproject.toml` and `requirements.txt`. ## Migration Path -For existing code using print statements: -1. Replace `print()` with appropriate logger calls -2. Choose appropriate log levels based on message importance -3. Add context to make logs more useful -4. Test logging output at different levels \ No newline at end of file +For existing code: + +1. **Services already using AsyncProcessor**: No changes needed, Loki support is automatic +2. **Services not using AsyncProcessor** (api-gateway, mcp-server): Already updated +3. **CLI tools**: Out of scope - continue using print() or simple logging + +### From print() to logging: +```python +# Before +print(f"Processing document {doc_id}") + +# After +logger = logging.getLogger(__name__) +logger.info(f"Processing document {doc_id}") +``` + +## Configuration Summary + +| Argument | Default | Environment Variable | Description | +|----------|---------|---------------------|-------------| +| `--log-level` | `INFO` | - | Console and Loki log level | +| `--loki-enabled` | `True` | - | Enable Loki logging | +| `--loki-url` | `http://loki:3100/loki/api/v1/push` | `LOKI_URL` | Loki push endpoint | +| `--loki-username` | `None` | `LOKI_USERNAME` | Loki auth username | +| `--loki-password` | `None` | `LOKI_PASSWORD` | Loki auth password | diff --git a/docs/tech-specs/minio-to-s3-migration.md b/docs/tech-specs/minio-to-s3-migration.md new file mode 100644 index 00000000..91daf105 --- /dev/null +++ b/docs/tech-specs/minio-to-s3-migration.md @@ -0,0 +1,258 @@ +# Tech Spec: S3-Compatible Storage Backend Support + +## Overview + +The Librarian service uses S3-compatible object storage for document blob storage. This spec documents the implementation that enables support for any S3-compatible backend including MinIO, Ceph RADOS Gateway (RGW), AWS S3, Cloudflare R2, DigitalOcean Spaces, and others. + +## Architecture + +### Storage Components +- **Blob Storage**: S3-compatible object storage via `minio` Python client library +- **Metadata Storage**: Cassandra (stores object_id mapping and document metadata) +- **Affected Component**: Librarian service only +- **Storage Pattern**: Hybrid storage with metadata in Cassandra, content in S3-compatible storage + +### Implementation +- **Library**: `minio` Python client (supports any S3-compatible API) +- **Location**: `trustgraph-flow/trustgraph/librarian/blob_store.py` +- **Operations**: + - `add()` - Store blob with UUID object_id + - `get()` - Retrieve blob by object_id + - `remove()` - Delete blob by object_id + - `ensure_bucket()` - Create bucket if not exists +- **Bucket**: `library` +- **Object Path**: `doc/{object_id}` +- **Supported MIME Types**: `text/plain`, `application/pdf` + +### Key Files +1. `trustgraph-flow/trustgraph/librarian/blob_store.py` - BlobStore implementation +2. `trustgraph-flow/trustgraph/librarian/librarian.py` - BlobStore initialization +3. `trustgraph-flow/trustgraph/librarian/service.py` - Service configuration +4. `trustgraph-flow/pyproject.toml` - Dependencies (`minio` package) +5. `docs/apis/api-librarian.md` - API documentation + +## Supported Storage Backends + +The implementation works with any S3-compatible object storage system: + +### Tested/Supported +- **Ceph RADOS Gateway (RGW)** - Distributed storage system with S3 API (default configuration) +- **MinIO** - Lightweight self-hosted object storage +- **Garage** - Lightweight geo-distributed S3-compatible storage + +### Should Work (S3-Compatible) +- **AWS S3** - Amazon's cloud object storage +- **Cloudflare R2** - Cloudflare's S3-compatible storage +- **DigitalOcean Spaces** - DigitalOcean's object storage +- **Wasabi** - S3-compatible cloud storage +- **Backblaze B2** - S3-compatible backup storage +- Any other service implementing the S3 REST API + +## Configuration + +### CLI Arguments + +```bash +librarian \ + --object-store-endpoint \ + --object-store-access-key \ + --object-store-secret-key \ + [--object-store-use-ssl] \ + [--object-store-region ] +``` + +**Note:** Do not include `http://` or `https://` in the endpoint. Use `--object-store-use-ssl` to enable HTTPS. + +### Environment Variables (Alternative) + +```bash +OBJECT_STORE_ENDPOINT= +OBJECT_STORE_ACCESS_KEY= +OBJECT_STORE_SECRET_KEY= +OBJECT_STORE_USE_SSL=true|false # Optional, default: false +OBJECT_STORE_REGION= # Optional +``` + +### Examples + +**Ceph RADOS Gateway (default):** +```bash +--object-store-endpoint ceph-rgw:7480 \ +--object-store-access-key object-user \ +--object-store-secret-key object-password +``` + +**MinIO:** +```bash +--object-store-endpoint minio:9000 \ +--object-store-access-key minioadmin \ +--object-store-secret-key minioadmin +``` + +**Garage (S3-compatible):** +```bash +--object-store-endpoint garage:3900 \ +--object-store-access-key GK000000000000000000000001 \ +--object-store-secret-key b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427 +``` + +**AWS S3 with SSL:** +```bash +--object-store-endpoint s3.amazonaws.com \ +--object-store-access-key AKIAIOSFODNN7EXAMPLE \ +--object-store-secret-key wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY \ +--object-store-use-ssl \ +--object-store-region us-east-1 +``` + +## Authentication + +All S3-compatible backends require AWS Signature Version 4 (or v2) authentication: + +- **Access Key** - Public identifier (like username) +- **Secret Key** - Private signing key (like password) + +The MinIO Python client handles all signature calculation automatically. + +### Creating Credentials + +**For MinIO:** +```bash +# Use default credentials or create user via MinIO Console +minioadmin / minioadmin +``` + +**For Ceph RGW:** +```bash +radosgw-admin user create --uid="trustgraph" --display-name="TrustGraph Service" +# Returns access_key and secret_key +``` + +**For AWS S3:** +- Create IAM user with S3 permissions +- Generate access key in AWS Console + +## Library Selection: MinIO Python Client + +**Rationale:** +- Lightweight (~500KB vs boto3's ~50MB) +- S3-compatible - works with any S3 API endpoint +- Simpler API than boto3 for basic operations +- Already in use, no migration needed +- Battle-tested with MinIO and other S3 systems + +## BlobStore Implementation + +**Location:** `trustgraph-flow/trustgraph/librarian/blob_store.py` + +```python +from minio import Minio +import io +import logging + +logger = logging.getLogger(__name__) + +class BlobStore: + """ + S3-compatible blob storage for document content. + Supports MinIO, Ceph RGW, AWS S3, and other S3-compatible backends. + """ + + def __init__(self, endpoint, access_key, secret_key, bucket_name, + use_ssl=False, region=None): + """ + Initialize S3-compatible blob storage. + + Args: + endpoint: S3 endpoint (e.g., "minio:9000", "ceph-rgw:7480") + access_key: S3 access key + secret_key: S3 secret key + bucket_name: Bucket name for storage + use_ssl: Use HTTPS instead of HTTP (default: False) + region: S3 region (optional, e.g., "us-east-1") + """ + self.client = Minio( + endpoint=endpoint, + access_key=access_key, + secret_key=secret_key, + secure=use_ssl, + region=region, + ) + + self.bucket_name = bucket_name + + protocol = "https" if use_ssl else "http" + logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}") + + self.ensure_bucket() + + def ensure_bucket(self): + """Create bucket if it doesn't exist""" + found = self.client.bucket_exists(bucket_name=self.bucket_name) + if not found: + self.client.make_bucket(bucket_name=self.bucket_name) + logger.info(f"Created bucket {self.bucket_name}") + else: + logger.debug(f"Bucket {self.bucket_name} already exists") + + async def add(self, object_id, blob, kind): + """Store blob in S3-compatible storage""" + self.client.put_object( + bucket_name=self.bucket_name, + object_name=f"doc/{object_id}", + length=len(blob), + data=io.BytesIO(blob), + content_type=kind, + ) + logger.debug("Add blob complete") + + async def remove(self, object_id): + """Delete blob from S3-compatible storage""" + self.client.remove_object( + bucket_name=self.bucket_name, + object_name=f"doc/{object_id}", + ) + logger.debug("Remove blob complete") + + async def get(self, object_id): + """Retrieve blob from S3-compatible storage""" + resp = self.client.get_object( + bucket_name=self.bucket_name, + object_name=f"doc/{object_id}", + ) + return resp.read() +``` + +## Key Benefits + +1. **No Vendor Lock-in** - Works with any S3-compatible storage +2. **Lightweight** - MinIO client is only ~500KB +3. **Simple Configuration** - Just endpoint + credentials +4. **No Data Migration** - Drop-in replacement between backends +5. **Battle-Tested** - MinIO client works with all major S3 implementations + +## Implementation Status + +All code has been updated to use generic S3 parameter names: + +- ✅ `blob_store.py` - Updated to accept `endpoint`, `access_key`, `secret_key` +- ✅ `librarian.py` - Updated parameter names +- ✅ `service.py` - Updated CLI arguments and configuration +- ✅ Documentation updated + +## Future Enhancements + +1. **SSL/TLS Support** - Add `--s3-use-ssl` flag for HTTPS +2. **Retry Logic** - Implement exponential backoff for transient failures +3. **Presigned URLs** - Generate temporary upload/download URLs +4. **Multi-region Support** - Replicate blobs across regions +5. **CDN Integration** - Serve blobs via CDN +6. **Storage Classes** - Use S3 storage classes for cost optimization +7. **Lifecycle Policies** - Automatic archival/deletion +8. **Versioning** - Store multiple versions of blobs + +## References + +- MinIO Python Client: https://min.io/docs/minio/linux/developers/python/API.html +- Ceph RGW S3 API: https://docs.ceph.com/en/latest/radosgw/s3/ +- S3 API Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html diff --git a/docs/tech-specs/multi-tenant-support.md b/docs/tech-specs/multi-tenant-support.md new file mode 100644 index 00000000..dc0555c1 --- /dev/null +++ b/docs/tech-specs/multi-tenant-support.md @@ -0,0 +1,772 @@ +# Technical Specification: Multi-Tenant Support + +## Overview + +Enable multi-tenant deployments by fixing parameter name mismatches that prevent queue customization and adding Cassandra keyspace parameterization. + +## Architecture Context + +### Flow-Based Queue Resolution + +The TrustGraph system uses a **flow-based architecture** for dynamic queue resolution, which inherently supports multi-tenancy: + +- **Flow Definitions** are stored in Cassandra and specify queue names via interface definitions +- **Queue names use templates** with `{id}` variables that are replaced with flow instance IDs +- **Services dynamically resolve queues** by looking up flow configurations at request time +- **Each tenant can have unique flows** with different queue names, providing isolation + +Example flow interface definition: +```json +{ + "interfaces": { + "triples-store": "persistent://tg/flow/triples-store:{id}", + "graph-embeddings-store": "persistent://tg/flow/graph-embeddings-store:{id}" + } +} +``` + +When tenant A starts flow `tenant-a-prod` and tenant B starts flow `tenant-b-prod`, they automatically get isolated queues: +- `persistent://tg/flow/triples-store:tenant-a-prod` +- `persistent://tg/flow/triples-store:tenant-b-prod` + +**Services correctly designed for multi-tenancy:** +- ✅ **Knowledge Management (cores)** - Dynamically resolves queues from flow configuration passed in requests + +**Services needing fixes:** +- 🔴 **Config Service** - Parameter name mismatch prevents queue customization +- 🔴 **Librarian Service** - Hardcoded storage management topics (discussed below) +- 🔴 **All Services** - Cannot customize Cassandra keyspace + +## Problem Statement + +### Issue #1: Parameter Name Mismatch in AsyncProcessor +- **CLI defines:** `--config-queue` (unclear naming) +- **Argparse converts to:** `config_queue` (in params dict) +- **Code looks for:** `config_push_queue` +- **Result:** Parameter is ignored, defaults to `persistent://tg/config/config` +- **Impact:** Affects all 32+ services inheriting from AsyncProcessor +- **Blocks:** Multi-tenant deployments cannot use tenant-specific config queues +- **Solution:** Rename CLI parameter to `--config-push-queue` for clarity (breaking change acceptable since feature is currently broken) + +### Issue #2: Parameter Name Mismatch in Config Service +- **CLI defines:** `--push-queue` (ambiguous naming) +- **Argparse converts to:** `push_queue` (in params dict) +- **Code looks for:** `config_push_queue` +- **Result:** Parameter is ignored +- **Impact:** Config service cannot use custom push queue +- **Solution:** Rename CLI parameter to `--config-push-queue` for consistency and clarity (breaking change acceptable) + +### Issue #3: Hardcoded Cassandra Keyspace +- **Current:** Keyspace hardcoded as `"config"`, `"knowledge"`, `"librarian"` in various services +- **Result:** Cannot customize keyspace for multi-tenant deployments +- **Impact:** Config, cores, and librarian services +- **Blocks:** Multiple tenants cannot use separate Cassandra keyspaces + +### Issue #4: Collection Management Architecture ✅ COMPLETED +- **Previous:** Collections stored in Cassandra librarian keyspace via separate collections table +- **Previous:** Librarian used 4 hardcoded storage management topics to coordinate collection create/delete: + - `vector_storage_management_topic` + - `object_storage_management_topic` + - `triples_storage_management_topic` + - `storage_management_response_topic` +- **Problems (Resolved):** + - Hardcoded topics could not be customized for multi-tenant deployments + - Complex async coordination between librarian and 4+ storage services + - Separate Cassandra table and management infrastructure + - Non-persistent request/response queues for critical operations +- **Solution Implemented:** Migrated collections to config service storage, use config push for distribution +- **Status:** All storage backends migrated to `CollectionConfigHandler` pattern + +## Solution + +This spec addresses Issues #1, #2, #3, and #4. + +### Part 1: Fix Parameter Name Mismatches + +#### Change 1: AsyncProcessor Base Class - Rename CLI Parameter +**File:** `trustgraph-base/trustgraph/base/async_processor.py` +**Line:** 260-264 + +**Current:** +```python +parser.add_argument( + '--config-queue', + default=default_config_queue, + help=f'Config push queue {default_config_queue}', +) +``` + +**Fixed:** +```python +parser.add_argument( + '--config-push-queue', + default=default_config_queue, + help=f'Config push queue (default: {default_config_queue})', +) +``` + +**Rationale:** +- Clearer, more explicit naming +- Matches the internal variable name `config_push_queue` +- Breaking change acceptable since feature is currently non-functional +- No code change needed in params.get() - it already looks for the correct name + +#### Change 2: Config Service - Rename CLI Parameter +**File:** `trustgraph-flow/trustgraph/config/service/service.py` +**Line:** 276-279 + +**Current:** +```python +parser.add_argument( + '--push-queue', + default=default_config_push_queue, + help=f'Config push queue (default: {default_config_push_queue})' +) +``` + +**Fixed:** +```python +parser.add_argument( + '--config-push-queue', + default=default_config_push_queue, + help=f'Config push queue (default: {default_config_push_queue})' +) +``` + +**Rationale:** +- Clearer naming - "config-push-queue" is more explicit than just "push-queue" +- Matches the internal variable name `config_push_queue` +- Consistent with AsyncProcessor's `--config-push-queue` parameter +- Breaking change acceptable since feature is currently non-functional +- No code change needed in params.get() - it already looks for the correct name + +### Part 2: Add Cassandra Keyspace Parameterization + +#### Change 3: Add Keyspace Parameter to cassandra_config Module +**File:** `trustgraph-base/trustgraph/base/cassandra_config.py` + +**Add CLI argument** (in `add_cassandra_args()` function): +```python +parser.add_argument( + '--cassandra-keyspace', + default=None, + help='Cassandra keyspace (default: service-specific)' +) +``` + +**Add environment variable support** (in `resolve_cassandra_config()` function): +```python +keyspace = params.get( + "cassandra_keyspace", + os.environ.get("CASSANDRA_KEYSPACE") +) +``` + +**Update return value** of `resolve_cassandra_config()`: +- Currently returns: `(hosts, username, password)` +- Change to return: `(hosts, username, password, keyspace)` + +**Rationale:** +- Consistent with existing Cassandra configuration pattern +- Available to all services via `add_cassandra_args()` +- Supports both CLI and environment variable configuration + +#### Change 4: Config Service - Use Parameterized Keyspace +**File:** `trustgraph-flow/trustgraph/config/service/service.py` + +**Line 30** - Remove hardcoded keyspace: +```python +# DELETE THIS LINE: +keyspace = "config" +``` + +**Lines 69-73** - Update cassandra config resolution: + +**Current:** +```python +cassandra_host, cassandra_username, cassandra_password = \ + resolve_cassandra_config(params) +``` + +**Fixed:** +```python +cassandra_host, cassandra_username, cassandra_password, keyspace = \ + resolve_cassandra_config(params, default_keyspace="config") +``` + +**Rationale:** +- Maintains backward compatibility with "config" as default +- Allows override via `--cassandra-keyspace` or `CASSANDRA_KEYSPACE` + +#### Change 5: Cores/Knowledge Service - Use Parameterized Keyspace +**File:** `trustgraph-flow/trustgraph/cores/service.py` + +**Line 37** - Remove hardcoded keyspace: +```python +# DELETE THIS LINE: +keyspace = "knowledge" +``` + +**Update cassandra config resolution** (similar location as config service): +```python +cassandra_host, cassandra_username, cassandra_password, keyspace = \ + resolve_cassandra_config(params, default_keyspace="knowledge") +``` + +#### Change 6: Librarian Service - Use Parameterized Keyspace +**File:** `trustgraph-flow/trustgraph/librarian/service.py` + +**Line 51** - Remove hardcoded keyspace: +```python +# DELETE THIS LINE: +keyspace = "librarian" +``` + +**Update cassandra config resolution** (similar location as config service): +```python +cassandra_host, cassandra_username, cassandra_password, keyspace = \ + resolve_cassandra_config(params, default_keyspace="librarian") +``` + +### Part 3: Migrate Collection Management to Config Service + +#### Overview +Migrate collections from Cassandra librarian keyspace to config service storage. This eliminates hardcoded storage management topics and simplifies the architecture by using the existing config push mechanism for distribution. + +#### Current Architecture +``` +API Request → Gateway → Librarian Service + ↓ + CollectionManager + ↓ + Cassandra Collections Table (librarian keyspace) + ↓ + Broadcast to 4 Storage Management Topics (hardcoded) + ↓ + Wait for 4+ Storage Service Responses + ↓ + Response to Gateway +``` + +#### New Architecture +``` +API Request → Gateway → Librarian Service + ↓ + CollectionManager + ↓ + Config Service API (put/delete/getvalues) + ↓ + Cassandra Config Table (class='collections', key='user:collection') + ↓ + Config Push (to all subscribers on config-push-queue) + ↓ + All Storage Services receive config update independently +``` + +#### Change 7: Collection Manager - Use Config Service API +**File:** `trustgraph-flow/trustgraph/librarian/collection_manager.py` + +**Remove:** +- `LibraryTableStore` usage (Lines 33, 40-41) +- Storage management producers initialization (Lines 86-140) +- `on_storage_response` method (Lines 400-430) +- `pending_deletions` tracking (Lines 57, 90-96, and usage throughout) + +**Add:** +- Config service client for API calls (request/response pattern) + +**Config Client Setup:** +```python +# In __init__, add config request/response producers/consumers +from trustgraph.schema.services.config import ConfigRequest, ConfigResponse + +# Producer for config requests +self.config_request_producer = Producer( + client=pulsar_client, + topic=config_request_queue, + schema=ConfigRequest, +) + +# Consumer for config responses (with correlation ID) +self.config_response_consumer = Consumer( + taskgroup=taskgroup, + client=pulsar_client, + flow=None, + topic=config_response_queue, + subscriber=f"{id}-config", + schema=ConfigResponse, + handler=self.on_config_response, +) + +# Tracking for pending config requests +self.pending_config_requests = {} # request_id -> asyncio.Event +``` + +**Modify `list_collections` (Lines 145-180):** +```python +async def list_collections(self, user, tag_filter=None, limit=None): + """List collections from config service""" + # Send getvalues request to config service + request = ConfigRequest( + id=str(uuid.uuid4()), + operation='getvalues', + type='collections', + ) + + # Send request and wait for response + response = await self.send_config_request(request) + + # Parse collections from response + collections = [] + for key, value_json in response.values.items(): + if ":" in key: + coll_user, collection = key.split(":", 1) + if coll_user == user: + metadata = json.loads(value_json) + collections.append(CollectionMetadata(**metadata)) + + # Apply tag filtering in-memory (as before) + if tag_filter: + collections = [c for c in collections if any(tag in c.tags for tag in tag_filter)] + + # Apply limit + if limit: + collections = collections[:limit] + + return collections + +async def send_config_request(self, request): + """Send config request and wait for response""" + event = asyncio.Event() + self.pending_config_requests[request.id] = event + + await self.config_request_producer.send(request) + await event.wait() + + return self.pending_config_requests.pop(request.id + "_response") + +async def on_config_response(self, message, consumer, flow): + """Handle config response""" + response = message.value() + if response.id in self.pending_config_requests: + self.pending_config_requests[response.id + "_response"] = response + self.pending_config_requests[response.id].set() +``` + +**Modify `update_collection` (Lines 182-312):** +```python +async def update_collection(self, user, collection, name, description, tags): + """Update collection via config service""" + # Create metadata + metadata = CollectionMetadata( + user=user, + collection=collection, + name=name, + description=description, + tags=tags, + ) + + # Send put request to config service + request = ConfigRequest( + id=str(uuid.uuid4()), + operation='put', + type='collections', + key=f'{user}:{collection}', + value=json.dumps(metadata.to_dict()), + ) + + response = await self.send_config_request(request) + + if response.error: + raise RuntimeError(f"Config update failed: {response.error.message}") + + # Config service will trigger config push automatically + # Storage services will receive update and create collections +``` + +**Modify `delete_collection` (Lines 314-398):** +```python +async def delete_collection(self, user, collection): + """Delete collection via config service""" + # Send delete request to config service + request = ConfigRequest( + id=str(uuid.uuid4()), + operation='delete', + type='collections', + key=f'{user}:{collection}', + ) + + response = await self.send_config_request(request) + + if response.error: + raise RuntimeError(f"Config delete failed: {response.error.message}") + + # Config service will trigger config push automatically + # Storage services will receive update and delete collections +``` + +**Collection Metadata Format:** +- Stored in config table as: `class='collections', key='user:collection'` +- Value is JSON-serialized CollectionMetadata (without timestamp fields) +- Fields: `user`, `collection`, `name`, `description`, `tags` +- Example: `class='collections', key='alice:my-docs', value='{"user":"alice","collection":"my-docs","name":"My Documents","description":"...","tags":["work"]}'` + +#### Change 8: Librarian Service - Remove Storage Management Infrastructure +**File:** `trustgraph-flow/trustgraph/librarian/service.py` + +**Remove:** +- Storage management producers (Lines 173-190): + - `vector_storage_management_producer` + - `object_storage_management_producer` + - `triples_storage_management_producer` +- Storage response consumer (Lines 192-201) +- `on_storage_response` handler (Lines 467-473) + +**Modify:** +- CollectionManager initialization (Lines 215-224) - remove storage producer parameters + +**Note:** External collection API remains unchanged: +- `list-collections` +- `update-collection` +- `delete-collection` + +#### Change 9: Remove Collections Table from LibraryTableStore +**File:** `trustgraph-flow/trustgraph/tables/library.py` + +**Delete:** +- Collections table CREATE statement (Lines 114-127) +- Collections prepared statements (Lines 205-240) +- All collection methods (Lines 578-717): + - `ensure_collection_exists` + - `list_collections` + - `update_collection` + - `delete_collection` + - `get_collection` + - `create_collection` + +**Rationale:** +- Collections now stored in config table +- Breaking change acceptable - no data migration needed +- Simplifies librarian service significantly + +#### Change 10: Storage Services - Config-Based Collection Management ✅ COMPLETED + +**Status:** All 11 storage backends have been migrated to use `CollectionConfigHandler`. + +**Affected Services (11 total):** +- Document embeddings: milvus, pinecone, qdrant +- Graph embeddings: milvus, pinecone, qdrant +- Object storage: cassandra +- Triples storage: cassandra, falkordb, memgraph, neo4j + +**Files:** +- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py` +- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py` +- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py` +- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py` +- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py` +- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py` +- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py` +- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py` +- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py` +- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py` +- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py` + +**Implementation Pattern (all services):** + +1. **Register config handler in `__init__`:** +```python +# Add after AsyncProcessor initialization +self.register_config_handler(self.on_collection_config) +self.known_collections = set() # Track (user, collection) tuples +``` + +2. **Implement config handler:** +```python +async def on_collection_config(self, config, version): + """Handle collection configuration updates""" + logger.info(f"Collection config version: {version}") + + if "collections" not in config: + return + + # Parse collections from config + # Key format: "user:collection" in config["collections"] + config_collections = set() + for key in config["collections"].keys(): + if ":" in key: + user, collection = key.split(":", 1) + config_collections.add((user, collection)) + + # Determine changes + to_create = config_collections - self.known_collections + to_delete = self.known_collections - config_collections + + # Create new collections (idempotent) + for user, collection in to_create: + try: + await self.create_collection_internal(user, collection) + self.known_collections.add((user, collection)) + logger.info(f"Created collection: {user}/{collection}") + except Exception as e: + logger.error(f"Failed to create {user}/{collection}: {e}") + + # Delete removed collections (idempotent) + for user, collection in to_delete: + try: + await self.delete_collection_internal(user, collection) + self.known_collections.discard((user, collection)) + logger.info(f"Deleted collection: {user}/{collection}") + except Exception as e: + logger.error(f"Failed to delete {user}/{collection}: {e}") +``` + +3. **Initialize known collections on startup:** +```python +async def start(self): + """Start the processor""" + await super().start() + await self.sync_known_collections() + +async def sync_known_collections(self): + """Query backend to populate known_collections set""" + # Backend-specific implementation: + # - Milvus/Pinecone/Qdrant: List collections/indexes matching naming pattern + # - Cassandra: Query keyspaces or collection metadata + # - Neo4j/Memgraph/FalkorDB: Query CollectionMetadata nodes + pass +``` + +4. **Refactor existing handler methods:** +```python +# Rename and remove response sending: +# handle_create_collection → create_collection_internal +# handle_delete_collection → delete_collection_internal + +async def create_collection_internal(self, user, collection): + """Create collection (idempotent)""" + # Same logic as current handle_create_collection + # But remove response producer calls + # Handle "already exists" gracefully + pass + +async def delete_collection_internal(self, user, collection): + """Delete collection (idempotent)""" + # Same logic as current handle_delete_collection + # But remove response producer calls + # Handle "not found" gracefully + pass +``` + +5. **Remove storage management infrastructure:** + - Remove `self.storage_request_consumer` setup and start + - Remove `self.storage_response_producer` setup + - Remove `on_storage_management` dispatcher method + - Remove metrics for storage management + - Remove imports: `StorageManagementRequest`, `StorageManagementResponse` + +**Backend-Specific Considerations:** + +- **Vector stores (Milvus, Pinecone, Qdrant):** Track logical `(user, collection)` in `known_collections`, but may create multiple backend collections per dimension. Continue lazy creation pattern. Delete operations must remove all dimension variants. + +- **Cassandra Objects:** Collections are row properties, not structures. Track keyspace-level information. + +- **Graph stores (Neo4j, Memgraph, FalkorDB):** Query `CollectionMetadata` nodes on startup. Create/delete metadata nodes on sync. + +- **Cassandra Triples:** Use `KnowledgeGraph` API for collection operations. + +**Key Design Points:** + +- **Eventual consistency:** No request/response mechanism, config push is broadcast +- **Idempotency:** All create/delete operations must be safe to retry +- **Error handling:** Log errors but don't block config updates +- **Self-healing:** Failed operations will retry on next config push +- **Collection key format:** `"user:collection"` in `config["collections"]` + +#### Change 11: Update Collection Schema - Remove Timestamps +**File:** `trustgraph-base/trustgraph/schema/services/collection.py` + +**Modify CollectionMetadata (Lines 13-21):** +Remove `created_at` and `updated_at` fields: +```python +class CollectionMetadata(Record): + user = String() + collection = String() + name = String() + description = String() + tags = Array(String()) + # Remove: created_at = String() + # Remove: updated_at = String() +``` + +**Modify CollectionManagementRequest (Lines 25-47):** +Remove timestamp fields: +```python +class CollectionManagementRequest(Record): + operation = String() + user = String() + collection = String() + timestamp = String() + name = String() + description = String() + tags = Array(String()) + # Remove: created_at = String() + # Remove: updated_at = String() + tag_filter = Array(String()) + limit = Integer() +``` + +**Rationale:** +- Timestamps don't add value for collections +- Config service maintains its own version tracking +- Simplifies schema and reduces storage + +#### Benefits of Config Service Migration + +1. ✅ **Eliminates hardcoded storage management topics** - Solves multi-tenant blocker +2. ✅ **Simpler coordination** - No complex async waiting for 4+ storage responses +3. ✅ **Eventual consistency** - Storage services update independently via config push +4. ✅ **Better reliability** - Persistent config push vs non-persistent request/response +5. ✅ **Unified configuration model** - Collections treated as configuration +6. ✅ **Reduces complexity** - Removes ~300 lines of coordination code +7. ✅ **Multi-tenant ready** - Config already supports tenant isolation via keyspace +8. ✅ **Version tracking** - Config service version mechanism provides audit trail + +## Implementation Notes + +### Backward Compatibility + +**Parameter Changes:** +- CLI parameter renames are breaking changes but acceptable (feature currently non-functional) +- Services work without parameters (use defaults) +- Default keyspaces preserved: "config", "knowledge", "librarian" +- Default queue: `persistent://tg/config/config` + +**Collection Management:** +- **Breaking change:** Collections table removed from librarian keyspace +- **No data migration provided** - acceptable for this phase +- External collection API unchanged (list/update/delete operations) +- Collection metadata format simplified (timestamps removed) + +### Testing Requirements + +**Parameter Testing:** +1. Verify `--config-push-queue` parameter works on graph-embeddings service +2. Verify `--config-push-queue` parameter works on text-completion service +3. Verify `--config-push-queue` parameter works on config service +4. Verify `--cassandra-keyspace` parameter works for config service +5. Verify `--cassandra-keyspace` parameter works for cores service +6. Verify `--cassandra-keyspace` parameter works for librarian service +7. Verify services work without parameters (uses defaults) +8. Verify multi-tenant deployment with custom queue names and keyspace + +**Collection Management Testing:** +9. Verify `list-collections` operation via config service +10. Verify `update-collection` creates/updates in config table +11. Verify `delete-collection` removes from config table +12. Verify config push is triggered on collection updates +13. Verify tag filtering works with config-based storage +14. Verify collection operations work without timestamp fields + +### Multi-Tenant Deployment Example +```bash +# Tenant: tg-dev +graph-embeddings \ + -p pulsar+ssl://broker:6651 \ + --pulsar-api-key \ + --config-push-queue persistent://tg-dev/config/config + +config-service \ + -p pulsar+ssl://broker:6651 \ + --pulsar-api-key \ + --config-push-queue persistent://tg-dev/config/config \ + --cassandra-keyspace tg_dev_config +``` + +## Impact Analysis + +### Services Affected by Change 1-2 (CLI Parameter Rename) +All services inheriting from AsyncProcessor or FlowProcessor: +- config-service +- cores-service +- librarian-service +- graph-embeddings +- document-embeddings +- text-completion-* (all providers) +- extract-* (all extractors) +- query-* (all query services) +- retrieval-* (all RAG services) +- storage-* (all storage services) +- And 20+ more services + +### Services Affected by Changes 3-6 (Cassandra Keyspace) +- config-service +- cores-service +- librarian-service + +### Services Affected by Changes 7-11 (Collection Management) + +**Immediate Changes:** +- librarian-service (collection_manager.py, service.py) +- tables/library.py (collections table removal) +- schema/services/collection.py (timestamp removal) + +**Completed Changes (Change 10):** ✅ +- All storage services (11 total) - migrated to config push for collection updates via `CollectionConfigHandler` +- Storage management schema removed from `storage.py` + +## Future Considerations + +### Per-User Keyspace Model + +Some services use **per-user keyspaces** dynamically, where each user gets their own Cassandra keyspace: + +**Services with per-user keyspaces:** +1. **Triples Query Service** (`trustgraph-flow/trustgraph/query/triples/cassandra/service.py:65`) + - Uses `keyspace=query.user` +2. **Objects Query Service** (`trustgraph-flow/trustgraph/query/objects/cassandra/service.py:479`) + - Uses `keyspace=self.sanitize_name(user)` +3. **KnowledgeGraph Direct Access** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py:18`) + - Default parameter `keyspace="trustgraph"` + +**Status:** These are **not modified** in this specification. + +**Future Review Required:** +- Evaluate whether per-user keyspace model creates tenant isolation issues +- Consider if multi-tenant deployments need keyspace prefix patterns (e.g., `tenant_a_user1`) +- Review for potential user ID collision across tenants +- Assess if single shared keyspace per tenant with user-based row isolation is preferable + +**Note:** This does not block the current multi-tenant implementation but should be reviewed before production multi-tenant deployments. + +## Implementation Phases + +### Phase 1: Parameter Fixes (Changes 1-6) +- Fix `--config-push-queue` parameter naming +- Add `--cassandra-keyspace` parameter support +- **Outcome:** Multi-tenant queue and keyspace configuration enabled + +### Phase 2: Collection Management Migration (Changes 7-9, 11) +- Migrate collection storage to config service +- Remove collections table from librarian +- Update collection schema (remove timestamps) +- **Outcome:** Eliminates hardcoded storage management topics, simplifies librarian + +### Phase 3: Storage Service Updates (Change 10) ✅ COMPLETED +- Updated all storage services to use config push for collections via `CollectionConfigHandler` +- Removed storage management request/response infrastructure +- Removed legacy schema definitions +- **Outcome:** Complete config-based collection management achieved + +## References +- GitHub Issue: https://github.com/trustgraph-ai/trustgraph/issues/582 +- Related Files: + - `trustgraph-base/trustgraph/base/async_processor.py` + - `trustgraph-base/trustgraph/base/cassandra_config.py` + - `trustgraph-base/trustgraph/schema/core/topic.py` + - `trustgraph-base/trustgraph/schema/services/collection.py` + - `trustgraph-flow/trustgraph/config/service/service.py` + - `trustgraph-flow/trustgraph/cores/service.py` + - `trustgraph-flow/trustgraph/librarian/service.py` + - `trustgraph-flow/trustgraph/librarian/collection_manager.py` + - `trustgraph-flow/trustgraph/tables/library.py` diff --git a/docs/tech-specs/ontology-extract-phase-2.md b/docs/tech-specs/ontology-extract-phase-2.md new file mode 100644 index 00000000..ac1a0543 --- /dev/null +++ b/docs/tech-specs/ontology-extract-phase-2.md @@ -0,0 +1,761 @@ +# Ontology Knowledge Extraction - Phase 2 Refactor + +**Status**: Draft +**Author**: Analysis Session 2025-12-03 +**Related**: `ontology.md`, `ontorag.md` + +## Overview + +This document identifies inconsistencies in the current ontology-based knowledge extraction system and proposes a refactor to improve LLM performance and reduce information loss. + +## Current Implementation + +### How It Works Now + +1. **Ontology Loading** (`ontology_loader.py`) + - Loads ontology JSON with keys like `"fo/Recipe"`, `"fo/Food"`, `"fo/produces"` + - Class IDs include namespace prefix in the key itself + - Example from `food.ontology`: + ```json + "classes": { + "fo/Recipe": { + "uri": "http://purl.org/ontology/fo/Recipe", + "rdfs:comment": "A Recipe is a combination..." + } + } + ``` + +2. **Prompt Construction** (`extract.py:299-307`, `ontology-prompt.md`) + - Template receives `classes`, `object_properties`, `datatype_properties` dicts + - Template iterates: `{% for class_id, class_def in classes.items() %}` + - LLM sees: `**fo/Recipe**: A Recipe is a combination...` + - Example output format shows: + ```json + {"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"} + {"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"} + ``` + +3. **Response Parsing** (`extract.py:382-428`) + - Expects JSON array: `[{"subject": "...", "predicate": "...", "object": "..."}]` + - Validates against ontology subset + - Expands URIs via `expand_uri()` (extract.py:473-521) + +4. **URI Expansion** (`extract.py:473-521`) + - Checks if value is in `ontology_subset.classes` dict + - If found, extracts URI from class definition + - If not found, constructs URI: `f"https://trustgraph.ai/ontology/{ontology_id}#{value}"` + +### Data Flow Example + +**Ontology JSON → Loader → Prompt:** +``` +"fo/Recipe" → classes["fo/Recipe"] → LLM sees "**fo/Recipe**" +``` + +**LLM → Parser → Output:** +``` +"Recipe" → not in classes["fo/Recipe"] → constructs URI → LOSES original URI +"fo/Recipe" → found in classes → uses original URI → PRESERVES URI +``` + +## Problems Identified + +### 1. **Inconsistent Examples in Prompt** + +**Issue**: The prompt template shows class IDs with prefixes (`fo/Recipe`) but the example output uses unprefixed class names (`Recipe`). + +**Location**: `ontology-prompt.md:5-52` + +```markdown +## Ontology Classes: +- **fo/Recipe**: A Recipe is... + +## Example Output: +{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"} +``` + +**Impact**: LLM receives conflicting signals about what format to use. + +### 2. **Information Loss in URI Expansion** + +**Issue**: When LLM returns unprefixed class names following the example, `expand_uri()` can't find them in the ontology dict and constructs fallback URIs, losing the original proper URIs. + +**Location**: `extract.py:494-500` + +```python +if value in ontology_subset.classes: # Looks for "Recipe" + class_def = ontology_subset.classes[value] # But key is "fo/Recipe" + if isinstance(class_def, dict) and 'uri' in class_def: + return class_def['uri'] # Never reached! +return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" # Fallback +``` + +**Impact**: +- Original URI: `http://purl.org/ontology/fo/Recipe` +- Constructed URI: `https://trustgraph.ai/ontology/food#Recipe` +- Semantic meaning lost, breaks interoperability + +### 3. **Ambiguous Entity Instance Format** + +**Issue**: No clear guidance on entity instance URI format. + +**Examples in prompt**: +- `"recipe:cornish-pasty"` (namespace-like prefix) +- `"ingredient:flour"` (different prefix) + +**Actual behavior** (extract.py:517-520): +```python +# Treat as entity instance - construct unique URI +normalized = value.replace(" ", "-").lower() +return f"https://trustgraph.ai/{ontology_id}/{normalized}" +``` + +**Impact**: LLM must guess prefixing convention with no ontology context. + +### 4. **No Namespace Prefix Guidance** + +**Issue**: The ontology JSON contains namespace definitions (line 10-25 in food.ontology): +```json +"namespaces": { + "fo": "http://purl.org/ontology/fo/", + "rdf": "http://www.w3.org/1999/02/22-rdf-syntax-ns#", + ... +} +``` + +But these are never surfaced to the LLM. The LLM doesn't know: +- What "fo" means +- What prefix to use for entities +- Which namespace applies to which elements + +### 5. **Labels Not Used in Prompt** + +**Issue**: Every class has `rdfs:label` fields (e.g., `{"value": "Recipe", "lang": "en-gb"}`), but the prompt template doesn't use them. + +**Current**: Shows only `class_id` and `comment` +```jinja +- **{{class_id}}**{% if class_def.comment %}: {{class_def.comment}}{% endif %} +``` + +**Available but unused**: +```python +"rdfs:label": [{"value": "Recipe", "lang": "en-gb"}] +``` + +**Impact**: Could provide human-readable names alongside technical IDs. + +## Proposed Solutions + +### Option A: Normalize to Unprefixed IDs + +**Approach**: Strip prefixes from class IDs before showing to LLM. + +**Changes**: +1. Modify `build_extraction_variables()` to transform keys: + ```python + classes_for_prompt = { + k.split('/')[-1]: v # "fo/Recipe" → "Recipe" + for k, v in ontology_subset.classes.items() + } + ``` + +2. Update prompt example to match (already uses unprefixed names) + +3. Modify `expand_uri()` to handle both formats: + ```python + # Try exact match first + if value in ontology_subset.classes: + return ontology_subset.classes[value]['uri'] + + # Try with prefix + for prefix in ['fo/', 'rdf:', 'rdfs:']: + prefixed = f"{prefix}{value}" + if prefixed in ontology_subset.classes: + return ontology_subset.classes[prefixed]['uri'] + ``` + +**Pros**: +- Cleaner, more human-readable +- Matches existing prompt examples +- LLMs work better with simpler tokens + +**Cons**: +- Class name collisions if multiple ontologies have same class name +- Loses namespace information +- Requires fallback logic for lookups + +### Option B: Use Full Prefixed IDs Consistently + +**Approach**: Update examples to use prefixed IDs matching what's shown in the class list. + +**Changes**: +1. Update prompt example (ontology-prompt.md:46-52): + ```json + [ + {"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "fo/Recipe"}, + {"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"}, + {"subject": "recipe:cornish-pasty", "predicate": "fo/produces", "object": "food:cornish-pasty"}, + {"subject": "food:cornish-pasty", "predicate": "rdf:type", "object": "fo/Food"} + ] + ``` + +2. Add namespace explanation to prompt: + ```markdown + ## Namespace Prefixes: + - **fo/**: Food Ontology (http://purl.org/ontology/fo/) + - **rdf:**: RDF Schema + - **rdfs:**: RDF Schema + + Use these prefixes exactly as shown when referencing classes and properties. + ``` + +3. Keep `expand_uri()` as-is (works correctly when matches found) + +**Pros**: +- Input = Output consistency +- No information loss +- Preserves namespace semantics +- Works with multiple ontologies + +**Cons**: +- More verbose tokens for LLM +- Requires LLM to track prefixes + +### Option C: Hybrid - Show Both Label and ID + +**Approach**: Enhance prompt to show both human-readable labels and technical IDs. + +**Changes**: +1. Update prompt template: + ```jinja + {% for class_id, class_def in classes.items() %} + - **{{class_id}}** (label: "{{class_def.labels[0].value if class_def.labels else class_id}}"){% if class_def.comment %}: {{class_def.comment}}{% endif %} + {% endfor %} + ``` + + Example output: + ```markdown + - **fo/Recipe** (label: "Recipe"): A Recipe is a combination... + ``` + +2. Update instructions: + ```markdown + When referencing classes: + - Use the full prefixed ID (e.g., "fo/Recipe") in JSON output + - The label (e.g., "Recipe") is for human understanding only + ``` + +**Pros**: +- Clearest for LLM +- Preserves all information +- Explicit about what to use + +**Cons**: +- Longer prompt +- More complex template + +## Implemented Approach + +**Simplified Entity-Relationship-Attribute Format** - completely replaces the old triple-based format. + +The new approach was chosen because: + +1. **No Information Loss**: Original URIs preserved correctly +2. **Simpler Logic**: No transformation needed, direct dict lookups work +3. **Namespace Safety**: Handles multiple ontologies without collisions +4. **Semantic Correctness**: Maintains RDF/OWL semantics + +## Implementation Complete + +### What Was Built: + +1. **New Prompt Template** (`prompts/ontology-extract-v2.txt`) + - ✅ Clear sections: Entity Types, Relationships, Attributes + - ✅ Example using full type identifiers (`fo/Recipe`, `fo/has_ingredient`) + - ✅ Instructions to use exact identifiers from schema + - ✅ New JSON format with entities/relationships/attributes arrays + +2. **Entity Normalization** (`entity_normalizer.py`) + - ✅ `normalize_entity_name()` - Converts names to URI-safe format + - ✅ `normalize_type_identifier()` - Handles slashes in types (`fo/Recipe` → `fo-recipe`) + - ✅ `build_entity_uri()` - Creates unique URIs using (name, type) tuple + - ✅ `EntityRegistry` - Tracks entities for deduplication + +3. **JSON Parser** (`simplified_parser.py`) + - ✅ Parses new format: `{entities: [...], relationships: [...], attributes: [...]}` + - ✅ Supports kebab-case and snake_case field names + - ✅ Returns structured dataclasses + - ✅ Graceful error handling with logging + +4. **Triple Converter** (`triple_converter.py`) + - ✅ `convert_entity()` - Generates type + label triples automatically + - ✅ `convert_relationship()` - Connects entity URIs via properties + - ✅ `convert_attribute()` - Adds literal values + - ✅ Looks up full URIs from ontology definitions + +5. **Updated Main Processor** (`extract.py`) + - ✅ Removed old triple-based extraction code + - ✅ Added `extract_with_simplified_format()` method + - ✅ Now exclusively uses new simplified format + - ✅ Calls prompt with `extract-with-ontologies-v2` ID + +## Test Cases + +### Test 1: URI Preservation +```python +# Given ontology class +classes = {"fo/Recipe": {"uri": "http://purl.org/ontology/fo/Recipe", ...}} + +# When LLM returns +llm_output = {"subject": "x", "predicate": "rdf:type", "object": "fo/Recipe"} + +# Then expanded URI should be +assert expanded == "http://purl.org/ontology/fo/Recipe" +# Not: "https://trustgraph.ai/ontology/food#Recipe" +``` + +### Test 2: Multi-Ontology Collision +```python +# Given two ontologies +ont1 = {"fo/Recipe": {...}} +ont2 = {"cooking/Recipe": {...}} + +# LLM should use full prefix to disambiguate +llm_output = {"object": "fo/Recipe"} # Not just "Recipe" +``` + +### Test 3: Entity Instance Format +```python +# Given prompt with food ontology +# LLM should create instances like +{"subject": "recipe:cornish-pasty"} # Namespace-style +{"subject": "food:beef"} # Consistent prefix +``` + +## Open Questions + +1. **Should entity instances use namespace prefixes?** + - Current: `"recipe:cornish-pasty"` (arbitrary) + - Alternative: Use ontology prefix `"fo:cornish-pasty"`? + - Alternative: No prefix, expand in URI `"cornish-pasty"` → full URI? + +2. **How to handle domain/range in prompt?** + - Currently shows: `(Recipe → Food)` + - Should it be: `(fo/Recipe → fo/Food)`? + +3. **Should we validate domain/range constraints?** + - TODO comment at extract.py:470 + - Would catch more errors but more complex + +4. **What about inverse properties and equivalences?** + - Ontology has `owl:inverseOf`, `owl:equivalentClass` + - Not currently used in extraction + - Should they be? + +## Success Metrics + +- ✅ Zero URI information loss (100% preservation of original URIs) +- ✅ LLM output format matches input format +- ✅ No ambiguous examples in prompt +- ✅ Tests pass with multiple ontologies +- ✅ Improved extraction quality (measured by valid triple %) + +## Alternative Approach: Simplified Extraction Format + +### Philosophy + +Instead of asking the LLM to understand RDF/OWL semantics, ask it to do what it's good at: **find entities and relationships in text**. + +Let the code handle URI construction, RDF conversion, and semantic web formalities. + +### Example: Entity Classification + +**Input Text:** +``` +Cornish pasty is a traditional British pastry filled with meat and vegetables. +``` + +**Ontology Schema (shown to LLM):** +```markdown +## Entity Types: +- Recipe: A recipe is a combination of ingredients and a method +- Food: A food is something that can be eaten +- Ingredient: An ingredient combines a quantity and a food +``` + +**What LLM Returns (Simple JSON):** +```json +{ + "entities": [ + { + "entity": "Cornish pasty", + "type": "Recipe" + } + ] +} +``` + +**What Code Produces (RDF Triples):** +```python +# 1. Normalize entity name + type to ID (type prevents collisions) +entity_id = "recipe-cornish-pasty" # normalize("Cornish pasty", "Recipe") +entity_uri = "https://trustgraph.ai/food/recipe-cornish-pasty" + +# Note: Same name, different type = different URI +# "Cornish pasty" (Recipe) → recipe-cornish-pasty +# "Cornish pasty" (Food) → food-cornish-pasty + +# 2. Generate triples +triples = [ + # Type triple + Triple( + s=Value(value=entity_uri, is_uri=True), + p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True), + o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True) + ), + # Label triple (automatic) + Triple( + s=Value(value=entity_uri, is_uri=True), + p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True), + o=Value(value="Cornish pasty", is_uri=False) + ) +] +``` + +### Benefits + +1. **LLM doesn't need to:** + - Understand URI syntax + - Invent identifier prefixes (`recipe:`, `ingredient:`) + - Know about `rdf:type` or `rdfs:label` + - Construct semantic web identifiers + +2. **LLM just needs to:** + - Find entities in text + - Map them to ontology classes + - Extract relationships and attributes + +3. **Code handles:** + - URI normalization and construction + - RDF triple generation + - Automatic label assignment + - Namespace management + +### Why This Works Better + +- **Simpler prompt** = less confusion = fewer errors +- **Consistent IDs** = code controls normalization rules +- **Auto-generated labels** = no missing rdfs:label triples +- **LLM focuses on extraction** = what it's actually good at + +### Example: Entity Relationships + +**Input Text:** +``` +Cornish pasty is a traditional British pastry filled with beef and potatoes. +``` + +**Ontology Schema (shown to LLM):** +```markdown +## Entity Types: +- Recipe: A recipe is a combination of ingredients and a method +- Food: A food is something that can be eaten +- Ingredient: An ingredient combines a quantity and a food + +## Relationships: +- has_ingredient: Relates a recipe to an ingredient it uses (Recipe → Ingredient) +- food: Relates an ingredient to the food that is required (Ingredient → Food) +``` + +**What LLM Returns (Simple JSON):** +```json +{ + "entities": [ + { + "entity": "Cornish pasty", + "type": "Recipe" + }, + { + "entity": "beef", + "type": "Food" + }, + { + "entity": "potatoes", + "type": "Food" + } + ], + "relationships": [ + { + "subject": "Cornish pasty", + "subject-type": "Recipe", + "relation": "has_ingredient", + "object": "beef", + "object-type": "Food" + }, + { + "subject": "Cornish pasty", + "subject-type": "Recipe", + "relation": "has_ingredient", + "object": "potatoes", + "object-type": "Food" + } + ] +} +``` + +**What Code Produces (RDF Triples):** +```python +# Normalize entity names to URIs +cornish_pasty_uri = "https://trustgraph.ai/food/cornish-pasty" +beef_uri = "https://trustgraph.ai/food/beef" +potatoes_uri = "https://trustgraph.ai/food/potatoes" + +# Look up relation URI from ontology +has_ingredient_uri = "http://purl.org/ontology/fo/ingredients" # from fo/has_ingredient + +triples = [ + # Entity type triples (as before) + Triple(s=cornish_pasty_uri, p=rdf_type, o="http://purl.org/ontology/fo/Recipe"), + Triple(s=cornish_pasty_uri, p=rdfs_label, o="Cornish pasty"), + + Triple(s=beef_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"), + Triple(s=beef_uri, p=rdfs_label, o="beef"), + + Triple(s=potatoes_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"), + Triple(s=potatoes_uri, p=rdfs_label, o="potatoes"), + + # Relationship triples + Triple( + s=Value(value=cornish_pasty_uri, is_uri=True), + p=Value(value=has_ingredient_uri, is_uri=True), + o=Value(value=beef_uri, is_uri=True) + ), + Triple( + s=Value(value=cornish_pasty_uri, is_uri=True), + p=Value(value=has_ingredient_uri, is_uri=True), + o=Value(value=potatoes_uri, is_uri=True) + ) +] +``` + +**Key Points:** +- LLM returns natural language entity names: `"Cornish pasty"`, `"beef"`, `"potatoes"` +- LLM includes types to disambiguate: `subject-type`, `object-type` +- LLM uses relation name from schema: `"has_ingredient"` +- Code derives consistent IDs using (name, type): `("Cornish pasty", "Recipe")` → `recipe-cornish-pasty` +- Code looks up relation URI from ontology: `fo/has_ingredient` → full URI +- Same (name, type) tuple always gets same URI (deduplication) + +### Example: Entity Name Disambiguation + +**Problem:** Same name can refer to different entity types. + +**Real-world case:** +``` +"Cornish pasty" can be: +- A Recipe (instructions for making it) +- A Food (the dish itself) +``` + +**How It's Handled:** + +LLM returns both as separate entities: +```json +{ + "entities": [ + {"entity": "Cornish pasty", "type": "Recipe"}, + {"entity": "Cornish pasty", "type": "Food"} + ], + "relationships": [ + { + "subject": "Cornish pasty", + "subject-type": "Recipe", + "relation": "produces", + "object": "Cornish pasty", + "object-type": "Food" + } + ] +} +``` + +**Code Resolution:** +```python +# Different types → different URIs +recipe_uri = normalize("Cornish pasty", "Recipe") +# → "https://trustgraph.ai/food/recipe-cornish-pasty" + +food_uri = normalize("Cornish pasty", "Food") +# → "https://trustgraph.ai/food/food-cornish-pasty" + +# Relationship connects them correctly +triple = Triple( + s=recipe_uri, # The Recipe + p="http://purl.org/ontology/fo/produces", + o=food_uri # The Food +) +``` + +**Why This Works:** +- Type is included in ALL references (entities, relationships, attributes) +- Code uses `(name, type)` tuple as lookup key +- No ambiguity, no collisions + +### Example: Entity Attributes + +**Input Text:** +``` +This Cornish pasty recipe serves 4-6 people and takes 45 minutes to prepare. +``` + +**Ontology Schema (shown to LLM):** +```markdown +## Entity Types: +- Recipe: A recipe is a combination of ingredients and a method + +## Attributes: +- serves: Indicates what the recipe is intended to serve (Recipe → text) +- preparation_time: Time needed to prepare the recipe (Recipe → text) +``` + +**What LLM Returns (Simple JSON):** +```json +{ + "entities": [ + { + "entity": "Cornish pasty recipe", + "type": "Recipe" + } + ], + "attributes": [ + { + "entity": "Cornish pasty recipe", + "entity-type": "Recipe", + "attribute": "serves", + "value": "4-6 people" + }, + { + "entity": "Cornish pasty recipe", + "entity-type": "Recipe", + "attribute": "preparation_time", + "value": "45 minutes" + } + ] +} +``` + +**What Code Produces (RDF Triples):** +```python +# Normalize entity name to URI +recipe_uri = "https://trustgraph.ai/food/cornish-pasty-recipe" + +# Look up attribute URIs from ontology +serves_uri = "http://purl.org/ontology/fo/serves" # from fo/serves +prep_time_uri = "http://purl.org/ontology/fo/preparation_time" # from fo/preparation_time + +triples = [ + # Entity type triple + Triple( + s=Value(value=recipe_uri, is_uri=True), + p=Value(value=rdf_type, is_uri=True), + o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True) + ), + + # Label triple (automatic) + Triple( + s=Value(value=recipe_uri, is_uri=True), + p=Value(value=rdfs_label, is_uri=True), + o=Value(value="Cornish pasty recipe", is_uri=False) + ), + + # Attribute triples (objects are literals, not URIs) + Triple( + s=Value(value=recipe_uri, is_uri=True), + p=Value(value=serves_uri, is_uri=True), + o=Value(value="4-6 people", is_uri=False) # Literal value! + ), + Triple( + s=Value(value=recipe_uri, is_uri=True), + p=Value(value=prep_time_uri, is_uri=True), + o=Value(value="45 minutes", is_uri=False) # Literal value! + ) +] +``` + +**Key Points:** +- LLM extracts literal values: `"4-6 people"`, `"45 minutes"` +- LLM includes entity type for disambiguation: `entity-type` +- LLM uses attribute name from schema: `"serves"`, `"preparation_time"` +- Code looks up attribute URI from ontology datatype properties +- **Object is literal** (`is_uri=False`), not a URI reference +- Values stay as natural text, no normalization needed + +**Difference from Relationships:** +- Relationships: both subject and object are entities (URIs) +- Attributes: subject is entity (URI), object is literal value (string/number) + +### Complete Example: Entities + Relationships + Attributes + +**Input Text:** +``` +Cornish pasty is a savory pastry filled with beef and potatoes. +This recipe serves 4 people. +``` + +**What LLM Returns:** +```json +{ + "entities": [ + { + "entity": "Cornish pasty", + "type": "Recipe" + }, + { + "entity": "beef", + "type": "Food" + }, + { + "entity": "potatoes", + "type": "Food" + } + ], + "relationships": [ + { + "subject": "Cornish pasty", + "subject-type": "Recipe", + "relation": "has_ingredient", + "object": "beef", + "object-type": "Food" + }, + { + "subject": "Cornish pasty", + "subject-type": "Recipe", + "relation": "has_ingredient", + "object": "potatoes", + "object-type": "Food" + } + ], + "attributes": [ + { + "entity": "Cornish pasty", + "entity-type": "Recipe", + "attribute": "serves", + "value": "4 people" + } + ] +} +``` + +**Result:** 11 RDF triples generated: +- 3 entity type triples (rdf:type) +- 3 entity label triples (rdfs:label) - automatic +- 2 relationship triples (has_ingredient) +- 1 attribute triple (serves) + +All from simple, natural language extractions by the LLM! + +## References + +- Current implementation: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py` +- Prompt template: `ontology-prompt.md` +- Test cases: `tests/unit/test_extract/test_ontology/` +- Example ontology: `e2e/test-data/food.ontology` diff --git a/docs/tech-specs/pubsub.md b/docs/tech-specs/pubsub.md new file mode 100644 index 00000000..38836838 --- /dev/null +++ b/docs/tech-specs/pubsub.md @@ -0,0 +1,958 @@ +# Pub/Sub Infrastructure + +## Overview + +This document catalogs all connections between the TrustGraph codebase and the pub/sub infrastructure. Currently, the system is hardcoded to use Apache Pulsar. This analysis identifies all integration points to inform future refactoring toward a configurable pub/sub abstraction. + +## Current State: Pulsar Integration Points + +### 1. Direct Pulsar Client Usage + +**Location:** `trustgraph-flow/trustgraph/gateway/service.py` + +The API gateway directly imports and instantiates the Pulsar client: + +- **Line 20:** `import pulsar` +- **Lines 54-61:** Direct instantiation of `pulsar.Client()` with optional `pulsar.AuthenticationToken()` +- **Lines 33-35:** Default Pulsar host configuration from environment variables +- **Lines 178-192:** CLI arguments for `--pulsar-host`, `--pulsar-api-key`, and `--pulsar-listener` +- **Lines 78, 124:** Passes `pulsar_client` to `ConfigReceiver` and `DispatcherManager` + +This is the only location that directly instantiates a Pulsar client outside of the abstraction layer. + +### 2. Base Processor Framework + +**Location:** `trustgraph-base/trustgraph/base/async_processor.py` + +The base class for all processors provides Pulsar connectivity: + +- **Line 9:** `import _pulsar` (for exception handling) +- **Line 18:** `from . pubsub import PulsarClient` +- **Line 38:** Creates `pulsar_client_object = PulsarClient(**params)` +- **Lines 104-108:** Properties exposing `pulsar_host` and `pulsar_client` +- **Line 250:** Static method `add_args()` calls `PulsarClient.add_args(parser)` for CLI arguments +- **Lines 223-225:** Exception handling for `_pulsar.Interrupted` + +All processors inherit from `AsyncProcessor`, making this the central integration point. + +### 3. Consumer Abstraction + +**Location:** `trustgraph-base/trustgraph/base/consumer.py` + +Consumes messages from queues and invokes handler functions: + +**Pulsar imports:** +- **Line 12:** `from pulsar.schema import JsonSchema` +- **Line 13:** `import pulsar` +- **Line 14:** `import _pulsar` + +**Pulsar-specific usage:** +- **Lines 100, 102:** `pulsar.InitialPosition.Earliest` / `pulsar.InitialPosition.Latest` +- **Line 108:** `JsonSchema(self.schema)` wrapper +- **Line 110:** `pulsar.ConsumerType.Shared` +- **Lines 104-111:** `self.client.subscribe()` with Pulsar-specific parameters +- **Lines 143, 150, 65:** `consumer.unsubscribe()` and `consumer.close()` methods +- **Line 162:** `_pulsar.Timeout` exception +- **Lines 182, 205, 232:** `consumer.acknowledge()` / `consumer.negative_acknowledge()` + +**Spec file:** `trustgraph-base/trustgraph/base/consumer_spec.py` +- **Line 22:** References `processor.pulsar_client` + +### 4. Producer Abstraction + +**Location:** `trustgraph-base/trustgraph/base/producer.py` + +Sends messages to queues: + +**Pulsar imports:** +- **Line 2:** `from pulsar.schema import JsonSchema` + +**Pulsar-specific usage:** +- **Line 49:** `JsonSchema(self.schema)` wrapper +- **Lines 47-51:** `self.client.create_producer()` with Pulsar-specific parameters (topic, schema, chunking_enabled) +- **Lines 31, 76:** `producer.close()` method +- **Lines 64-65:** `producer.send()` with message and properties + +**Spec file:** `trustgraph-base/trustgraph/base/producer_spec.py` +- **Line 18:** References `processor.pulsar_client` + +### 5. Publisher Abstraction + +**Location:** `trustgraph-base/trustgraph/base/publisher.py` + +Asynchronous message publishing with queue buffering: + +**Pulsar imports:** +- **Line 2:** `from pulsar.schema import JsonSchema` +- **Line 6:** `import pulsar` + +**Pulsar-specific usage:** +- **Line 52:** `JsonSchema(self.schema)` wrapper +- **Lines 50-54:** `self.client.create_producer()` with Pulsar-specific parameters +- **Lines 101, 103:** `producer.send()` with message and optional properties +- **Lines 106-107:** `producer.flush()` and `producer.close()` methods + +### 6. Subscriber Abstraction + +**Location:** `trustgraph-base/trustgraph/base/subscriber.py` + +Provides multi-recipient message distribution from queues: + +**Pulsar imports:** +- **Line 6:** `from pulsar.schema import JsonSchema` +- **Line 8:** `import _pulsar` + +**Pulsar-specific usage:** +- **Line 55:** `JsonSchema(self.schema)` wrapper +- **Line 57:** `self.client.subscribe(**subscribe_args)` +- **Lines 101, 136, 160, 167-172:** Pulsar exceptions: `_pulsar.Timeout`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed` +- **Lines 159, 166, 170:** Consumer methods: `negative_acknowledge()`, `unsubscribe()`, `close()` +- **Lines 247, 251:** Message acknowledgment: `acknowledge()`, `negative_acknowledge()` + +**Spec file:** `trustgraph-base/trustgraph/base/subscriber_spec.py` +- **Line 19:** References `processor.pulsar_client` + +### 7. Schema System (Heart of Darkness) + +**Location:** `trustgraph-base/trustgraph/schema/` + +Every message schema in the system is defined using Pulsar's schema framework. + +**Core primitives:** `schema/core/primitives.py` +- **Line 2:** `from pulsar.schema import Record, String, Boolean, Array, Integer` +- All schemas inherit from Pulsar's `Record` base class +- All field types are Pulsar types: `String()`, `Integer()`, `Boolean()`, `Array()`, `Map()`, `Double()` + +**Example schemas:** +- `schema/services/llm.py` (Line 2): `from pulsar.schema import Record, String, Array, Double, Integer, Boolean` +- `schema/services/config.py` (Line 2): `from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer` + +**Topic naming:** `schema/core/topic.py` +- **Lines 2-3:** Topic format: `{kind}://{tenant}/{namespace}/{topic}` +- This URI structure is Pulsar-specific (e.g., `persistent://tg/flow/config`) + +**Impact:** +- All request/response message definitions throughout the codebase use Pulsar schemas +- This includes services for: config, flow, llm, prompt, query, storage, agent, collection, diagnosis, library, lookup, nlp_query, objects_query, retrieval, structured_query +- Schema definitions are imported and used extensively across all processors and services + +## Summary + +### Pulsar Dependencies by Category + +1. **Client instantiation:** + - Direct: `gateway/service.py` + - Abstracted: `async_processor.py` → `pubsub.py` (PulsarClient) + +2. **Message transport:** + - Consumer: `consumer.py`, `consumer_spec.py` + - Producer: `producer.py`, `producer_spec.py` + - Publisher: `publisher.py` + - Subscriber: `subscriber.py`, `subscriber_spec.py` + +3. **Schema system:** + - Base types: `schema/core/primitives.py` + - All service schemas: `schema/services/*.py` + - Topic naming: `schema/core/topic.py` + +4. **Pulsar-specific concepts required:** + - Topic-based messaging + - Schema system (Record, field types) + - Shared subscriptions + - Message acknowledgment (positive/negative) + - Consumer positioning (earliest/latest) + - Message properties + - Initial positions and consumer types + - Chunking support + - Persistent vs non-persistent topics + +### Refactoring Challenges + +The good news: The abstraction layer (Consumer, Producer, Publisher, Subscriber) provides a clean encapsulation of most Pulsar interactions. + +The challenges: +1. **Schema system pervasiveness:** Every message definition uses `pulsar.schema.Record` and Pulsar field types +2. **Pulsar-specific enums:** `InitialPosition`, `ConsumerType` +3. **Pulsar exceptions:** `_pulsar.Timeout`, `_pulsar.Interrupted`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed` +4. **Method signatures:** `acknowledge()`, `negative_acknowledge()`, `subscribe()`, `create_producer()`, etc. +5. **Topic URI format:** Pulsar's `kind://tenant/namespace/topic` structure + +### Next Steps + +To make the pub/sub infrastructure configurable, we need to: + +1. Create an abstraction interface for the client/schema system +2. Abstract Pulsar-specific enums and exceptions +3. Create schema wrappers or alternative schema definitions +4. Implement the interface for both Pulsar and alternative systems (Kafka, RabbitMQ, Redis Streams, etc.) +5. Update `pubsub.py` to be configurable and support multiple backends +6. Provide migration path for existing deployments + +## Approach Draft 1: Adapter Pattern with Schema Translation Layer + +### Key Insight +The **schema system** is the deepest integration point - everything else flows from it. We need to solve this first, or we'll be rewriting the entire codebase. + +### Strategy: Minimal Disruption with Adapters + +**1. Keep Pulsar schemas as the internal representation** +- Don't rewrite all the schema definitions +- Schemas remain `pulsar.schema.Record` internally +- Use adapters to translate at the boundary between our code and the pub/sub backend + +**2. Create a pub/sub abstraction layer:** + +``` +┌─────────────────────────────────────┐ +│ Existing Code (unchanged) │ +│ - Uses Pulsar schemas internally │ +│ - Consumer/Producer/Publisher │ +└──────────────┬──────────────────────┘ + │ +┌──────────────┴──────────────────────┐ +│ PubSubFactory (configurable) │ +│ - Creates backend-specific client │ +└──────────────┬──────────────────────┘ + │ + ┌──────┴──────┐ + │ │ +┌───────▼─────┐ ┌────▼─────────┐ +│ PulsarAdapter│ │ KafkaAdapter │ etc... +│ (passthrough)│ │ (translates) │ +└──────────────┘ └──────────────┘ +``` + +**3. Define abstract interfaces:** +- `PubSubClient` - client connection +- `PubSubProducer` - sending messages +- `PubSubConsumer` - receiving messages +- `SchemaAdapter` - translating Pulsar schemas to/from JSON or backend-specific formats + +**4. Implementation details:** + +For **Pulsar adapter**: Nearly passthrough, minimal translation + +For **other backends** (Kafka, RabbitMQ, etc.): +- Serialize Pulsar Record objects to JSON/bytes +- Map concepts like: + - `InitialPosition.Earliest/Latest` → Kafka's auto.offset.reset + - `acknowledge()` → Kafka's commit + - `negative_acknowledge()` → Re-queue or DLQ pattern + - Topic URIs → Backend-specific topic names + +### Analysis + +**Pros:** +- ✅ Minimal code changes to existing services +- ✅ Schemas stay as-is (no massive rewrite) +- ✅ Gradual migration path +- ✅ Pulsar users see no difference +- ✅ New backends added via adapters + +**Cons:** +- ⚠️ Still carries Pulsar dependency (for schema definitions) +- ⚠️ Some impedance mismatch translating concepts + +### Alternative Consideration + +Create a **TrustGraph schema system** that's pub/sub agnostic (using dataclasses or Pydantic), then generate Pulsar/Kafka/etc schemas from it. This requires rewriting every schema file and potentially breaking changes. + +### Recommendation for Draft 1 + +Start with the **adapter approach** because: +1. It's pragmatic - works with existing code +2. Proves the concept with minimal risk +3. Can evolve to a native schema system later if needed +4. Configuration-driven: one env var switches backends + +## Approach Draft 2: Backend-Agnostic Schema System with Dataclasses + +### Core Concept + +Use Python **dataclasses** as the neutral schema definition format. Each pub/sub backend provides its own serialization/deserialization for dataclasses, eliminating the need for Pulsar schemas to remain in the codebase. + +### Schema Polymorphism at the Factory Level + +Instead of translating Pulsar schemas, **each backend provides its own schema handling** that works with standard Python dataclasses. + +### Publisher Flow + +```python +# 1. Get the configured backend from factory +pubsub = get_pubsub() # Returns PulsarBackend, MQTTBackend, etc. + +# 2. Get schema class from the backend +# (Can be imported directly - backend-agnostic) +from trustgraph.schema.services.llm import TextCompletionRequest + +# 3. Create a producer/publisher for a specific topic +producer = pubsub.create_producer( + topic="text-completion-requests", + schema=TextCompletionRequest # Tells backend what schema to use +) + +# 4. Create message instances (same API regardless of backend) +request = TextCompletionRequest( + system="You are helpful", + prompt="Hello world", + streaming=False +) + +# 5. Send the message +producer.send(request) # Backend serializes appropriately +``` + +### Consumer Flow + +```python +# 1. Get the configured backend +pubsub = get_pubsub() + +# 2. Create a consumer +consumer = pubsub.subscribe( + topic="text-completion-requests", + schema=TextCompletionRequest # Tells backend how to deserialize +) + +# 3. Receive and deserialize +msg = consumer.receive() +request = msg.value() # Returns TextCompletionRequest dataclass instance + +# 4. Use the data (type-safe access) +print(request.system) # "You are helpful" +print(request.prompt) # "Hello world" +print(request.streaming) # False +``` + +### What Happens Behind the Scenes + +**For Pulsar backend:** +- `create_producer()` → creates Pulsar producer with JSON schema or dynamically generated Record +- `send(request)` → serializes dataclass to JSON/Pulsar format, sends to Pulsar +- `receive()` → gets Pulsar message, deserializes back to dataclass + +**For MQTT backend:** +- `create_producer()` → connects to MQTT broker, no schema registration needed +- `send(request)` → converts dataclass to JSON, publishes to MQTT topic +- `receive()` → subscribes to MQTT topic, deserializes JSON to dataclass + +**For Kafka backend:** +- `create_producer()` → creates Kafka producer, registers Avro schema if needed +- `send(request)` → serializes dataclass to Avro format, sends to Kafka +- `receive()` → gets Kafka message, deserializes Avro back to dataclass + +### Key Design Points + +1. **Schema object creation**: The dataclass instance (`TextCompletionRequest(...)`) is identical regardless of backend +2. **Backend handles encoding**: Each backend knows how to serialize its dataclass to the wire format +3. **Schema definition at creation**: When creating producer/consumer, you specify the schema type +4. **Type safety preserved**: You get back a proper `TextCompletionRequest` object, not a dict +5. **No backend leakage**: Application code never imports backend-specific libraries + +### Example Transformation + +**Current (Pulsar-specific):** +```python +# schema/services/llm.py +from pulsar.schema import Record, String, Boolean, Integer + +class TextCompletionRequest(Record): + system = String() + prompt = String() + streaming = Boolean() +``` + +**New (Backend-agnostic):** +```python +# schema/services/llm.py +from dataclasses import dataclass + +@dataclass +class TextCompletionRequest: + system: str + prompt: str + streaming: bool = False +``` + +### Backend Integration + +Each backend handles serialization/deserialization of dataclasses: + +**Pulsar backend:** +- Dynamically generate `pulsar.schema.Record` classes from dataclasses +- Or serialize dataclasses to JSON and use Pulsar's JSON schema +- Maintains compatibility with existing Pulsar deployments + +**MQTT/Redis backend:** +- Direct JSON serialization of dataclass instances +- Use `dataclasses.asdict()` / `from_dict()` +- Lightweight, no schema registry needed + +**Kafka backend:** +- Generate Avro schemas from dataclass definitions +- Use Confluent's schema registry +- Type-safe serialization with schema evolution support + +### Architecture + +``` +┌─────────────────────────────────────┐ +│ Application Code │ +│ - Uses dataclass schemas │ +│ - Backend-agnostic │ +└──────────────┬──────────────────────┘ + │ +┌──────────────┴──────────────────────┐ +│ PubSubFactory (configurable) │ +│ - get_pubsub() returns backend │ +└──────────────┬──────────────────────┘ + │ + ┌──────┴──────┐ + │ │ +┌───────▼─────────┐ ┌────▼──────────────┐ +│ PulsarBackend │ │ MQTTBackend │ +│ - JSON schema │ │ - JSON serialize │ +│ - or dynamic │ │ - Simple queues │ +│ Record gen │ │ │ +└─────────────────┘ └───────────────────┘ +``` + +### Implementation Details + +**1. Schema definitions:** Plain dataclasses with type hints + - `str`, `int`, `bool`, `float` for primitives + - `list[T]` for arrays + - `dict[str, T]` for maps + - Nested dataclasses for complex types + +**2. Each backend provides:** + - Serializer: `dataclass → bytes/wire format` + - Deserializer: `bytes/wire format → dataclass` + - Schema registration (if needed, like Pulsar/Kafka) + +**3. Consumer/Producer abstraction:** + - Already exists (consumer.py, producer.py) + - Update to use backend's serialization + - Remove direct Pulsar imports + +**4. Type mappings:** + - Pulsar `String()` → Python `str` + - Pulsar `Integer()` → Python `int` + - Pulsar `Boolean()` → Python `bool` + - Pulsar `Array(T)` → Python `list[T]` + - Pulsar `Map(K, V)` → Python `dict[K, V]` + - Pulsar `Double()` → Python `float` + - Pulsar `Bytes()` → Python `bytes` + +### Migration Path + +1. **Create dataclass versions** of all schemas in `trustgraph/schema/` +2. **Update backend classes** (Consumer, Producer, Publisher, Subscriber) to use backend-provided serialization +3. **Implement PulsarBackend** with JSON schema or dynamic Record generation +4. **Test with Pulsar** to ensure backward compatibility with existing deployments +5. **Add new backends** (MQTT, Kafka, Redis, etc.) as needed +6. **Remove Pulsar imports** from schema files + +### Benefits + +✅ **No pub/sub dependency** in schema definitions +✅ **Standard Python** - easy to understand, type-check, document +✅ **Modern tooling** - works with mypy, IDE autocomplete, linters +✅ **Backend-optimized** - each backend uses native serialization +✅ **No translation overhead** - direct serialization, no adapters +✅ **Type safety** - real objects with proper types +✅ **Easy validation** - can use Pydantic if needed + +### Challenges & Solutions + +**Challenge:** Pulsar's `Record` has runtime field validation +**Solution:** Use Pydantic dataclasses for validation if needed, or Python 3.10+ dataclass features with `__post_init__` + +**Challenge:** Some Pulsar-specific features (like `Bytes` type) +**Solution:** Map to `bytes` type in dataclass, backend handles encoding appropriately + +**Challenge:** Topic naming (`persistent://tenant/namespace/topic`) +**Solution:** Abstract topic names in schema definitions, backend converts to proper format + +**Challenge:** Schema evolution and versioning +**Solution:** Each backend handles this according to its capabilities (Pulsar schema versions, Kafka schema registry, etc.) + +**Challenge:** Nested complex types +**Solution:** Use nested dataclasses, backends recursively serialize/deserialize + +### Design Decisions + +1. **Plain dataclasses or Pydantic?** + - ✅ **Decision: Use plain Python dataclasses** + - Simpler, no additional dependencies + - Validation not required in practice + - Easier to understand and maintain + +2. **Schema evolution:** + - ✅ **Decision: No versioning mechanism needed** + - Schemas are stable and long-lasting + - Updates typically add new fields (backward compatible) + - Backends handle schema evolution according to their capabilities + +3. **Backward compatibility:** + - ✅ **Decision: Major version change, no backward compatibility required** + - Will be a breaking change with migration instructions + - Clean break allows for better design + - Migration guide will be provided for existing deployments + +4. **Nested types and complex structures:** + - ✅ **Decision: Use nested dataclasses naturally** + - Python dataclasses handle nesting perfectly + - `list[T]` for arrays, `dict[K, V]` for maps + - Backends recursively serialize/deserialize + - Example: + ```python + @dataclass + class Value: + value: str + is_uri: bool + + @dataclass + class Triple: + s: Value # Nested dataclass + p: Value + o: Value + + @dataclass + class GraphQuery: + triples: list[Triple] # Array of nested dataclasses + metadata: dict[str, str] + ``` + +5. **Default values and optional fields:** + - ✅ **Decision: Mix of required, defaults, and optional fields** + - Required fields: No default value + - Fields with defaults: Always present, have sensible default + - Truly optional fields: `T | None = None`, omitted from serialization when `None` + - Example: + ```python + @dataclass + class TextCompletionRequest: + system: str # Required, no default + prompt: str # Required, no default + streaming: bool = False # Optional with default value + metadata: dict | None = None # Truly optional, can be absent + ``` + + **Important serialization semantics:** + + When `metadata = None`: + ```json + { + "system": "...", + "prompt": "...", + "streaming": false + // metadata field NOT PRESENT + } + ``` + + When `metadata = {}` (explicitly empty): + ```json + { + "system": "...", + "prompt": "...", + "streaming": false, + "metadata": {} // Field PRESENT but empty + } + ``` + + **Key distinction:** + - `None` → field absent from JSON (not serialized) + - Empty value (`{}`, `[]`, `""`) → field present with empty value + - This matters semantically: "not provided" vs "explicitly empty" + - Serialization backends must skip `None` fields, not encode as `null` + +## Approach Draft 3: Implementation Details + +### Generic Queue Naming Format + +Replace backend-specific queue names with a generic format that backends can map appropriately. + +**Format:** `{qos}/{tenant}/{namespace}/{queue-name}` + +Where: +- `qos`: Quality of Service level + - `q0` = best-effort (fire and forget, no acknowledgment) + - `q1` = at-least-once (requires acknowledgment) + - `q2` = exactly-once (two-phase acknowledgment) +- `tenant`: Logical grouping for multi-tenancy +- `namespace`: Sub-grouping within tenant +- `queue-name`: Actual queue/topic name + +**Examples:** +``` +q1/tg/flow/text-completion-requests +q2/tg/config/config-push +q0/tg/metrics/stats +``` + +### Backend Topic Mapping + +Each backend maps the generic format to its native format: + +**Pulsar Backend:** +```python +def map_topic(self, generic_topic: str) -> str: + # Parse: q1/tg/flow/text-completion-requests + qos, tenant, namespace, queue = generic_topic.split('/', 3) + + # Map QoS to persistence + persistence = 'persistent' if qos in ['q1', 'q2'] else 'non-persistent' + + # Return Pulsar URI: persistent://tg/flow/text-completion-requests + return f"{persistence}://{tenant}/{namespace}/{queue}" +``` + +**MQTT Backend:** +```python +def map_topic(self, generic_topic: str) -> tuple[str, int]: + # Parse: q1/tg/flow/text-completion-requests + qos, tenant, namespace, queue = generic_topic.split('/', 3) + + # Map QoS level + qos_level = {'q0': 0, 'q1': 1, 'q2': 2}[qos] + + # Build MQTT topic including tenant/namespace for proper namespacing + mqtt_topic = f"{tenant}/{namespace}/{queue}" + + return mqtt_topic, qos_level +``` + +### Updated Topic Helper Function + +```python +# schema/core/topic.py +def topic(queue_name, qos='q1', tenant='tg', namespace='flow'): + """ + Create a generic topic identifier that can be mapped by backends. + + Args: + queue_name: The queue/topic name + qos: Quality of service + - 'q0' = best-effort (no ack) + - 'q1' = at-least-once (ack required) + - 'q2' = exactly-once (two-phase ack) + tenant: Tenant identifier for multi-tenancy + namespace: Namespace within tenant + + Returns: + Generic topic string: qos/tenant/namespace/queue_name + + Examples: + topic('my-queue') # q1/tg/flow/my-queue + topic('config', qos='q2', namespace='config') # q2/tg/config/config + """ + return f"{qos}/{tenant}/{namespace}/{queue_name}" +``` + +### Configuration and Initialization + +**Command-Line Arguments + Environment Variables:** + +```python +# In base/async_processor.py - add_args() method +@staticmethod +def add_args(parser): + # Pub/sub backend selection + parser.add_argument( + '--pubsub-backend', + default=os.getenv('PUBSUB_BACKEND', 'pulsar'), + choices=['pulsar', 'mqtt'], + help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)' + ) + + # Pulsar-specific configuration + parser.add_argument( + '--pulsar-host', + default=os.getenv('PULSAR_HOST', 'pulsar://localhost:6650'), + help='Pulsar host (default: pulsar://localhost:6650, env: PULSAR_HOST)' + ) + + parser.add_argument( + '--pulsar-api-key', + default=os.getenv('PULSAR_API_KEY', None), + help='Pulsar API key (env: PULSAR_API_KEY)' + ) + + parser.add_argument( + '--pulsar-listener', + default=os.getenv('PULSAR_LISTENER', None), + help='Pulsar listener name (env: PULSAR_LISTENER)' + ) + + # MQTT-specific configuration + parser.add_argument( + '--mqtt-host', + default=os.getenv('MQTT_HOST', 'localhost'), + help='MQTT broker host (default: localhost, env: MQTT_HOST)' + ) + + parser.add_argument( + '--mqtt-port', + type=int, + default=int(os.getenv('MQTT_PORT', '1883')), + help='MQTT broker port (default: 1883, env: MQTT_PORT)' + ) + + parser.add_argument( + '--mqtt-username', + default=os.getenv('MQTT_USERNAME', None), + help='MQTT username (env: MQTT_USERNAME)' + ) + + parser.add_argument( + '--mqtt-password', + default=os.getenv('MQTT_PASSWORD', None), + help='MQTT password (env: MQTT_PASSWORD)' + ) +``` + +**Factory Function:** + +```python +# In base/pubsub.py or base/pubsub_factory.py +def get_pubsub(**config) -> PubSubBackend: + """ + Create and return a pub/sub backend based on configuration. + + Args: + config: Configuration dict from command-line args + Must include 'pubsub_backend' key + + Returns: + Backend instance (PulsarBackend, MQTTBackend, etc.) + """ + backend_type = config.get('pubsub_backend', 'pulsar') + + if backend_type == 'pulsar': + return PulsarBackend( + host=config.get('pulsar_host'), + api_key=config.get('pulsar_api_key'), + listener=config.get('pulsar_listener'), + ) + elif backend_type == 'mqtt': + return MQTTBackend( + host=config.get('mqtt_host'), + port=config.get('mqtt_port'), + username=config.get('mqtt_username'), + password=config.get('mqtt_password'), + ) + else: + raise ValueError(f"Unknown pub/sub backend: {backend_type}") +``` + +**Usage in AsyncProcessor:** + +```python +# In async_processor.py +class AsyncProcessor: + def __init__(self, **params): + self.id = params.get("id") + + # Create backend from config (replaces PulsarClient) + self.pubsub = get_pubsub(**params) + + # Rest of initialization... +``` + +### Backend Interface + +```python +class PubSubBackend(Protocol): + """Protocol defining the interface all pub/sub backends must implement.""" + + def create_producer(self, topic: str, schema: type, **options) -> BackendProducer: + """ + Create a producer for a topic. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + schema: Dataclass type for messages + options: Backend-specific options (e.g., chunking_enabled) + + Returns: + Backend-specific producer instance + """ + ... + + def create_consumer( + self, + topic: str, + subscription: str, + schema: type, + initial_position: str = 'latest', + consumer_type: str = 'shared', + **options + ) -> BackendConsumer: + """ + Create a consumer for a topic. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + subscription: Subscription/consumer group name + schema: Dataclass type for messages + initial_position: 'earliest' or 'latest' (MQTT may ignore) + consumer_type: 'shared', 'exclusive', 'failover' (MQTT may ignore) + options: Backend-specific options + + Returns: + Backend-specific consumer instance + """ + ... + + def close(self) -> None: + """Close the backend connection.""" + ... +``` + +```python +class BackendProducer(Protocol): + """Protocol for backend-specific producer.""" + + def send(self, message: Any, properties: dict = {}) -> None: + """Send a message (dataclass instance) with optional properties.""" + ... + + def flush(self) -> None: + """Flush any buffered messages.""" + ... + + def close(self) -> None: + """Close the producer.""" + ... +``` + +```python +class BackendConsumer(Protocol): + """Protocol for backend-specific consumer.""" + + def receive(self, timeout_millis: int = 2000) -> Message: + """ + Receive a message from the topic. + + Raises: + TimeoutError: If no message received within timeout + """ + ... + + def acknowledge(self, message: Message) -> None: + """Acknowledge successful processing of a message.""" + ... + + def negative_acknowledge(self, message: Message) -> None: + """Negative acknowledge - triggers redelivery.""" + ... + + def unsubscribe(self) -> None: + """Unsubscribe from the topic.""" + ... + + def close(self) -> None: + """Close the consumer.""" + ... +``` + +```python +class Message(Protocol): + """Protocol for a received message.""" + + def value(self) -> Any: + """Get the deserialized message (dataclass instance).""" + ... + + def properties(self) -> dict: + """Get message properties/metadata.""" + ... +``` + +### Existing Classes Refactoring + +The existing `Consumer`, `Producer`, `Publisher`, `Subscriber` classes remain largely intact: + +**Current responsibilities (keep):** +- Async threading model and taskgroups +- Reconnection logic and retry handling +- Metrics collection +- Rate limiting +- Concurrency management + +**Changes needed:** +- Remove direct Pulsar imports (`pulsar.schema`, `pulsar.InitialPosition`, etc.) +- Accept `BackendProducer`/`BackendConsumer` instead of Pulsar client +- Delegate actual pub/sub operations to backend instances +- Map generic concepts to backend calls + +**Example refactoring:** + +```python +# OLD - consumer.py +class Consumer: + def __init__(self, client, topic, subscriber, schema, ...): + self.client = client # Direct Pulsar client + # ... + + async def consumer_run(self): + # Uses pulsar.InitialPosition, pulsar.ConsumerType + self.consumer = self.client.subscribe( + topic=self.topic, + schema=JsonSchema(self.schema), + initial_position=pulsar.InitialPosition.Earliest, + consumer_type=pulsar.ConsumerType.Shared, + ) + +# NEW - consumer.py +class Consumer: + def __init__(self, backend_consumer, schema, ...): + self.backend_consumer = backend_consumer # Backend-specific consumer + self.schema = schema + # ... + + async def consumer_run(self): + # Backend consumer already created with right settings + # Just use it directly + while self.running: + msg = await asyncio.to_thread( + self.backend_consumer.receive, + timeout_millis=2000 + ) + await self.handle_message(msg) +``` + +### Backend-Specific Behaviors + +**Pulsar Backend:** +- Maps `q0` → `non-persistent://`, `q1`/`q2` → `persistent://` +- Supports all consumer types (shared, exclusive, failover) +- Supports initial position (earliest/latest) +- Native message acknowledgment +- Schema registry support + +**MQTT Backend:** +- Maps `q0`/`q1`/`q2` → MQTT QoS levels 0/1/2 +- Includes tenant/namespace in topic path for namespacing +- Auto-generates client IDs from subscription names +- Ignores initial position (no message history in basic MQTT) +- Ignores consumer type (MQTT uses client IDs, not consumer groups) +- Simple publish/subscribe model + +### Design Decisions Summary + +1. ✅ **Generic queue naming**: `qos/tenant/namespace/queue-name` format +2. ✅ **QoS in queue ID**: Determined by queue definition, not configuration +3. ✅ **Reconnection**: Handled by Consumer/Producer classes, not backends +4. ✅ **MQTT topics**: Include tenant/namespace for proper namespacing +5. ✅ **Message history**: MQTT ignores `initial_position` parameter (future enhancement) +6. ✅ **Client IDs**: MQTT backend auto-generates from subscription name + +### Future Enhancements + +**MQTT message history:** +- Could add optional persistence layer (e.g., retained messages, external store) +- Would allow supporting `initial_position='earliest'` +- Not required for initial implementation + diff --git a/docs/tech-specs/python-api-refactor.md b/docs/tech-specs/python-api-refactor.md new file mode 100644 index 00000000..6fcf2f22 --- /dev/null +++ b/docs/tech-specs/python-api-refactor.md @@ -0,0 +1,1508 @@ +# Python API Refactor Technical Specification + +## Overview + +This specification describes a comprehensive refactor of the TrustGraph Python API client library to achieve feature parity with the API Gateway and add support for modern real-time communication patterns. + +The refactor addresses four primary use cases: + +1. **Streaming LLM Interactions**: Enable real-time streaming of LLM responses (agent, graph RAG, document RAG, text completion, prompts) with ~60x lower latency (500ms vs 30s for first token) +2. **Bulk Data Operations**: Support efficient bulk import/export of triples, graph embeddings, and document embeddings for large-scale knowledge graph management +3. **Feature Parity**: Ensure every API Gateway endpoint has a corresponding Python API method, including graph embeddings query +4. **Persistent Connections**: Enable WebSocket-based communication for multiplexed requests and reduced connection overhead + +## Goals + +- **Feature Parity**: Every Gateway API service has a corresponding Python API method +- **Streaming Support**: All streaming-capable services (agent, RAG, text completion, prompt) support streaming in Python API +- **WebSocket Transport**: Add optional WebSocket transport layer for persistent connections and multiplexing +- **Bulk Operations**: Add efficient bulk import/export for triples, graph embeddings, and document embeddings +- **Full Async Support**: Complete async/await implementation for all interfaces (REST, WebSocket, bulk operations, metrics) +- **Backward Compatibility**: Existing code continues to work without modification +- **Type Safety**: Maintain type-safe interfaces with dataclasses and type hints +- **Progressive Enhancement**: Streaming and async are opt-in via explicit interface selection +- **Performance**: Achieve 60x latency improvement for streaming operations +- **Modern Python**: Support for both sync and async paradigms for maximum flexibility + +## Background + +### Current State + +The Python API (`trustgraph-base/trustgraph/api/`) is a REST-only client library with the following modules: + +- `flow.py`: Flow management and flow-scoped services (50 methods) +- `library.py`: Document library operations (9 methods) +- `knowledge.py`: KG core management (4 methods) +- `collection.py`: Collection metadata (3 methods) +- `config.py`: Configuration management (6 methods) +- `types.py`: Data type definitions (5 dataclasses) + +**Total Operations**: 50/59 (85% coverage) + +### Current Limitations + +**Missing Operations**: +- Graph embeddings query (semantic search over graph entities) +- Bulk import/export for triples, graph embeddings, document embeddings, entity contexts, objects +- Metrics endpoint + +**Missing Capabilities**: +- Streaming support for LLM services +- WebSocket transport +- Multiplexed concurrent requests +- Persistent connections + +**Performance Issues**: +- High latency for LLM interactions (~30s time-to-first-token) +- Inefficient bulk data transfer (REST request per item) +- Connection overhead for multiple sequential operations + +**User Experience Issues**: +- No real-time feedback during LLM generation +- Cannot cancel long-running LLM operations +- Poor scalability for bulk operations + +### Impact + +The November 2024 streaming enhancement to the Gateway API provided 60x latency improvement (500ms vs 30s first token) for LLM interactions, but Python API users cannot leverage this capability. This creates a significant experience gap between Python and non-Python users. + +## Technical Design + +### Architecture + +The refactored Python API uses a **modular interface approach** with separate objects for different communication patterns. All interfaces are available in both **synchronous and asynchronous** variants: + +1. **REST Interface** (existing, enhanced) + - **Sync**: `api.flow()`, `api.library()`, `api.knowledge()`, `api.collection()`, `api.config()` + - **Async**: `api.async_flow()` + - Synchronous/asynchronous request/response + - Simple connection model + - Default for backward compatibility + +2. **WebSocket Interface** (new) + - **Sync**: `api.socket()` + - **Async**: `api.async_socket()` + - Persistent connection + - Multiplexed requests + - Streaming support + - Same method signatures as REST where functionality overlaps + +3. **Bulk Operations Interface** (new) + - **Sync**: `api.bulk()` + - **Async**: `api.async_bulk()` + - WebSocket-based for efficiency + - Iterator/AsyncIterator-based import/export + - Handles large datasets + +4. **Metrics Interface** (new) + - **Sync**: `api.metrics()` + - **Async**: `api.async_metrics()` + - Prometheus metrics access + +```python +import asyncio + +# Synchronous interfaces +api = Api(url="http://localhost:8088/") + +# REST (existing, unchanged) +flow = api.flow().id("default") +response = flow.agent(question="...", user="...") + +# WebSocket (new) +socket_flow = api.socket().flow("default") +response = socket_flow.agent(question="...", user="...") +for chunk in socket_flow.agent(question="...", user="...", streaming=True): + print(chunk) + +# Bulk operations (new) +bulk = api.bulk() +bulk.import_triples(flow="default", triples=triple_generator()) + +# Asynchronous interfaces +async def main(): + api = Api(url="http://localhost:8088/") + + # Async REST (new) + flow = api.async_flow().id("default") + response = await flow.agent(question="...", user="...") + + # Async WebSocket (new) + socket_flow = api.async_socket().flow("default") + async for chunk in socket_flow.agent(question="...", streaming=True): + print(chunk) + + # Async bulk operations (new) + bulk = api.async_bulk() + await bulk.import_triples(flow="default", triples=async_triple_generator()) + +asyncio.run(main()) +``` + +**Key Design Principles**: +- **Same URL for all interfaces**: `Api(url="http://localhost:8088/")` works for all +- **Sync/Async symmetry**: Every interface has both sync and async variants with identical method signatures +- **Identical signatures**: Where functionality overlaps, method signatures are identical between REST and WebSocket, sync and async +- **Progressive enhancement**: Choose interface based on needs (REST for simple, WebSocket for streaming, Bulk for large datasets, async for modern frameworks) +- **Explicit intent**: `api.socket()` signals WebSocket, `api.async_socket()` signals async WebSocket +- **Backward compatible**: Existing code unchanged + +### Components + +#### 1. Core API Class (Modified) + +Module: `trustgraph-base/trustgraph/api/api.py` + +**Enhanced API Class**: + +```python +class Api: + def __init__(self, url: str, timeout: int = 60, token: Optional[str] = None): + self.url = url + self.timeout = timeout + self.token = token # Optional bearer token for REST, query param for WebSocket + self._socket_client = None + self._bulk_client = None + self._async_flow = None + self._async_socket_client = None + self._async_bulk_client = None + + # Existing synchronous methods (unchanged) + def flow(self) -> Flow: + """Synchronous REST-based flow interface""" + pass + + def library(self) -> Library: + """Synchronous REST-based library interface""" + pass + + def knowledge(self) -> Knowledge: + """Synchronous REST-based knowledge interface""" + pass + + def collection(self) -> Collection: + """Synchronous REST-based collection interface""" + pass + + def config(self) -> Config: + """Synchronous REST-based config interface""" + pass + + # New synchronous methods + def socket(self) -> SocketClient: + """Synchronous WebSocket-based interface for streaming operations""" + if self._socket_client is None: + self._socket_client = SocketClient(self.url, self.timeout, self.token) + return self._socket_client + + def bulk(self) -> BulkClient: + """Synchronous bulk operations interface for import/export""" + if self._bulk_client is None: + self._bulk_client = BulkClient(self.url, self.timeout, self.token) + return self._bulk_client + + def metrics(self) -> Metrics: + """Synchronous metrics interface""" + return Metrics(self.url, self.timeout, self.token) + + # New asynchronous methods + def async_flow(self) -> AsyncFlow: + """Asynchronous REST-based flow interface""" + if self._async_flow is None: + self._async_flow = AsyncFlow(self.url, self.timeout, self.token) + return self._async_flow + + def async_socket(self) -> AsyncSocketClient: + """Asynchronous WebSocket-based interface for streaming operations""" + if self._async_socket_client is None: + self._async_socket_client = AsyncSocketClient(self.url, self.timeout, self.token) + return self._async_socket_client + + def async_bulk(self) -> AsyncBulkClient: + """Asynchronous bulk operations interface for import/export""" + if self._async_bulk_client is None: + self._async_bulk_client = AsyncBulkClient(self.url, self.timeout, self.token) + return self._async_bulk_client + + def async_metrics(self) -> AsyncMetrics: + """Asynchronous metrics interface""" + return AsyncMetrics(self.url, self.timeout, self.token) + + # Resource management + def close(self) -> None: + """Close all synchronous connections""" + if self._socket_client: + self._socket_client.close() + if self._bulk_client: + self._bulk_client.close() + + async def aclose(self) -> None: + """Close all asynchronous connections""" + if self._async_socket_client: + await self._async_socket_client.aclose() + if self._async_bulk_client: + await self._async_bulk_client.aclose() + if self._async_flow: + await self._async_flow.aclose() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.aclose() +``` + +#### 2. Synchronous WebSocket Client + +Module: `trustgraph-base/trustgraph/api/socket_client.py` (new) + +**SocketClient Class**: + +```python +class SocketClient: + """Synchronous WebSocket client""" + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = self._convert_to_ws_url(url) + self.timeout = timeout + self.token = token + self._connection = None + self._request_counter = 0 + + def flow(self, flow_id: str) -> SocketFlowInstance: + """Get flow instance for WebSocket operations""" + return SocketFlowInstance(self, flow_id) + + def _connect(self) -> WebSocket: + """Establish WebSocket connection (lazy)""" + # Uses asyncio.run() internally to wrap async websockets library + pass + + def _send_request( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + streaming: bool = False + ) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]: + """Send request and handle response/streaming""" + # Synchronous wrapper around async WebSocket calls + pass + + def close(self) -> None: + """Close WebSocket connection""" + pass + +class SocketFlowInstance: + """Synchronous WebSocket flow instance with same interface as REST FlowInstance""" + def __init__(self, client: SocketClient, flow_id: str): + self.client = client + self.flow_id = flow_id + + # Same method signatures as FlowInstance + def agent( + self, + question: str, + user: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + streaming: bool = False, + **kwargs + ) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]: + """Agent with optional streaming""" + pass + + def text_completion( + self, + system: str, + prompt: str, + streaming: bool = False, + **kwargs + ) -> Union[str, Iterator[str]]: + """Text completion with optional streaming""" + pass + + # ... similar for graph_rag, document_rag, prompt, etc. +``` + +**Key Features**: +- Lazy connection (only connects when first request sent) +- Request multiplexing (up to 15 concurrent) +- Automatic reconnection on disconnect +- Streaming response parsing +- Thread-safe operation +- Synchronous wrapper around async websockets library + +#### 3. Asynchronous WebSocket Client + +Module: `trustgraph-base/trustgraph/api/async_socket_client.py` (new) + +**AsyncSocketClient Class**: + +```python +class AsyncSocketClient: + """Asynchronous WebSocket client""" + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = self._convert_to_ws_url(url) + self.timeout = timeout + self.token = token + self._connection = None + self._request_counter = 0 + + def flow(self, flow_id: str) -> AsyncSocketFlowInstance: + """Get async flow instance for WebSocket operations""" + return AsyncSocketFlowInstance(self, flow_id) + + async def _connect(self) -> WebSocket: + """Establish WebSocket connection (lazy)""" + # Native async websockets library + pass + + async def _send_request( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + streaming: bool = False + ) -> Union[Dict[str, Any], AsyncIterator[Dict[str, Any]]]: + """Send request and handle response/streaming""" + pass + + async def aclose(self) -> None: + """Close WebSocket connection""" + pass + +class AsyncSocketFlowInstance: + """Asynchronous WebSocket flow instance""" + def __init__(self, client: AsyncSocketClient, flow_id: str): + self.client = client + self.flow_id = flow_id + + # Same method signatures as FlowInstance (but async) + async def agent( + self, + question: str, + user: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + streaming: bool = False, + **kwargs + ) -> Union[Dict[str, Any], AsyncIterator[Dict[str, Any]]]: + """Agent with optional streaming""" + pass + + async def text_completion( + self, + system: str, + prompt: str, + streaming: bool = False, + **kwargs + ) -> Union[str, AsyncIterator[str]]: + """Text completion with optional streaming""" + pass + + # ... similar for graph_rag, document_rag, prompt, etc. +``` + +**Key Features**: +- Native async/await support +- Efficient for async applications (FastAPI, aiohttp) +- No thread blocking +- Same interface as sync version +- AsyncIterator for streaming + +#### 4. Synchronous Bulk Operations Client + +Module: `trustgraph-base/trustgraph/api/bulk_client.py` (new) + +**BulkClient Class**: + +```python +class BulkClient: + """Synchronous bulk operations client""" + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = self._convert_to_ws_url(url) + self.timeout = timeout + self.token = token + + def import_triples( + self, + flow: str, + triples: Iterator[Triple], + **kwargs + ) -> None: + """Bulk import triples via WebSocket""" + pass + + def export_triples( + self, + flow: str, + **kwargs + ) -> Iterator[Triple]: + """Bulk export triples via WebSocket""" + pass + + def import_graph_embeddings( + self, + flow: str, + embeddings: Iterator[Dict[str, Any]], + **kwargs + ) -> None: + """Bulk import graph embeddings via WebSocket""" + pass + + def export_graph_embeddings( + self, + flow: str, + **kwargs + ) -> Iterator[Dict[str, Any]]: + """Bulk export graph embeddings via WebSocket""" + pass + + # ... similar for document embeddings, entity contexts, objects + + def close(self) -> None: + """Close connections""" + pass +``` + +**Key Features**: +- Iterator-based for constant memory usage +- Dedicated WebSocket connections per operation +- Progress tracking (optional callback) +- Error handling with partial success reporting + +#### 5. Asynchronous Bulk Operations Client + +Module: `trustgraph-base/trustgraph/api/async_bulk_client.py` (new) + +**AsyncBulkClient Class**: + +```python +class AsyncBulkClient: + """Asynchronous bulk operations client""" + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = self._convert_to_ws_url(url) + self.timeout = timeout + self.token = token + + async def import_triples( + self, + flow: str, + triples: AsyncIterator[Triple], + **kwargs + ) -> None: + """Bulk import triples via WebSocket""" + pass + + async def export_triples( + self, + flow: str, + **kwargs + ) -> AsyncIterator[Triple]: + """Bulk export triples via WebSocket""" + pass + + async def import_graph_embeddings( + self, + flow: str, + embeddings: AsyncIterator[Dict[str, Any]], + **kwargs + ) -> None: + """Bulk import graph embeddings via WebSocket""" + pass + + async def export_graph_embeddings( + self, + flow: str, + **kwargs + ) -> AsyncIterator[Dict[str, Any]]: + """Bulk export graph embeddings via WebSocket""" + pass + + # ... similar for document embeddings, entity contexts, objects + + async def aclose(self) -> None: + """Close connections""" + pass +``` + +**Key Features**: +- AsyncIterator-based for constant memory usage +- Efficient for async applications +- Native async/await support +- Same interface as sync version + +#### 6. REST Flow API (Synchronous - Unchanged) + +Module: `trustgraph-base/trustgraph/api/flow.py` + +The REST Flow API remains **completely unchanged** for backward compatibility. All existing methods continue to work: + +- `Flow.list()`, `Flow.start()`, `Flow.stop()`, etc. +- `FlowInstance.agent()`, `FlowInstance.text_completion()`, `FlowInstance.graph_rag()`, etc. +- All existing signatures and return types preserved + +**New**: Add `graph_embeddings_query()` to REST FlowInstance for feature parity: + +```python +class FlowInstance: + # All existing methods unchanged... + + # New: Graph embeddings query (REST) + def graph_embeddings_query( + self, + text: str, + user: str, + collection: str, + limit: int = 10, + **kwargs + ) -> List[Dict[str, Any]]: + """Query graph embeddings for semantic search""" + # Calls POST /api/v1/flow/{flow}/service/graph-embeddings + pass +``` + +#### 7. Asynchronous REST Flow API + +Module: `trustgraph-base/trustgraph/api/async_flow.py` (new) + +**AsyncFlow and AsyncFlowInstance Classes**: + +```python +class AsyncFlow: + """Asynchronous REST-based flow interface""" + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = url + self.timeout = timeout + self.token = token + + async def list(self) -> List[Dict[str, Any]]: + """List all flows""" + pass + + async def get(self, id: str) -> Dict[str, Any]: + """Get flow definition""" + pass + + async def start(self, class_name: str, id: str, description: str, parameters: Dict) -> None: + """Start a flow""" + pass + + async def stop(self, id: str) -> None: + """Stop a flow""" + pass + + def id(self, flow_id: str) -> AsyncFlowInstance: + """Get async flow instance""" + return AsyncFlowInstance(self.url, self.timeout, self.token, flow_id) + + async def aclose(self) -> None: + """Close connection""" + pass + +class AsyncFlowInstance: + """Asynchronous REST flow instance""" + + async def agent( + self, + question: str, + user: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + **kwargs + ) -> Dict[str, Any]: + """Async agent execution""" + pass + + async def text_completion( + self, + system: str, + prompt: str, + **kwargs + ) -> str: + """Async text completion""" + pass + + async def graph_rag( + self, + question: str, + user: str, + collection: str, + **kwargs + ) -> str: + """Async graph RAG""" + pass + + # ... all other FlowInstance methods as async versions +``` + +**Key Features**: +- Native async HTTP using `aiohttp` or `httpx` +- Same method signatures as sync REST API +- No streaming (use `async_socket()` for streaming) +- Efficient for async applications + +#### 8. Metrics API + +Module: `trustgraph-base/trustgraph/api/metrics.py` (new) + +**Synchronous Metrics**: + +```python +class Metrics: + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = url + self.timeout = timeout + self.token = token + + def get(self) -> str: + """Get Prometheus metrics as text""" + # Call GET /api/metrics + pass +``` + +**Asynchronous Metrics**: + +```python +class AsyncMetrics: + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = url + self.timeout = timeout + self.token = token + + async def get(self) -> str: + """Get Prometheus metrics as text""" + # Call GET /api/metrics + pass +``` + +#### 9. Enhanced Types + +Module: `trustgraph-base/trustgraph/api/types.py` (modified) + +**New Types**: + +```python +from typing import Iterator, Union, Dict, Any +import dataclasses + +@dataclasses.dataclass +class StreamingChunk: + """Base class for streaming chunks""" + content: str + end_of_message: bool = False + +@dataclasses.dataclass +class AgentThought(StreamingChunk): + """Agent reasoning chunk""" + chunk_type: str = "thought" + +@dataclasses.dataclass +class AgentObservation(StreamingChunk): + """Agent tool observation chunk""" + chunk_type: str = "observation" + +@dataclasses.dataclass +class AgentAnswer(StreamingChunk): + """Agent final answer chunk""" + chunk_type: str = "final-answer" + end_of_dialog: bool = False + +@dataclasses.dataclass +class RAGChunk(StreamingChunk): + """RAG streaming chunk""" + end_of_stream: bool = False + error: Optional[Dict[str, str]] = None + +# Type aliases for clarity +AgentStream = Iterator[Union[AgentThought, AgentObservation, AgentAnswer]] +RAGStream = Iterator[RAGChunk] +CompletionStream = Iterator[str] +``` + +#### 6. Metrics API + +Module: `trustgraph-base/trustgraph/api/metrics.py` (new) + +```python +class Metrics: + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = url + self.timeout = timeout + self.token = token + + def get(self) -> str: + """Get Prometheus metrics as text""" + # Call GET /api/metrics + pass +``` + +### Implementation Approach + +#### Phase 1: Core API Enhancement (Week 1) + +1. Add `socket()`, `bulk()`, and `metrics()` methods to `Api` class +2. Implement lazy initialization for WebSocket and bulk clients +3. Add context manager support (`__enter__`, `__exit__`) +4. Add `close()` method for cleanup +5. Add unit tests for API class enhancements +6. Verify backward compatibility + +**Backward Compatibility**: Zero breaking changes. New methods only. + +#### Phase 2: WebSocket Client (Week 2-3) + +1. Implement `SocketClient` class with connection management +2. Implement `SocketFlowInstance` with same method signatures as `FlowInstance` +3. Add request multiplexing support (up to 15 concurrent) +4. Add streaming response parsing for different chunk types +5. Add automatic reconnection logic +6. Add unit and integration tests +7. Document WebSocket usage patterns + +**Backward Compatibility**: New interface only. Zero impact on existing code. + +#### Phase 3: Streaming Support (Week 3-4) + +1. Add streaming chunk type classes (`AgentThought`, `AgentObservation`, `AgentAnswer`, `RAGChunk`) +2. Implement streaming response parsing in `SocketClient` +3. Add streaming parameter to all LLM methods in `SocketFlowInstance` +4. Handle error cases during streaming +5. Add unit and integration tests for streaming +6. Add streaming examples to documentation + +**Backward Compatibility**: New interface only. Existing REST API unchanged. + +#### Phase 4: Bulk Operations (Week 4-5) + +1. Implement `BulkClient` class +2. Add bulk import/export methods for triples, embeddings, contexts, objects +3. Implement iterator-based processing for constant memory +4. Add progress tracking (optional callback) +5. Add error handling with partial success reporting +6. Add unit and integration tests +7. Add bulk operation examples + +**Backward Compatibility**: New interface only. Zero impact on existing code. + +#### Phase 5: Feature Parity & Polish (Week 5) + +1. Add `graph_embeddings_query()` to REST `FlowInstance` +2. Implement `Metrics` class +3. Add comprehensive integration tests +4. Performance benchmarking +5. Update all documentation +6. Create migration guide + +**Backward Compatibility**: New methods only. Zero impact on existing code. + +### Data Models + +#### Interface Selection + +```python +# Single API instance, same URL for all interfaces +api = Api(url="http://localhost:8088/") + +# Synchronous interfaces +rest_flow = api.flow().id("default") # Sync REST +socket_flow = api.socket().flow("default") # Sync WebSocket +bulk = api.bulk() # Sync bulk operations +metrics = api.metrics() # Sync metrics + +# Asynchronous interfaces +async_rest_flow = api.async_flow().id("default") # Async REST +async_socket_flow = api.async_socket().flow("default") # Async WebSocket +async_bulk = api.async_bulk() # Async bulk operations +async_metrics = api.async_metrics() # Async metrics +``` + +#### Streaming Response Types + +**Agent Streaming**: + +```python +api = Api(url="http://localhost:8088/") + +# REST interface - non-streaming (existing) +rest_flow = api.flow().id("default") +response = rest_flow.agent(question="What is ML?", user="user123") +print(response["response"]) + +# WebSocket interface - non-streaming (same signature) +socket_flow = api.socket().flow("default") +response = socket_flow.agent(question="What is ML?", user="user123") +print(response["response"]) + +# WebSocket interface - streaming (new) +for chunk in socket_flow.agent(question="What is ML?", user="user123", streaming=True): + if isinstance(chunk, AgentThought): + print(f"Thinking: {chunk.content}") + elif isinstance(chunk, AgentObservation): + print(f"Observed: {chunk.content}") + elif isinstance(chunk, AgentAnswer): + print(f"Answer: {chunk.content}") + if chunk.end_of_dialog: + break +``` + +**RAG Streaming**: + +```python +api = Api(url="http://localhost:8088/") + +# REST interface - non-streaming (existing) +rest_flow = api.flow().id("default") +response = rest_flow.graph_rag(question="What is Python?", user="user123", collection="default") +print(response) + +# WebSocket interface - streaming (new) +socket_flow = api.socket().flow("default") +for chunk in socket_flow.graph_rag( + question="What is Python?", + user="user123", + collection="default", + streaming=True +): + print(chunk.content, end="", flush=True) + if chunk.end_of_stream: + break +``` + +**Bulk Operations (Synchronous)**: + +```python +api = Api(url="http://localhost:8088/") + +# Bulk import triples +def triple_generator(): + yield Triple(s="http://ex.com/alice", p="http://ex.com/type", o="Person") + yield Triple(s="http://ex.com/alice", p="http://ex.com/name", o="Alice") + yield Triple(s="http://ex.com/bob", p="http://ex.com/type", o="Person") + +bulk = api.bulk() +bulk.import_triples(flow="default", triples=triple_generator()) + +# Bulk export triples +for triple in bulk.export_triples(flow="default"): + print(f"{triple.s} -> {triple.p} -> {triple.o}") +``` + +**Bulk Operations (Asynchronous)**: + +```python +import asyncio + +async def main(): + api = Api(url="http://localhost:8088/") + + # Async bulk import triples + async def async_triple_generator(): + yield Triple(s="http://ex.com/alice", p="http://ex.com/type", o="Person") + yield Triple(s="http://ex.com/alice", p="http://ex.com/name", o="Alice") + yield Triple(s="http://ex.com/bob", p="http://ex.com/type", o="Person") + + bulk = api.async_bulk() + await bulk.import_triples(flow="default", triples=async_triple_generator()) + + # Async bulk export triples + async for triple in bulk.export_triples(flow="default"): + print(f"{triple.s} -> {triple.p} -> {triple.o}") + +asyncio.run(main()) +``` + +**Async REST Example**: + +```python +import asyncio + +async def main(): + api = Api(url="http://localhost:8088/") + + # Async REST flow operations + flow = api.async_flow().id("default") + response = await flow.agent(question="What is ML?", user="user123") + print(response["response"]) + +asyncio.run(main()) +``` + +**Async WebSocket Streaming Example**: + +```python +import asyncio + +async def main(): + api = Api(url="http://localhost:8088/") + + # Async WebSocket streaming + socket = api.async_socket() + flow = socket.flow("default") + + async for chunk in flow.agent(question="What is ML?", user="user123", streaming=True): + if isinstance(chunk, AgentAnswer): + print(chunk.content, end="", flush=True) + if chunk.end_of_dialog: + break + +asyncio.run(main()) +``` + +### APIs + +#### New APIs + +1. **Core API Class**: + - **Synchronous**: + - `Api.socket()` - Get synchronous WebSocket client + - `Api.bulk()` - Get synchronous bulk operations client + - `Api.metrics()` - Get synchronous metrics client + - `Api.close()` - Close all synchronous connections + - Context manager support (`__enter__`, `__exit__`) + - **Asynchronous**: + - `Api.async_flow()` - Get asynchronous REST flow client + - `Api.async_socket()` - Get asynchronous WebSocket client + - `Api.async_bulk()` - Get asynchronous bulk operations client + - `Api.async_metrics()` - Get asynchronous metrics client + - `Api.aclose()` - Close all asynchronous connections + - Async context manager support (`__aenter__`, `__aexit__`) + +2. **Synchronous WebSocket Client**: + - `SocketClient.flow(flow_id)` - Get WebSocket flow instance + - `SocketFlowInstance.agent(..., streaming: bool = False)` - Agent with optional streaming + - `SocketFlowInstance.text_completion(..., streaming: bool = False)` - Text completion with optional streaming + - `SocketFlowInstance.graph_rag(..., streaming: bool = False)` - Graph RAG with optional streaming + - `SocketFlowInstance.document_rag(..., streaming: bool = False)` - Document RAG with optional streaming + - `SocketFlowInstance.prompt(..., streaming: bool = False)` - Prompt with optional streaming + - `SocketFlowInstance.graph_embeddings_query()` - Graph embeddings query + - All other FlowInstance methods with identical signatures + +3. **Asynchronous WebSocket Client**: + - `AsyncSocketClient.flow(flow_id)` - Get async WebSocket flow instance + - `AsyncSocketFlowInstance.agent(..., streaming: bool = False)` - Async agent with optional streaming + - `AsyncSocketFlowInstance.text_completion(..., streaming: bool = False)` - Async text completion with optional streaming + - `AsyncSocketFlowInstance.graph_rag(..., streaming: bool = False)` - Async graph RAG with optional streaming + - `AsyncSocketFlowInstance.document_rag(..., streaming: bool = False)` - Async document RAG with optional streaming + - `AsyncSocketFlowInstance.prompt(..., streaming: bool = False)` - Async prompt with optional streaming + - `AsyncSocketFlowInstance.graph_embeddings_query()` - Async graph embeddings query + - All other FlowInstance methods as async versions + +4. **Synchronous Bulk Operations Client**: + - `BulkClient.import_triples(flow, triples)` - Bulk triple import + - `BulkClient.export_triples(flow)` - Bulk triple export + - `BulkClient.import_graph_embeddings(flow, embeddings)` - Bulk graph embeddings import + - `BulkClient.export_graph_embeddings(flow)` - Bulk graph embeddings export + - `BulkClient.import_document_embeddings(flow, embeddings)` - Bulk document embeddings import + - `BulkClient.export_document_embeddings(flow)` - Bulk document embeddings export + - `BulkClient.import_entity_contexts(flow, contexts)` - Bulk entity contexts import + - `BulkClient.export_entity_contexts(flow)` - Bulk entity contexts export + - `BulkClient.import_objects(flow, objects)` - Bulk objects import + +5. **Asynchronous Bulk Operations Client**: + - `AsyncBulkClient.import_triples(flow, triples)` - Async bulk triple import + - `AsyncBulkClient.export_triples(flow)` - Async bulk triple export + - `AsyncBulkClient.import_graph_embeddings(flow, embeddings)` - Async bulk graph embeddings import + - `AsyncBulkClient.export_graph_embeddings(flow)` - Async bulk graph embeddings export + - `AsyncBulkClient.import_document_embeddings(flow, embeddings)` - Async bulk document embeddings import + - `AsyncBulkClient.export_document_embeddings(flow)` - Async bulk document embeddings export + - `AsyncBulkClient.import_entity_contexts(flow, contexts)` - Async bulk entity contexts import + - `AsyncBulkClient.export_entity_contexts(flow)` - Async bulk entity contexts export + - `AsyncBulkClient.import_objects(flow, objects)` - Async bulk objects import + +6. **Asynchronous REST Flow Client**: + - `AsyncFlow.list()` - Async list all flows + - `AsyncFlow.get(id)` - Async get flow definition + - `AsyncFlow.start(...)` - Async start flow + - `AsyncFlow.stop(id)` - Async stop flow + - `AsyncFlow.id(flow_id)` - Get async flow instance + - `AsyncFlowInstance.agent(...)` - Async agent execution + - `AsyncFlowInstance.text_completion(...)` - Async text completion + - `AsyncFlowInstance.graph_rag(...)` - Async graph RAG + - All other FlowInstance methods as async versions + +7. **Metrics Clients**: + - `Metrics.get()` - Synchronous Prometheus metrics + - `AsyncMetrics.get()` - Asynchronous Prometheus metrics + +8. **REST Flow API Enhancement**: + - `FlowInstance.graph_embeddings_query()` - Graph embeddings query (sync feature parity) + - `AsyncFlowInstance.graph_embeddings_query()` - Graph embeddings query (async feature parity) + +#### Modified APIs + +1. **Constructor** (minor enhancement): + ```python + Api(url: str, timeout: int = 60, token: Optional[str] = None) + ``` + - Added `token` parameter (optional, for authentication) + - If `None` (default): No authentication used + - If specified: Used as bearer token for REST (`Authorization: Bearer `), query param for WebSocket (`?token=`) + - No other changes - fully backward compatible + +2. **No Breaking Changes**: + - All existing REST API methods unchanged + - All existing signatures preserved + - All existing return types preserved + +### Implementation Details + +#### Error Handling + +**WebSocket Connection Errors**: +```python +try: + api = Api(url="http://localhost:8088/") + socket = api.socket() + socket_flow = socket.flow("default") + response = socket_flow.agent(question="...", user="user123") +except ConnectionError as e: + print(f"WebSocket connection failed: {e}") + print("Hint: Ensure Gateway is running and WebSocket endpoint is accessible") +``` + +**Graceful Fallback**: +```python +api = Api(url="http://localhost:8088/") + +try: + # Try WebSocket streaming first + socket_flow = api.socket().flow("default") + for chunk in socket_flow.agent(question="...", user="...", streaming=True): + print(chunk.content) +except ConnectionError: + # Fall back to REST non-streaming + print("WebSocket unavailable, falling back to REST") + rest_flow = api.flow().id("default") + response = rest_flow.agent(question="...", user="...") + print(response["response"]) +``` + +**Partial Streaming Errors**: +```python +api = Api(url="http://localhost:8088/") +socket_flow = api.socket().flow("default") + +accumulated = [] +try: + for chunk in socket_flow.graph_rag(question="...", streaming=True): + accumulated.append(chunk.content) + if chunk.error: + print(f"Error occurred: {chunk.error}") + print(f"Partial response: {''.join(accumulated)}") + break +except Exception as e: + print(f"Streaming error: {e}") + print(f"Partial response: {''.join(accumulated)}") +``` + +#### Resource Management + +**Context Manager Support**: +```python +# Automatic cleanup +with Api(url="http://localhost:8088/") as api: + socket_flow = api.socket().flow("default") + response = socket_flow.agent(question="...", user="user123") +# All connections automatically closed + +# Manual cleanup +api = Api(url="http://localhost:8088/") +try: + socket_flow = api.socket().flow("default") + response = socket_flow.agent(question="...", user="user123") +finally: + api.close() # Explicitly close all connections (WebSocket, bulk, etc.) +``` + +#### Threading and Concurrency + +**Thread Safety**: +- Each `Api` instance maintains its own connection +- WebSocket transport uses locks for thread-safe request multiplexing +- Multiple threads can share an `Api` instance safely +- Streaming iterators are not thread-safe (consume from single thread) + +**Async Support** (future consideration): +```python +# Phase 2 enhancement (not in initial scope) +import asyncio + +async def main(): + api = await AsyncApi(url="ws://localhost:8088/") + flow = api.flow().id("default") + + async for chunk in flow.agent(question="...", streaming=True): + print(chunk.content) + + await api.close() + +asyncio.run(main()) +``` + +## Security Considerations + +### Authentication + +**Token Parameter**: +```python +# No authentication (default) +api = Api(url="http://localhost:8088/") + +# With authentication +api = Api(url="http://localhost:8088/", token="mytoken") +``` + +**REST Transport**: +- Bearer token via `Authorization` header +- Applied automatically to all REST requests +- Format: `Authorization: Bearer ` + +**WebSocket Transport**: +- Token via query parameter appended to WebSocket URL +- Applied automatically during connection establishment +- Format: `ws://localhost:8088/api/v1/socket?token=` + +**Implementation**: +```python +class SocketClient: + def _connect(self) -> WebSocket: + # Construct WebSocket URL with optional token + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + # Connect to WebSocket + return websocket.connect(ws_url) +``` + +**Example**: +```python +# REST with auth +api = Api(url="http://localhost:8088/", token="mytoken") +flow = api.flow().id("default") +# All REST calls include: Authorization: Bearer mytoken + +# WebSocket with auth +socket = api.socket() +# Connects to: ws://localhost:8088/api/v1/socket?token=mytoken +``` + +### Secure Communication + +- Support both WS (WebSocket) and WSS (WebSocket Secure) schemes +- TLS certificate validation for WSS connections +- Optional certificate verification disable for development (with warning) + +### Input Validation + +- Validate URL schemes (http, https, ws, wss) +- Validate transport parameter values +- Validate streaming parameter combinations +- Validate bulk import data types + +## Performance Considerations + +### Latency Improvements + +**Streaming LLM Operations**: +- **Time-to-first-token**: ~500ms (vs ~30s non-streaming) +- **Improvement**: 60x faster perceived performance +- **Applicable to**: Agent, Graph RAG, Document RAG, Text Completion, Prompt + +**Persistent Connections**: +- **Connection overhead**: Eliminated for subsequent requests +- **WebSocket handshake**: One-time cost (~100ms) +- **Applicable to**: All operations when using WebSocket transport + +### Throughput Improvements + +**Bulk Operations**: +- **Triples import**: ~10,000 triples/second (vs ~100/second with REST per-item) +- **Embeddings import**: ~5,000 embeddings/second (vs ~50/second with REST per-item) +- **Improvement**: 100x throughput for bulk operations + +**Request Multiplexing**: +- **Concurrent requests**: Up to 15 simultaneous requests over single connection +- **Connection reuse**: No connection overhead for concurrent operations + +### Memory Considerations + +**Streaming Responses**: +- Constant memory usage (process chunks as they arrive) +- No buffering of complete response +- Suitable for very long outputs (>1MB) + +**Bulk Operations**: +- Iterator-based processing (constant memory) +- No loading of entire dataset into memory +- Suitable for datasets with millions of items + +### Benchmarks (Expected) + +| Operation | REST (existing) | WebSocket (streaming) | Improvement | +|-----------|----------------|----------------------|-------------| +| Agent (time-to-first-token) | 30s | 0.5s | 60x | +| Graph RAG (time-to-first-token) | 25s | 0.5s | 50x | +| Import 10K triples | 100s | 1s | 100x | +| Import 1M triples | 10,000s (2.7h) | 100s (1.6m) | 100x | +| 10 concurrent small requests | 5s (sequential) | 0.5s (parallel) | 10x | + +## Testing Strategy + +### Unit Tests + +**Transport Layer** (`test_transport.py`): +- Test REST transport request/response +- Test WebSocket transport connection +- Test WebSocket transport reconnection +- Test request multiplexing +- Test streaming response parsing +- Mock WebSocket server for deterministic tests + +**API Methods** (`test_flow.py`, `test_library.py`, etc.): +- Test new methods with mocked transport +- Test streaming parameter handling +- Test bulk operation iterators +- Test error handling + +**Types** (`test_types.py`): +- Test new streaming chunk types +- Test type serialization/deserialization + +### Integration Tests + +**End-to-End REST** (`test_integration_rest.py`): +- Test all operations against real Gateway (REST mode) +- Verify backward compatibility +- Test error conditions + +**End-to-End WebSocket** (`test_integration_websocket.py`): +- Test all operations against real Gateway (WebSocket mode) +- Test streaming operations +- Test bulk operations +- Test concurrent requests +- Test connection recovery + +**Streaming Services** (`test_streaming_integration.py`): +- Test agent streaming (thoughts, observations, answers) +- Test RAG streaming (incremental chunks) +- Test text completion streaming (token-by-token) +- Test prompt streaming +- Test error handling during streaming + +**Bulk Operations** (`test_bulk_integration.py`): +- Test bulk import/export of triples (1K, 10K, 100K items) +- Test bulk import/export of embeddings +- Test memory usage during bulk operations +- Test progress tracking + +### Performance Tests + +**Latency Benchmarks** (`test_performance_latency.py`): +- Measure time-to-first-token (streaming vs non-streaming) +- Measure connection overhead (REST vs WebSocket) +- Compare against expected benchmarks + +**Throughput Benchmarks** (`test_performance_throughput.py`): +- Measure bulk import throughput +- Measure request multiplexing efficiency +- Compare against expected benchmarks + +### Compatibility Tests + +**Backward Compatibility** (`test_backward_compatibility.py`): +- Run existing test suite against refactored API +- Verify zero breaking changes +- Test migration path for common patterns + +## Migration Plan + +### Phase 1: Transparent Migration (Default) + +**No code changes required**. Existing code continues to work: + +```python +# Existing code works unchanged +api = Api(url="http://localhost:8088/") +flow = api.flow().id("default") +response = flow.agent(question="What is ML?", user="user123") +``` + +### Phase 2: Opt-in Streaming (Simple) + +**Use `api.socket()` interface** to enable streaming: + +```python +# Before: Non-streaming REST +api = Api(url="http://localhost:8088/") +rest_flow = api.flow().id("default") +response = rest_flow.agent(question="What is ML?", user="user123") +print(response["response"]) + +# After: Streaming WebSocket (same parameters!) +api = Api(url="http://localhost:8088/") # Same URL +socket_flow = api.socket().flow("default") + +for chunk in socket_flow.agent(question="What is ML?", user="user123", streaming=True): + if isinstance(chunk, AgentAnswer): + print(chunk.content, end="", flush=True) +``` + +**Key Points**: +- Same URL for both REST and WebSocket +- Same method signatures (easy migration) +- Just add `.socket()` and `streaming=True` + +### Phase 3: Bulk Operations (New Capability) + +**Use `api.bulk()` interface** for large datasets: + +```python +# Before: Inefficient per-item operations +api = Api(url="http://localhost:8088/") +flow = api.flow().id("default") + +for triple in my_large_triple_list: + # Slow per-item operations + # (no direct bulk insert in REST API) + pass + +# After: Efficient bulk loading +api = Api(url="http://localhost:8088/") # Same URL +bulk = api.bulk() + +# This is fast (10,000 triples/second) +bulk.import_triples(flow="default", triples=iter(my_large_triple_list)) +``` + +### Documentation Updates + +1. **README.md**: Add streaming and WebSocket examples +2. **API Reference**: Document all new methods and parameters +3. **Migration Guide**: Step-by-step guide for enabling streaming +4. **Examples**: Add example scripts for common patterns +5. **Performance Guide**: Document expected performance improvements + +### Deprecation Policy + +**No deprecations**. All existing APIs remain supported. This is a pure enhancement. + +## Timeline + +### Week 1: Foundation +- Transport abstraction layer +- Refactor existing REST code +- Unit tests for transport layer +- Backward compatibility verification + +### Week 2: WebSocket Transport +- WebSocket transport implementation +- Connection management and reconnection +- Request multiplexing +- Unit and integration tests + +### Week 3: Streaming Support +- Add streaming parameter to LLM methods +- Implement streaming response parsing +- Add streaming chunk types +- Streaming integration tests + +### Week 4: Bulk Operations +- Add bulk import/export methods +- Implement iterator-based operations +- Performance testing +- Bulk operation integration tests + +### Week 5: Feature Parity & Documentation +- Add graph embeddings query +- Add metrics API +- Comprehensive documentation +- Migration guide +- Release candidate + +### Week 6: Release +- Final integration testing +- Performance benchmarking +- Release documentation +- Community announcement + +**Total Duration**: 6 weeks + +## Open Questions + +### API Design Questions + +1. **Async Support**: ✅ **RESOLVED** - Full async support included in initial release + - All interfaces have async variants: `async_flow()`, `async_socket()`, `async_bulk()`, `async_metrics()` + - Provides complete symmetry between sync and async APIs + - Essential for modern async frameworks (FastAPI, aiohttp) + +2. **Progress Tracking**: Should bulk operations support progress callbacks? + ```python + def progress_callback(processed: int, total: Optional[int]): + print(f"Processed {processed} items") + + bulk.import_triples(flow="default", triples=triples, on_progress=progress_callback) + ``` + - **Recommendation**: Add in Phase 2. Not critical for initial release. + +3. **Streaming Timeout**: How should we handle timeouts for streaming operations? + - **Recommendation**: Use same timeout as non-streaming, but reset on each chunk received. + +4. **Chunk Buffering**: Should we buffer chunks or yield immediately? + - **Recommendation**: Yield immediately for lowest latency. + +5. **Global Services via WebSocket**: Should `api.socket()` support global services (library, knowledge, collection, config) or only flow-scoped services? + - **Recommendation**: Start with flow-scoped only (where streaming matters). Add global services if needed in Phase 2. + +### Implementation Questions + +1. **WebSocket Library**: Should we use `websockets`, `websocket-client`, or `aiohttp`? + - **Recommendation**: `websockets` (async, mature, well-maintained). Wrap in sync interface using `asyncio.run()`. + +2. **Connection Pooling**: Should we support multiple concurrent `Api` instances sharing a connection pool? + - **Recommendation**: Defer to Phase 2. Each `Api` instance has its own connections initially. + +3. **Connection Reuse**: Should `SocketClient` and `BulkClient` share the same WebSocket connection, or use separate connections? + - **Recommendation**: Separate connections. Simpler implementation, clearer separation of concerns. + +4. **Lazy vs Eager Connection**: Should WebSocket connection be established in `api.socket()` or on first request? + - **Recommendation**: Lazy (on first request). Avoids connection overhead if user only uses REST methods. + +### Testing Questions + +1. **Mock Gateway**: Should we create a lightweight mock Gateway for testing, or test against real Gateway? + - **Recommendation**: Both. Use mocks for unit tests, real Gateway for integration tests. + +2. **Performance Regression Tests**: Should we add automated performance regression testing to CI? + - **Recommendation**: Yes, but with generous thresholds to account for CI environment variability. + +## References + +### Related Tech Specs +- `docs/tech-specs/streaming-llm-responses.md` - Streaming implementation in Gateway +- `docs/tech-specs/rag-streaming-support.md` - RAG streaming support + +### Implementation Files +- `trustgraph-base/trustgraph/api/` - Python API source +- `trustgraph-flow/trustgraph/gateway/` - Gateway source +- `trustgraph-flow/trustgraph/gateway/dispatch/mux.py` - WebSocket multiplexer reference implementation + +### Documentation +- `docs/apiSpecification.md` - Complete API reference +- `docs/api-status-summary.md` - API status summary +- `README.websocket` - WebSocket protocol documentation +- `STREAMING-IMPLEMENTATION-NOTES.txt` - Streaming implementation notes + +### External Libraries +- `websockets` - Python WebSocket library (https://websockets.readthedocs.io/) +- `requests` - Python HTTP library (existing) diff --git a/ontology-prompt.md b/ontology-prompt.md new file mode 100644 index 00000000..6be255b7 --- /dev/null +++ b/ontology-prompt.md @@ -0,0 +1,54 @@ +You are a knowledge extraction expert. Extract structured triples from text using ONLY the provided ontology elements. + +## Ontology Classes: + +{% for class_id, class_def in classes.items() %} +- **{{class_id}}**{% if class_def.subclass_of %} (subclass of {{class_def.subclass_of}}){% endif %}{% if class_def.comment %}: {{class_def.comment}}{% endif %} +{% endfor %} + +## Object Properties (connect entities): + +{% for prop_id, prop_def in object_properties.items() %} +- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %} +{% endfor %} + +## Datatype Properties (entity attributes): + +{% for prop_id, prop_def in datatype_properties.items() %} +- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %} +{% endfor %} + +## Text to Analyze: + +{{text}} + +## Extraction Rules: + +1. Only use classes defined above for entity types +2. Only use properties defined above for relationships and attributes +3. Respect domain and range constraints where specified +4. For class instances, use `rdf:type` as the predicate +5. Include `rdfs:label` for new entities to provide human-readable names +6. Extract all relevant triples that can be inferred from the text +7. Use entity URIs or meaningful identifiers as subjects/objects + +## Output Format: + +Return ONLY a valid JSON array (no markdown, no code blocks) containing objects with these fields: +- "subject": the subject entity (URI or identifier) +- "predicate": the property (from ontology or rdf:type/rdfs:label) +- "object": the object entity or literal value + +Important: Return raw JSON only, with no markdown formatting, no code blocks, and no backticks. + +## Example Output: + +[ + {"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}, + {"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"}, + {"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"}, + {"subject": "ingredient:flour", "predicate": "rdf:type", "object": "Ingredient"}, + {"subject": "ingredient:flour", "predicate": "rdfs:label", "object": "plain flour"} +] + +Now extract triples from the text above. diff --git a/requirements.txt b/requirements.txt index 68c21e1d..4e3305d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ prometheus-client pyarrow boto3 ollama +python-logging-loki diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..7c6ffc13 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,60 @@ +""" +Global pytest configuration for all tests. + +This conftest.py applies to all test directories. +""" + +import pytest +# import asyncio +# import tracemalloc +# import warnings +from unittest.mock import MagicMock + +# Uncomment the lines below to enable asyncio debug mode and tracemalloc +# for tracing unawaited coroutines and their creation points +# tracemalloc.start() +# asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) +# warnings.simplefilter("always", ResourceWarning) +# warnings.simplefilter("always", RuntimeWarning) + + +@pytest.fixture(scope="session", autouse=True) +def mock_loki_handler(session_mocker=None): + """ + Mock LokiHandler to prevent connection attempts during tests. + + This fixture runs once per test session and prevents the logging + module from trying to connect to a Loki server that doesn't exist + in the test environment. + """ + # Try to import logging_loki and mock it if available + try: + import logging_loki + # Create a mock LokiHandler that does nothing + original_loki_handler = logging_loki.LokiHandler + + class MockLokiHandler: + """Mock LokiHandler that doesn't make network calls.""" + def __init__(self, *args, **kwargs): + pass + + def emit(self, record): + pass + + def flush(self): + pass + + def close(self): + pass + + # Replace the real LokiHandler with our mock + logging_loki.LokiHandler = MockLokiHandler + + yield + + # Restore original after tests + logging_loki.LokiHandler = original_loki_handler + + except ImportError: + # If logging_loki isn't installed, no need to mock + yield diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index 972bf1f0..6b10bd2f 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -257,7 +257,6 @@ class TestAgentMessageContracts: # Act request = AgentRequest( question="What comes next?", - plan="Multi-step plan", state="processing", history=history_steps ) @@ -588,7 +587,6 @@ class TestSerializationContracts: request = AgentRequest( question="Test with array", - plan="Test plan", state="Test state", history=steps ) diff --git a/tests/contract/test_objects_cassandra_contracts.py b/tests/contract/test_objects_cassandra_contracts.py index 3966a3fc..bb8aec8a 100644 --- a/tests/contract/test_objects_cassandra_contracts.py +++ b/tests/contract/test_objects_cassandra_contracts.py @@ -189,6 +189,7 @@ class TestObjectsCassandraContracts: assert result == expected_val assert isinstance(result, expected_type) or result is None + @pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type") def test_extracted_object_serialization_contract(self): """Test that ExtractedObject can be serialized/deserialized correctly""" # Create test object @@ -408,6 +409,7 @@ class TestObjectsCassandraContractsBatch: assert isinstance(single_batch_object.values[0], dict) assert single_batch_object.values[0]["customer_id"] == "CUST999" + @pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type") def test_extracted_object_batch_serialization_contract(self): """Test that batched ExtractedObject can be serialized/deserialized correctly""" # Create batch object diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index af5dda5b..7e18f0de 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -480,11 +480,15 @@ def streaming_chunk_collector(): class ChunkCollector: def __init__(self): self.chunks = [] + self.end_of_stream_flags = [] self.complete = False - async def collect(self, chunk): - """Async callback to collect chunks""" + async def collect(self, chunk, end_of_stream=False): + """Async callback to collect chunks with end_of_stream flag""" self.chunks.append(chunk) + self.end_of_stream_flags.append(end_of_stream) + if end_of_stream: + self.complete = True def get_full_text(self): """Concatenate all chunk content""" @@ -496,6 +500,14 @@ def streaming_chunk_collector(): return [c.get("chunk_type") for c in self.chunks] return [] + def verify_streaming_protocol(self): + """Verify that streaming protocol is correct""" + assert len(self.chunks) > 0, "Should have received at least one chunk" + assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag" + assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True" + assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True" + assert self.complete is True, "Should be marked complete after final chunk" + return ChunkCollector diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py index 0971d30c..d6004c21 100644 --- a/tests/integration/test_agent_streaming_integration.py +++ b/tests/integration/test_agent_streaming_integration.py @@ -47,8 +47,9 @@ Args: { "}" ] - for chunk in chunks: - await chunk_callback(chunk) + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) return full_text else: @@ -312,8 +313,10 @@ Final Answer: AI is the simulation of human intelligence in machines.""" call_count += 1 if streaming and chunk_callback: - for chunk in response.split(): - await chunk_callback(chunk + " ") + chunks = response.split() + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk + " ", is_final) return response return response diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index e706b76a..a06ec509 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -373,13 +373,13 @@ class TestMultipleHostsHandling: from trustgraph.base.cassandra_config import resolve_cassandra_config # Test various whitespace scenarios - hosts1, _, _ = resolve_cassandra_config(host='host1, host2 , host3') + hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3') assert hosts1 == ['host1', 'host2', 'host3'] - hosts2, _, _ = resolve_cassandra_config(host='host1,host2,host3,') + hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,') assert hosts2 == ['host1', 'host2', 'host3'] - hosts3, _, _ = resolve_cassandra_config(host=' host1 , host2 ') + hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ') assert hosts3 == ['host1', 'host2'] diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index 4b792443..84061add 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -46,9 +46,16 @@ class TestDocumentRagStreaming: full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." if streaming and chunk_callback: - # Simulate streaming chunks + # Simulate streaming chunks with end_of_stream flags + chunks = [] async for chunk in mock_streaming_llm_response(): - await chunk_callback(chunk) + chunks.append(chunk) + + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) + return full_text else: # Non-streaming response - same text @@ -89,6 +96,9 @@ class TestDocumentRagStreaming: assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) + # Verify streaming protocol compliance + collector.verify_streaming_protocol() + # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() assert result == full_from_chunks @@ -117,7 +127,7 @@ class TestDocumentRagStreaming: # Act - Streaming streaming_chunks = [] - async def collect(chunk): + async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) streaming_result = await document_rag_streaming.query( diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 92da6527..47dd84b6 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -59,9 +59,16 @@ class TestGraphRagStreaming: full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." if streaming and chunk_callback: - # Simulate streaming chunks + # Simulate streaming chunks with end_of_stream flags + chunks = [] async for chunk in mock_streaming_llm_response(): - await chunk_callback(chunk) + chunks.append(chunk) + + # Send all chunks with end_of_stream=False except the last + for i, chunk in enumerate(chunks): + is_final = (i == len(chunks) - 1) + await chunk_callback(chunk, is_final) + return full_text else: # Non-streaming response - same text @@ -102,6 +109,9 @@ class TestGraphRagStreaming: assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) + # Verify streaming protocol compliance + collector.verify_streaming_protocol() + # Verify full response matches concatenated chunks full_from_chunks = collector.get_full_text() assert result == full_from_chunks @@ -128,7 +138,7 @@ class TestGraphRagStreaming: # Act - Streaming streaming_chunks = [] - async def collect(chunk): + async def collect(chunk, end_of_stream): streaming_chunks.append(chunk) streaming_result = await graph_rag_streaming.query( diff --git a/tests/integration/test_import_export_graceful_shutdown.py b/tests/integration/test_import_export_graceful_shutdown.py index b802cd10..30197731 100644 --- a/tests/integration/test_import_export_graceful_shutdown.py +++ b/tests/integration/test_import_export_graceful_shutdown.py @@ -59,17 +59,17 @@ class MockWebSocket: @pytest.fixture -def mock_pulsar_client(): - """Mock Pulsar client for integration testing.""" - client = MagicMock() - +def mock_backend(): + """Mock backend for integration testing.""" + backend = MagicMock() + # Mock producer producer = MagicMock() producer.send = MagicMock() producer.flush = MagicMock() producer.close = MagicMock() - client.create_producer.return_value = producer - + backend.create_producer.return_value = producer + # Mock consumer consumer = MagicMock() consumer.receive = AsyncMock() @@ -78,33 +78,31 @@ def mock_pulsar_client(): consumer.pause_message_listener = MagicMock() consumer.unsubscribe = MagicMock() consumer.close = MagicMock() - client.subscribe.return_value = consumer - - return client + backend.create_consumer.return_value = consumer + + return backend @pytest.mark.asyncio -async def test_import_graceful_shutdown_integration(): +async def test_import_graceful_shutdown_integration(mock_backend): """Test import path handles shutdown gracefully with real message flow.""" - mock_client = MagicMock() - mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer - + mock_producer = mock_backend.create_producer.return_value + # Track sent messages sent_messages = [] def track_send(message, properties=None): sent_messages.append((message, properties)) - + mock_producer.send.side_effect = track_send - + ws = MockWebSocket() running = Running() - + # Create import handler import_handler = TriplesImport( ws=ws, running=running, - pulsar_client=mock_client, + backend=mock_backend, queue="test-triples-import" ) @@ -151,11 +149,9 @@ async def test_import_graceful_shutdown_integration(): @pytest.mark.asyncio -async def test_export_no_message_loss_integration(): +async def test_export_no_message_loss_integration(mock_backend): """Test export path doesn't lose acknowledged messages.""" - mock_client = MagicMock() - mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer + mock_consumer = mock_backend.create_consumer.return_value # Create test messages test_messages = [] @@ -202,7 +198,7 @@ async def test_export_no_message_loss_integration(): export_handler = TriplesExport( ws=ws, running=running, - pulsar_client=mock_client, + backend=mock_backend, queue="test-triples-export", consumer="test-consumer", subscriber="test-subscriber" @@ -245,14 +241,14 @@ async def test_export_no_message_loss_integration(): async def test_concurrent_import_export_shutdown(): """Test concurrent import and export shutdown scenarios.""" # Setup mock clients - import_client = MagicMock() - export_client = MagicMock() + import_backend = MagicMock() + export_backend = MagicMock() import_producer = MagicMock() export_consumer = MagicMock() - import_client.create_producer.return_value = import_producer - export_client.subscribe.return_value = export_consumer + import_backend.create_producer.return_value = import_producer + export_backend.subscribe.return_value = export_consumer # Track operations import_operations = [] @@ -280,14 +276,14 @@ async def test_concurrent_import_export_shutdown(): import_handler = TriplesImport( ws=import_ws, running=import_running, - pulsar_client=import_client, + backend=import_backend, queue="concurrent-import" ) export_handler = TriplesExport( ws=export_ws, running=export_running, - pulsar_client=export_client, + backend=export_backend, queue="concurrent-export", consumer="concurrent-consumer", subscriber="concurrent-subscriber" @@ -328,9 +324,9 @@ async def test_concurrent_import_export_shutdown(): @pytest.mark.asyncio async def test_websocket_close_during_message_processing(): """Test graceful handling when websocket closes during active message processing.""" - mock_client = MagicMock() + mock_backend_local = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend_local.create_producer.return_value = mock_producer # Simulate slow message processing processed_messages = [] @@ -346,7 +342,7 @@ async def test_websocket_close_during_message_processing(): import_handler = TriplesImport( ws=ws, running=running, - pulsar_client=mock_client, + backend=mock_backend_local, queue="slow-processing-import" ) @@ -395,9 +391,9 @@ async def test_websocket_close_during_message_processing(): @pytest.mark.asyncio async def test_backpressure_during_shutdown(): """Test graceful shutdown under backpressure conditions.""" - mock_client = MagicMock() + mock_backend_local = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer + mock_backend_local.subscribe.return_value = mock_consumer # Mock slow websocket class SlowWebSocket(MockWebSocket): @@ -410,8 +406,8 @@ async def test_backpressure_during_shutdown(): export_handler = TriplesExport( ws=ws, - running=running, - pulsar_client=mock_client, + running=running, + backend=mock_backend_local, queue="backpressure-export", consumer="backpressure-consumer", subscriber="backpressure-subscriber" diff --git a/tests/integration/test_objects_cassandra_integration.py b/tests/integration/test_objects_cassandra_integration.py index 21b414c1..3310b396 100644 --- a/tests/integration/test_objects_cassandra_integration.py +++ b/tests/integration/test_objects_cassandra_integration.py @@ -117,7 +117,7 @@ class TestObjectsCassandraIntegration: assert "customer_records" in processor.schemas # Step 1.5: Create the collection first (simulate tg-set-collection) - await processor.create_collection("test_user", "import_2024") + await processor.create_collection("test_user", "import_2024", {}) # Step 2: Process an ExtractedObject test_obj = ExtractedObject( @@ -213,8 +213,8 @@ class TestObjectsCassandraIntegration: assert len(processor.schemas) == 2 # Create collections first - await processor.create_collection("shop", "catalog") - await processor.create_collection("shop", "sales") + await processor.create_collection("shop", "catalog", {}) + await processor.create_collection("shop", "sales", {}) # Process objects for different schemas product_obj = ExtractedObject( @@ -263,7 +263,7 @@ class TestObjectsCassandraIntegration: ) # Create collection first - await processor.create_collection("test", "test") + await processor.create_collection("test", "test", {}) # Create object missing required field test_obj = ExtractedObject( @@ -302,7 +302,7 @@ class TestObjectsCassandraIntegration: ) # Create collection first - await processor.create_collection("logger", "app_events") + await processor.create_collection("logger", "app_events", {}) # Process object test_obj = ExtractedObject( @@ -407,7 +407,7 @@ class TestObjectsCassandraIntegration: # Create all collections first for coll in collections: - await processor.create_collection("analytics", coll) + await processor.create_collection("analytics", coll, {}) for coll in collections: obj = ExtractedObject( @@ -486,7 +486,7 @@ class TestObjectsCassandraIntegration: ) # Create collection first - await processor.create_collection("test_user", "batch_import") + await processor.create_collection("test_user", "batch_import", {}) msg = MagicMock() msg.value.return_value = batch_obj @@ -532,7 +532,7 @@ class TestObjectsCassandraIntegration: ) # Create collection first - await processor.create_collection("test", "empty") + await processor.create_collection("test", "empty", {}) # Process empty batch object empty_obj = ExtractedObject( @@ -573,7 +573,7 @@ class TestObjectsCassandraIntegration: ) # Create collection first - await processor.create_collection("test", "mixed") + await processor.create_collection("test", "mixed", {}) # Single object (backward compatibility) single_obj = ExtractedObject( diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py new file mode 100644 index 00000000..d2ceea95 --- /dev/null +++ b/tests/integration/test_rag_streaming_protocol.py @@ -0,0 +1,351 @@ +""" +Integration tests for RAG service streaming protocol compliance. + +These tests verify that RAG services correctly forward end_of_stream flags +and don't duplicate final chunks, ensuring proper streaming semantics. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + + +class TestGraphRagStreamingProtocol: + """Integration tests for GraphRAG streaming protocol""" + + @pytest.fixture + def mock_embeddings_client(self): + """Mock embeddings client""" + client = AsyncMock() + client.embed.return_value = [[0.1, 0.2, 0.3]] + return client + + @pytest.fixture + def mock_graph_embeddings_client(self): + """Mock graph embeddings client""" + client = AsyncMock() + client.query.return_value = ["entity1", "entity2"] + return client + + @pytest.fixture + def mock_triples_client(self): + """Mock triples client""" + client = AsyncMock() + client.query.return_value = [] + return client + + @pytest.fixture + def mock_streaming_prompt_client(self): + """Mock prompt client that simulates realistic streaming with end_of_stream flags""" + client = AsyncMock() + + async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True + await chunk_callback("The", False) + await chunk_callback(" answer", False) + await chunk_callback(" is here.", False) + await chunk_callback("", True) # Empty final chunk with end_of_stream=True + return "" # Return value not used since callback handles everything + else: + return "The answer is here." + + client.kg_prompt.side_effect = kg_prompt_side_effect + return client + + @pytest.fixture + def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, + mock_triples_client, mock_streaming_prompt_client): + """Create GraphRag instance with mocked dependencies""" + return GraphRag( + embeddings_client=mock_embeddings_client, + graph_embeddings_client=mock_graph_embeddings_client, + triples_client=mock_triples_client, + prompt_client=mock_streaming_prompt_client, + verbose=False + ) + + @pytest.mark.asyncio + async def test_callback_receives_end_of_stream_parameter(self, graph_rag): + """Test that callback receives end_of_stream parameter""" + # Arrange + callback = AsyncMock() + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should receive (chunk, end_of_stream) signature + assert callback.call_count == 4 + # All calls should have 2 arguments + for call_args in callback.call_args_list: + assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)" + + @pytest.mark.asyncio + async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag): + """Test that end_of_stream flags are forwarded correctly""" + # Arrange + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 4 + + # First three chunks should have end_of_stream=False + assert chunks_with_flags[0] == ("The", False) + assert chunks_with_flags[1] == (" answer", False) + assert chunks_with_flags[2] == (" is here.", False) + + # Final chunk should have end_of_stream=True + assert chunks_with_flags[3] == ("", True) + + @pytest.mark.asyncio + async def test_no_duplicate_final_chunk(self, graph_rag): + """Test that final chunk is not duplicated""" + # Arrange + chunks = [] + + async def collect(chunk, end_of_stream): + chunks.append(chunk) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - should have exactly 4 chunks, no duplicates + assert len(chunks) == 4 + assert chunks == ["The", " answer", " is here.", ""] + + # The last chunk appears exactly once + assert chunks.count("") == 1 + + @pytest.mark.asyncio + async def test_exactly_one_end_of_stream_true(self, graph_rag): + """Test that exactly one message has end_of_stream=True""" + # Arrange + end_of_stream_flags = [] + + async def collect(chunk, end_of_stream): + end_of_stream_flags.append(end_of_stream) + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - exactly one True + assert end_of_stream_flags.count(True) == 1 + assert end_of_stream_flags.count(False) == 3 + + @pytest.mark.asyncio + async def test_empty_final_chunk_preserved(self, graph_rag): + """Test that empty final chunks are preserved and forwarded""" + # Arrange + final_chunk = None + final_flag = None + + async def collect(chunk, end_of_stream): + nonlocal final_chunk, final_flag + if end_of_stream: + final_chunk = chunk + final_flag = end_of_stream + + # Act + await graph_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert final_flag is True + assert final_chunk == "", "Empty final chunk should be preserved" + + +class TestDocumentRagStreamingProtocol: + """Integration tests for DocumentRAG streaming protocol""" + + @pytest.fixture + def mock_embeddings_client(self): + """Mock embeddings client""" + client = AsyncMock() + client.embed.return_value = [[0.1, 0.2, 0.3]] + return client + + @pytest.fixture + def mock_doc_embeddings_client(self): + """Mock document embeddings client""" + client = AsyncMock() + client.query.return_value = ["doc1", "doc2"] + return client + + @pytest.fixture + def mock_streaming_prompt_client(self): + """Mock prompt client with streaming support""" + client = AsyncMock() + + async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + # Simulate streaming with non-empty final chunk (some LLMs do this) + await chunk_callback("Document", False) + await chunk_callback(" summary", False) + await chunk_callback(".", True) # Non-empty final chunk + return "" + else: + return "Document summary." + + client.document_prompt.side_effect = document_prompt_side_effect + return client + + @pytest.fixture + def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, + mock_streaming_prompt_client): + """Create DocumentRag instance with mocked dependencies""" + return DocumentRag( + embeddings_client=mock_embeddings_client, + doc_embeddings_client=mock_doc_embeddings_client, + prompt_client=mock_streaming_prompt_client, + verbose=False + ) + + @pytest.mark.asyncio + async def test_callback_receives_end_of_stream_parameter(self, document_rag): + """Test that callback receives end_of_stream parameter""" + # Arrange + callback = AsyncMock() + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 3 + for call_args in callback.call_args_list: + assert len(call_args.args) == 2 + + @pytest.mark.asyncio + async def test_non_empty_final_chunk_preserved(self, document_rag): + """Test that non-empty final chunks are preserved with correct flag""" + # Arrange + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 3 + assert chunks_with_flags[0] == ("Document", False) + assert chunks_with_flags[1] == (" summary", False) + assert chunks_with_flags[2] == (".", True) # Non-empty final chunk + + @pytest.mark.asyncio + async def test_no_duplicate_final_chunk(self, document_rag): + """Test that final chunk is not duplicated""" + # Arrange + chunks = [] + + async def collect(chunk, end_of_stream): + chunks.append(chunk) + + # Act + await document_rag.query( + query="test query", + user="test_user", + collection="test_collection", + streaming=True, + chunk_callback=collect + ) + + # Assert - final "." appears exactly once + assert chunks.count(".") == 1 + assert chunks == ["Document", " summary", "."] + + +class TestStreamingProtocolEdgeCases: + """Test edge cases in streaming protocol""" + + @pytest.mark.asyncio + async def test_multiple_empty_chunks_before_final(self): + """Test handling of multiple empty chunks (edge case)""" + # Arrange + client = AsyncMock() + + async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None): + if streaming and chunk_callback: + await chunk_callback("text", False) + await chunk_callback("", False) # Empty but not final + await chunk_callback("more", False) + await chunk_callback("", True) # Empty and final + return "" + else: + return "textmore" + + client.kg_prompt.side_effect = kg_prompt_with_empties + + rag = GraphRag( + embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])), + graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), + triples_client=AsyncMock(query=AsyncMock(return_value=[])), + prompt_client=client, + verbose=False + ) + + chunks_with_flags = [] + + async def collect(chunk, end_of_stream): + chunks_with_flags.append((chunk, end_of_stream)) + + # Act + await rag.query( + query="test", + streaming=True, + chunk_callback=collect + ) + + # Assert + assert len(chunks_with_flags) == 4 + assert chunks_with_flags[-1] == ("", True) # Final empty chunk + end_of_stream_flags = [f for c, f in chunks_with_flags] + assert end_of_stream_flags.count(True) == 1 diff --git a/tests/unit/test_base/test_async_processor.py b/tests/unit/test_base/test_async_processor.py index 8e7ad70f..464e459a 100644 --- a/tests/unit/test_base/test_async_processor.py +++ b/tests/unit/test_base/test_async_processor.py @@ -14,19 +14,20 @@ from trustgraph.base.async_processor import AsyncProcessor class TestAsyncProcessorSimple(IsolatedAsyncioTestCase): """Test AsyncProcessor base class functionality""" - @patch('trustgraph.base.async_processor.PulsarClient') + @patch('trustgraph.base.async_processor.get_pubsub') @patch('trustgraph.base.async_processor.Consumer') @patch('trustgraph.base.async_processor.ProcessorMetrics') @patch('trustgraph.base.async_processor.ConsumerMetrics') - async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics, - mock_consumer, mock_pulsar_client): + async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics, + mock_consumer, mock_get_pubsub): """Test basic AsyncProcessor initialization""" # Arrange - mock_pulsar_client.return_value = MagicMock() + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend mock_consumer.return_value = MagicMock() mock_processor_metrics.return_value = MagicMock() mock_consumer_metrics.return_value = MagicMock() - + config = { 'id': 'test-async-processor', 'taskgroup': AsyncMock() @@ -42,14 +43,14 @@ class TestAsyncProcessorSimple(IsolatedAsyncioTestCase): assert processor.running == True assert hasattr(processor, 'config_handlers') assert processor.config_handlers == [] - - # Verify PulsarClient was created - mock_pulsar_client.assert_called_once_with(**config) - + + # Verify get_pubsub was called to create backend + mock_get_pubsub.assert_called_once_with(**config) + # Verify metrics were initialized mock_processor_metrics.assert_called_once() mock_consumer_metrics.assert_called_once() - + # Verify Consumer was created for config subscription mock_consumer.assert_called_once() diff --git a/tests/unit/test_base/test_cassandra_config.py b/tests/unit/test_base/test_cassandra_config.py index 547ff637..5703c7e1 100644 --- a/tests/unit/test_base/test_cassandra_config.py +++ b/tests/unit/test_base/test_cassandra_config.py @@ -145,7 +145,7 @@ class TestResolveCassandraConfig: def test_default_configuration(self): """Test resolution with no parameters or environment variables.""" with patch.dict(os.environ, {}, clear=True): - hosts, username, password = resolve_cassandra_config() + hosts, username, password, keyspace = resolve_cassandra_config() assert hosts == ['cassandra'] assert username is None @@ -160,7 +160,7 @@ class TestResolveCassandraConfig: } with patch.dict(os.environ, env_vars, clear=True): - hosts, username, password = resolve_cassandra_config() + hosts, username, password, keyspace = resolve_cassandra_config() assert hosts == ['env1', 'env2', 'env3'] assert username == 'env-user' @@ -175,7 +175,7 @@ class TestResolveCassandraConfig: } with patch.dict(os.environ, env_vars, clear=True): - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host='explicit-host', username='explicit-user', password='explicit-pass' @@ -188,19 +188,19 @@ class TestResolveCassandraConfig: def test_host_list_parsing(self): """Test different host list formats.""" # Single host - hosts, _, _ = resolve_cassandra_config(host='single-host') + hosts, _, _, _ = resolve_cassandra_config(host='single-host') assert hosts == ['single-host'] # Multiple hosts with spaces - hosts, _, _ = resolve_cassandra_config(host='host1, host2 ,host3') + hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3') assert hosts == ['host1', 'host2', 'host3'] # Empty elements filtered out - hosts, _, _ = resolve_cassandra_config(host='host1,,host2,') + hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,') assert hosts == ['host1', 'host2'] # Already a list - hosts, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2']) + hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2']) assert hosts == ['list-host1', 'list-host2'] def test_args_object_resolution(self): @@ -212,7 +212,7 @@ class TestResolveCassandraConfig: cassandra_password = 'args-pass' args = MockArgs() - hosts, username, password = resolve_cassandra_config(args) + hosts, username, password, keyspace = resolve_cassandra_config(args) assert hosts == ['args-host1', 'args-host2'] assert username == 'args-user' @@ -233,7 +233,7 @@ class TestResolveCassandraConfig: with patch.dict(os.environ, env_vars, clear=True): args = PartialArgs() - hosts, username, password = resolve_cassandra_config(args) + hosts, username, password, keyspace = resolve_cassandra_config(args) assert hosts == ['args-host'] # From args assert username == 'env-user' # From env @@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams: 'cassandra_password': 'new-pass' } - hosts, username, password = get_cassandra_config_from_params(params) + hosts, username, password, keyspace = get_cassandra_config_from_params(params) assert hosts == ['new-host1', 'new-host2'] assert username == 'new-user' @@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams: 'graph_password': 'old-pass' } - hosts, username, password = get_cassandra_config_from_params(params) + hosts, username, password, keyspace = get_cassandra_config_from_params(params) # Should use defaults since graph_* params are not recognized assert hosts == ['cassandra'] # Default @@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams: 'cassandra_password': 'compat-pass' } - hosts, username, password = get_cassandra_config_from_params(params) + hosts, username, password, keyspace = get_cassandra_config_from_params(params) assert hosts == ['compat-host'] assert username is None # cassandra_user is not recognized @@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams: 'graph_password': 'old-pass' } - hosts, username, password = get_cassandra_config_from_params(params) + hosts, username, password, keyspace = get_cassandra_config_from_params(params) assert hosts == ['new-host'] # Only cassandra_* params work assert username == 'new-user' # Only cassandra_* params work @@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams: with patch.dict(os.environ, env_vars, clear=True): params = {} - hosts, username, password = get_cassandra_config_from_params(params) + hosts, username, password, keyspace = get_cassandra_config_from_params(params) assert hosts == ['fallback-host1', 'fallback-host2'] assert username == 'fallback-user' @@ -334,7 +334,7 @@ class TestConfigurationPriority: with patch.dict(os.environ, env_vars, clear=True): # CLI args should override everything - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host='cli-host', username='cli-user', password='cli-pass' @@ -354,7 +354,7 @@ class TestConfigurationPriority: with patch.dict(os.environ, env_vars, clear=True): # Only provide host via CLI - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host='cli-host' # username and password not provided ) @@ -366,7 +366,7 @@ class TestConfigurationPriority: def test_no_config_defaults(self): """Test that defaults are used when no configuration is provided.""" with patch.dict(os.environ, {}, clear=True): - hosts, username, password = resolve_cassandra_config() + hosts, username, password, keyspace = resolve_cassandra_config() assert hosts == ['cassandra'] # Default assert username is None # Default @@ -378,17 +378,17 @@ class TestEdgeCases: def test_empty_host_string(self): """Test handling of empty host string falls back to default.""" - hosts, _, _ = resolve_cassandra_config(host='') + hosts, _, _, _ = resolve_cassandra_config(host='') assert hosts == ['cassandra'] # Falls back to default def test_whitespace_only_host(self): """Test handling of whitespace-only host string.""" - hosts, _, _ = resolve_cassandra_config(host=' ') + hosts, _, _, _ = resolve_cassandra_config(host=' ') assert hosts == [] # Empty after stripping whitespace def test_none_values_preserved(self): """Test that None values are preserved correctly.""" - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=None, username=None, password=None @@ -401,7 +401,7 @@ class TestEdgeCases: def test_mixed_none_and_values(self): """Test mixing None and actual values.""" - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host='mixed-host', username=None, password='mixed-pass' diff --git a/tests/unit/test_base/test_prompt_client_streaming.py b/tests/unit/test_base/test_prompt_client_streaming.py new file mode 100644 index 00000000..83a4b90e --- /dev/null +++ b/tests/unit/test_base/test_prompt_client_streaming.py @@ -0,0 +1,260 @@ +""" +Unit tests for PromptClient streaming callback behavior. + +These tests verify that the prompt client correctly passes the end_of_stream +flag to chunk callbacks, ensuring proper streaming protocol compliance. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call, patch +from trustgraph.base.prompt_client import PromptClient +from trustgraph.schema import PromptResponse + + +class TestPromptClientStreamingCallback: + """Test PromptClient streaming callback behavior""" + + @pytest.fixture + def prompt_client(self): + """Create a PromptClient with mocked dependencies""" + # Mock all the required initialization parameters + with patch.object(PromptClient, '__init__', lambda self: None): + client = PromptClient() + return client + + @pytest.fixture + def mock_request_response(self): + """Create a mock request/response handler""" + async def mock_request(request, recipient=None, timeout=600): + if recipient: + # Simulate streaming responses + responses = [ + PromptResponse(text="Hello", object=None, error=None, end_of_stream=False), + PromptResponse(text=" world", object=None, error=None, end_of_stream=False), + PromptResponse(text="!", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + else: + # Non-streaming response + return PromptResponse(text="Hello world!", object=None, error=None) + + return mock_request + + @pytest.mark.asyncio + async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response): + """Test that callback receives both chunk text and end_of_stream flag""" + # Arrange + prompt_client.request = mock_request_response + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should be called with (chunk, end_of_stream) signature + assert callback.call_count == 4 + + # Verify first chunk: text + end_of_stream=False + assert callback.call_args_list[0] == call("Hello", False) + + # Verify second chunk + assert callback.call_args_list[1] == call(" world", False) + + # Verify third chunk + assert callback.call_args_list[2] == call("!", False) + + # Verify final chunk: empty text + end_of_stream=True + assert callback.call_args_list[3] == call("", True) + + @pytest.mark.asyncio + async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response): + """Test that empty final chunks are passed to callback""" + # Arrange + prompt_client.request = mock_request_response + + chunks_received = [] + + async def collect_chunks(chunk, end_of_stream): + chunks_received.append((chunk, end_of_stream)) + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=collect_chunks + ) + + # Assert - should receive the empty final chunk + final_chunk = chunks_received[-1] + assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True" + + @pytest.mark.asyncio + async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client): + """Test callback signature when LLM sends non-empty final chunk""" + # Arrange + async def mock_request_non_empty_final(request, recipient=None, timeout=600): + if recipient: + # Some LLMs send content in the final chunk + responses = [ + PromptResponse(text="Hello", object=None, error=None, end_of_stream=False), + PromptResponse(text=" world!", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request_non_empty_final + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Hello", False) + assert callback.call_args_list[1] == call(" world!", True) + + @pytest.mark.asyncio + async def test_callback_not_called_without_text(self, prompt_client): + """Test that callback is not called for responses without text""" + # Arrange + async def mock_request_no_text(request, recipient=None, timeout=600): + if recipient: + # Response with only end_of_stream, no text + responses = [ + PromptResponse(text="Content", object=None, error=None, end_of_stream=False), + PromptResponse(text=None, object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request_no_text + + callback = AsyncMock() + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - callback should only be called once (for "Content") + assert callback.call_count == 1 + assert callback.call_args_list[0] == call("Content", False) + + @pytest.mark.asyncio + async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client): + """Test that synchronous callbacks also receive end_of_stream parameter""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="test", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = MagicMock() # Synchronous mock + + # Act + await prompt_client.prompt( + id="test-prompt", + variables={"query": "test"}, + streaming=True, + chunk_callback=callback + ) + + # Assert - synchronous callback should also get both parameters + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("test", False) + assert callback.call_args_list[1] == call("", True) + + @pytest.mark.asyncio + async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client): + """Test that kg_prompt correctly passes streaming parameters""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="Answer", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = AsyncMock() + + # Act + await prompt_client.kg_prompt( + query="What is machine learning?", + kg=[("subject", "predicate", "object")], + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Answer", False) + assert callback.call_args_list[1] == call("", True) + + @pytest.mark.asyncio + async def test_document_prompt_passes_parameters_to_callback(self, prompt_client): + """Test that document_prompt correctly passes streaming parameters""" + # Arrange + async def mock_request(request, recipient=None, timeout=600): + if recipient: + responses = [ + PromptResponse(text="Summary", object=None, error=None, end_of_stream=False), + PromptResponse(text="", object=None, error=None, end_of_stream=True), + ] + for resp in responses: + should_stop = await recipient(resp) + if should_stop: + break + + prompt_client.request = mock_request + + callback = AsyncMock() + + # Act + await prompt_client.document_prompt( + query="Summarize this", + documents=["doc1", "doc2"], + streaming=True, + chunk_callback=callback + ) + + # Assert + assert callback.call_count == 2 + assert callback.call_args_list[0] == call("Summary", False) + assert callback.call_args_list[1] == call("", True) diff --git a/tests/unit/test_base/test_publisher_graceful_shutdown.py b/tests/unit/test_base/test_publisher_graceful_shutdown.py index e15cb1ec..3c5cb967 100644 --- a/tests/unit/test_base/test_publisher_graceful_shutdown.py +++ b/tests/unit/test_base/test_publisher_graceful_shutdown.py @@ -8,22 +8,22 @@ from trustgraph.base.publisher import Publisher @pytest.fixture -def mock_pulsar_client(): - """Mock Pulsar client for testing.""" - client = MagicMock() +def mock_pulsar_backend(): + """Mock Pulsar backend for testing.""" + backend = MagicMock() producer = AsyncMock() producer.send = MagicMock() producer.flush = MagicMock() producer.close = MagicMock() - client.create_producer.return_value = producer - return client + backend.create_producer.return_value = producer + return backend @pytest.fixture -def publisher(mock_pulsar_client): +def publisher(mock_pulsar_backend): """Create Publisher instance for testing.""" return Publisher( - client=mock_pulsar_client, + backend=mock_pulsar_backend, topic="test-topic", schema=dict, max_size=10, @@ -34,12 +34,12 @@ def publisher(mock_pulsar_client): @pytest.mark.asyncio async def test_publisher_queue_drain(): """Verify Publisher drains queue on shutdown.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, @@ -85,12 +85,12 @@ async def test_publisher_queue_drain(): @pytest.mark.asyncio async def test_publisher_rejects_messages_during_drain(): """Verify Publisher rejects new messages during shutdown.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, @@ -113,12 +113,12 @@ async def test_publisher_rejects_messages_during_drain(): @pytest.mark.asyncio async def test_publisher_drain_timeout(): """Verify Publisher respects drain timeout.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, @@ -169,12 +169,12 @@ async def test_publisher_drain_timeout(): @pytest.mark.asyncio async def test_publisher_successful_drain(): """Verify Publisher drains successfully under normal conditions.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, @@ -224,12 +224,12 @@ async def test_publisher_successful_drain(): @pytest.mark.asyncio async def test_publisher_state_transitions(): """Test Publisher state transitions during graceful shutdown.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, @@ -276,9 +276,9 @@ async def test_publisher_state_transitions(): @pytest.mark.asyncio async def test_publisher_exception_handling(): """Test Publisher handles exceptions during drain gracefully.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_producer = MagicMock() - mock_client.create_producer.return_value = mock_producer + mock_backend.create_producer.return_value = mock_producer # Mock producer.send to raise exception on second call call_count = 0 @@ -291,7 +291,7 @@ async def test_publisher_exception_handling(): mock_producer.send.side_effect = failing_send publisher = Publisher( - client=mock_client, + backend=mock_backend, topic="test-topic", schema=dict, max_size=10, diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py index 1a3f8b82..ea5d04cc 100644 --- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -6,23 +6,11 @@ import uuid from unittest.mock import AsyncMock, MagicMock, patch from trustgraph.base.subscriber import Subscriber -# Mock JsonSchema globally to avoid schema issues in tests -# Patch at the module level where it's imported in subscriber -@patch('trustgraph.base.subscriber.JsonSchema') -def mock_json_schema_global(mock_schema): - mock_schema.return_value = MagicMock() - return mock_schema - -# Apply the global patch -_json_schema_patch = patch('trustgraph.base.subscriber.JsonSchema') -_mock_json_schema = _json_schema_patch.start() -_mock_json_schema.return_value = MagicMock() - @pytest.fixture -def mock_pulsar_client(): - """Mock Pulsar client for testing.""" - client = MagicMock() +def mock_pulsar_backend(): + """Mock Pulsar backend for testing.""" + backend = MagicMock() consumer = MagicMock() consumer.receive = MagicMock() consumer.acknowledge = MagicMock() @@ -30,15 +18,15 @@ def mock_pulsar_client(): consumer.pause_message_listener = MagicMock() consumer.unsubscribe = MagicMock() consumer.close = MagicMock() - client.subscribe.return_value = consumer - return client + backend.create_consumer.return_value = consumer + return backend @pytest.fixture -def subscriber(mock_pulsar_client): +def subscriber(mock_pulsar_backend): """Create Subscriber instance for testing.""" return Subscriber( - client=mock_pulsar_client, + backend=mock_pulsar_backend, topic="test-topic", subscription="test-subscription", consumer_name="test-consumer", @@ -60,14 +48,14 @@ def create_mock_message(message_id="test-id", data=None): @pytest.mark.asyncio async def test_subscriber_deferred_acknowledgment_success(): """Verify Subscriber only acks on successful delivery.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", - subscription="test-subscription", + subscription="test-subscription", consumer_name="test-consumer", schema=dict, max_size=10, @@ -102,15 +90,15 @@ async def test_subscriber_deferred_acknowledgment_success(): @pytest.mark.asyncio async def test_subscriber_deferred_acknowledgment_failure(): """Verify Subscriber negative acks on delivery failure.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", subscription="test-subscription", - consumer_name="test-consumer", + consumer_name="test-consumer", schema=dict, max_size=1, # Very small queue backpressure_strategy="drop_new" @@ -140,14 +128,14 @@ async def test_subscriber_deferred_acknowledgment_failure(): @pytest.mark.asyncio async def test_subscriber_backpressure_strategies(): """Test different backpressure strategies.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + # Test drop_oldest strategy subscriber = Subscriber( - client=mock_client, - topic="test-topic", + backend=mock_backend, + topic="test-topic", subscription="test-subscription", consumer_name="test-consumer", schema=dict, @@ -187,12 +175,12 @@ async def test_subscriber_backpressure_strategies(): @pytest.mark.asyncio async def test_subscriber_graceful_shutdown(): """Test Subscriber graceful shutdown with queue draining.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", subscription="test-subscription", consumer_name="test-consumer", @@ -253,14 +241,14 @@ async def test_subscriber_graceful_shutdown(): @pytest.mark.asyncio async def test_subscriber_drain_timeout(): """Test Subscriber respects drain timeout.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", - subscription="test-subscription", + subscription="test-subscription", consumer_name="test-consumer", schema=dict, max_size=10, @@ -288,12 +276,12 @@ async def test_subscriber_drain_timeout(): @pytest.mark.asyncio async def test_subscriber_pending_acks_cleanup(): """Test Subscriber cleans up pending acknowledgments on shutdown.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", subscription="test-subscription", consumer_name="test-consumer", @@ -342,12 +330,12 @@ async def test_subscriber_pending_acks_cleanup(): @pytest.mark.asyncio async def test_subscriber_multiple_subscribers(): """Test Subscriber with multiple concurrent subscribers.""" - mock_client = MagicMock() + mock_backend = MagicMock() mock_consumer = MagicMock() - mock_client.subscribe.return_value = mock_consumer - + mock_backend.create_consumer.return_value = mock_consumer + subscriber = Subscriber( - client=mock_client, + backend=mock_backend, topic="test-topic", subscription="test-subscription", consumer_name="test-consumer", diff --git a/tests/unit/test_cli/test_config_commands.py b/tests/unit/test_cli/test_config_commands.py index 286054b9..68ae1a54 100644 --- a/tests/unit/test_cli/test_config_commands.py +++ b/tests/unit/test_cli/test_config_commands.py @@ -108,7 +108,8 @@ class TestListConfigItems: mock_list.assert_called_once_with( url='http://custom.com', config_type='prompt', - format_type='json' + format_type='json', + token=None ) def test_list_main_uses_defaults(self): @@ -126,7 +127,8 @@ class TestListConfigItems: mock_list.assert_called_once_with( url='http://localhost:8088/', config_type='prompt', - format_type='text' + format_type='text', + token=None ) @@ -193,7 +195,8 @@ class TestGetConfigItem: url='http://custom.com', config_type='prompt', key='template-1', - format_type='json' + format_type='json', + token=None ) @@ -249,7 +252,8 @@ class TestPutConfigItem: url='http://custom.com', config_type='prompt', key='new-template', - value='Custom prompt: {input}' + value='Custom prompt: {input}', + token=None ) def test_put_main_with_stdin_arg(self): @@ -273,7 +277,8 @@ class TestPutConfigItem: url='http://localhost:8088/', config_type='prompt', key='stdin-template', - value=stdin_content + value=stdin_content, + token=None ) def test_put_main_mutually_exclusive_args(self): @@ -328,7 +333,8 @@ class TestDeleteConfigItem: mock_delete.assert_called_once_with( url='http://custom.com', config_type='prompt', - key='old-template' + key='old-template', + token=None ) diff --git a/tests/unit/test_cli/test_load_knowledge.py b/tests/unit/test_cli/test_load_knowledge.py index c7070200..63045ef9 100644 --- a/tests/unit/test_cli/test_load_knowledge.py +++ b/tests/unit/test_cli/test_load_knowledge.py @@ -2,17 +2,16 @@ Unit tests for the load_knowledge CLI module. Tests the business logic of loading triples and entity contexts from Turtle files -while mocking WebSocket connections and external dependencies. +using the BulkClient API. """ import pytest -import json import tempfile -import asyncio -from unittest.mock import AsyncMock, Mock, patch, mock_open, MagicMock +from unittest.mock import Mock, patch, MagicMock, call from pathlib import Path from trustgraph.cli.load_knowledge import KnowledgeLoader, main +from trustgraph.api import Triple @pytest.fixture @@ -38,159 +37,80 @@ def temp_turtle_file(sample_turtle_content): f.write(sample_turtle_content) f.flush() yield f.name - + # Cleanup Path(f.name).unlink(missing_ok=True) -@pytest.fixture -def mock_websocket(): - """Mock WebSocket connection.""" - mock_ws = MagicMock() - - async def async_send(data): - return None - - async def async_recv(): - return "" - - async def async_close(): - return None - - mock_ws.send = Mock(side_effect=async_send) - mock_ws.recv = Mock(side_effect=async_recv) - mock_ws.close = Mock(side_effect=async_close) - return mock_ws - - @pytest.fixture def knowledge_loader(): """Create a KnowledgeLoader instance with test parameters.""" return KnowledgeLoader( files=["test.ttl"], flow="test-flow", - user="test-user", + user="test-user", collection="test-collection", document_id="test-doc-123", - url="ws://test.example.com/" + url="http://test.example.com/", + token=None ) class TestKnowledgeLoader: """Test the KnowledgeLoader class business logic.""" - def test_init_constructs_urls_correctly(self): - """Test that URLs are constructed properly.""" + def test_init_stores_parameters_correctly(self): + """Test that initialization stores parameters correctly.""" loader = KnowledgeLoader( - files=["test.ttl"], + files=["file1.ttl", "file2.ttl"], flow="my-flow", user="user1", - collection="col1", - document_id="doc1", - url="ws://example.com/" - ) - - assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples" - assert loader.entity_contexts_url == "ws://example.com/api/v1/flow/my-flow/import/entity-contexts" - assert loader.user == "user1" - assert loader.collection == "col1" - assert loader.document_id == "doc1" - - def test_init_adds_trailing_slash(self): - """Test that trailing slash is added to URL if missing.""" - loader = KnowledgeLoader( - files=["test.ttl"], - flow="my-flow", - user="user1", collection="col1", document_id="doc1", - url="ws://example.com" # No trailing slash + url="http://example.com/", + token="test-token" ) - - assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples" - @pytest.mark.asyncio - async def test_load_triples_sends_correct_messages(self, temp_turtle_file, mock_websocket): - """Test that triple loading sends correctly formatted messages.""" - loader = KnowledgeLoader( - files=[temp_turtle_file], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - - await loader.load_triples(temp_turtle_file, mock_websocket) - - # Verify WebSocket send was called - assert mock_websocket.send.call_count > 0 - - # Check message format for one of the calls - sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list] - - # Verify message structure - sample_message = sent_messages[0] - assert "metadata" in sample_message - assert "triples" in sample_message - - metadata = sample_message["metadata"] - assert metadata["id"] == "test-doc" - assert metadata["user"] == "test-user" - assert metadata["collection"] == "test-collection" - assert isinstance(metadata["metadata"], list) - - triple = sample_message["triples"][0] - assert "s" in triple - assert "p" in triple - assert "o" in triple - - # Check Value structure - assert "v" in triple["s"] - assert "e" in triple["s"] - assert triple["s"]["e"] is True # Subject should be URI + assert loader.files == ["file1.ttl", "file2.ttl"] + assert loader.flow == "my-flow" + assert loader.user == "user1" + assert loader.collection == "col1" + assert loader.document_id == "doc1" + assert loader.url == "http://example.com/" + assert loader.token == "test-token" - @pytest.mark.asyncio - async def test_load_entity_contexts_processes_literals_only(self, temp_turtle_file, mock_websocket): + def test_load_triples_from_file_yields_triples(self, temp_turtle_file, knowledge_loader): + """Test that load_triples_from_file yields Triple objects.""" + triples = list(knowledge_loader.load_triples_from_file(temp_turtle_file)) + + # Should have triples for all statements in the file + assert len(triples) > 0 + + # Verify they are Triple objects + for triple in triples: + assert isinstance(triple, Triple) + assert hasattr(triple, 's') + assert hasattr(triple, 'p') + assert hasattr(triple, 'o') + assert isinstance(triple.s, str) + assert isinstance(triple.p, str) + assert isinstance(triple.o, str) + + def test_load_entity_contexts_from_file_yields_literals_only(self, temp_turtle_file, knowledge_loader): """Test that entity contexts are created only for literals.""" - loader = KnowledgeLoader( - files=[temp_turtle_file], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - - await loader.load_entity_contexts(temp_turtle_file, mock_websocket) - - # Get all sent messages - sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list] - - # Verify we got entity context messages - assert len(sent_messages) > 0 - - for message in sent_messages: - assert "metadata" in message - assert "entities" in message - - metadata = message["metadata"] - assert metadata["id"] == "test-doc" - assert metadata["user"] == "test-user" - assert metadata["collection"] == "test-collection" - - entity_context = message["entities"][0] - assert "entity" in entity_context - assert "context" in entity_context - - entity = entity_context["entity"] - assert "v" in entity - assert "e" in entity - assert entity["e"] is True # Entity should be URI (subject) - - # Context should be a string (the literal value) - assert isinstance(entity_context["context"], str) + contexts = list(knowledge_loader.load_entity_contexts_from_file(temp_turtle_file)) - @pytest.mark.asyncio - async def test_load_entity_contexts_skips_uri_objects(self, mock_websocket): + # Should have contexts for literal objects (foaf:name, foaf:age, foaf:email) + assert len(contexts) > 0 + + # Verify format: (entity, context) tuples + for entity, context in contexts: + assert isinstance(entity, str) + assert isinstance(context, str) + # Entity should be a URI (subject) + assert entity.startswith("http://") + + def test_load_entity_contexts_skips_uri_objects(self): """Test that URI objects don't generate entity contexts.""" # Create turtle with only URI objects (no literals) turtle_content = """ @@ -198,242 +118,229 @@ class TestKnowledgeLoader: ex:john ex:knows ex:mary . ex:mary ex:knows ex:bob . """ - + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: f.write(turtle_content) f.flush() - + loader = KnowledgeLoader( files=[f.name], flow="test-flow", - user="test-user", + user="test-user", collection="test-collection", - document_id="test-doc" + document_id="test-doc", + url="http://test.example.com/" ) - - await loader.load_entity_contexts(f.name, mock_websocket) - - Path(f.name).unlink(missing_ok=True) - - # Should not send any messages since there are no literals - mock_websocket.send.assert_not_called() - @pytest.mark.asyncio - @patch('trustgraph.cli.load_knowledge.connect') - async def test_run_calls_both_loaders(self, mock_connect, knowledge_loader, temp_turtle_file): - """Test that run() calls both triple and entity context loaders.""" - knowledge_loader.files = [temp_turtle_file] - - # Create a simple mock websocket - mock_ws = MagicMock() - async def mock_send(data): - pass - mock_ws.send = mock_send - - # Create async context manager mock - async def mock_aenter(self): - return mock_ws - - async def mock_aexit(self, exc_type, exc_val, exc_tb): - return None - - mock_connection = MagicMock() - mock_connection.__aenter__ = mock_aenter - mock_connection.__aexit__ = mock_aexit - mock_connect.return_value = mock_connection - - # Create AsyncMock objects that can track calls properly - mock_load_triples = AsyncMock(return_value=None) - mock_load_contexts = AsyncMock(return_value=None) - - with patch.object(knowledge_loader, 'load_triples', mock_load_triples), \ - patch.object(knowledge_loader, 'load_entity_contexts', mock_load_contexts): - - await knowledge_loader.run() - - # Verify both methods were called - mock_load_triples.assert_called_once_with(temp_turtle_file, mock_ws) - mock_load_contexts.assert_called_once_with(temp_turtle_file, mock_ws) - - # Verify WebSocket connections were made to both URLs - assert mock_connect.call_count == 2 + contexts = list(loader.load_entity_contexts_from_file(f.name)) + + Path(f.name).unlink(missing_ok=True) + + # Should have no contexts since there are no literals + assert len(contexts) == 0 + + @patch('trustgraph.cli.load_knowledge.Api') + def test_run_calls_bulk_api(self, mock_api_class, temp_turtle_file): + """Test that run() uses BulkClient API.""" + # Setup mocks + mock_api = MagicMock() + mock_bulk = MagicMock() + mock_api_class.return_value = mock_api + mock_api.bulk.return_value = mock_bulk + + loader = KnowledgeLoader( + files=[temp_turtle_file], + flow="test-flow", + user="test-user", + collection="test-collection", + document_id="test-doc", + url="http://test.example.com/", + token="test-token" + ) + + loader.run() + + # Verify Api was created with correct parameters + mock_api_class.assert_called_once_with( + url="http://test.example.com/", + token="test-token" + ) + + # Verify bulk client was obtained + mock_api.bulk.assert_called_once() + + # Verify import_triples was called + assert mock_bulk.import_triples.call_count == 1 + call_args = mock_bulk.import_triples.call_args + assert call_args[1]['flow'] == "test-flow" + assert call_args[1]['metadata']['id'] == "test-doc" + assert call_args[1]['metadata']['user'] == "test-user" + assert call_args[1]['metadata']['collection'] == "test-collection" + + # Verify import_entity_contexts was called + assert mock_bulk.import_entity_contexts.call_count == 1 + call_args = mock_bulk.import_entity_contexts.call_args + assert call_args[1]['flow'] == "test-flow" + assert call_args[1]['metadata']['id'] == "test-doc" class TestCLIArgumentParsing: """Test CLI argument parsing and main function.""" @patch('trustgraph.cli.load_knowledge.KnowledgeLoader') - @patch('trustgraph.cli.load_knowledge.asyncio.run') - def test_main_parses_args_correctly(self, mock_asyncio_run, mock_loader_class): + @patch('trustgraph.cli.load_knowledge.time.sleep') + def test_main_parses_args_correctly(self, mock_sleep, mock_loader_class): """Test that main() parses arguments correctly.""" mock_loader_instance = MagicMock() mock_loader_class.return_value = mock_loader_instance - + test_args = [ 'tg-load-knowledge', '-i', 'doc-123', - '-f', 'my-flow', + '-f', 'my-flow', '-U', 'my-user', '-C', 'my-collection', - '-u', 'ws://custom.example.com/', + '-u', 'http://custom.example.com/', + '-t', 'my-token', 'file1.ttl', 'file2.ttl' ] - + with patch('sys.argv', test_args): main() - + # Verify KnowledgeLoader was instantiated with correct args mock_loader_class.assert_called_once_with( document_id='doc-123', - url='ws://custom.example.com/', + url='http://custom.example.com/', + token='my-token', flow='my-flow', files=['file1.ttl', 'file2.ttl'], user='my-user', collection='my-collection' ) - - # Verify asyncio.run was called once - mock_asyncio_run.assert_called_once() + + # Verify run was called + mock_loader_instance.run.assert_called_once() @patch('trustgraph.cli.load_knowledge.KnowledgeLoader') - @patch('trustgraph.cli.load_knowledge.asyncio.run') - def test_main_uses_defaults(self, mock_asyncio_run, mock_loader_class): + @patch('trustgraph.cli.load_knowledge.time.sleep') + def test_main_uses_defaults(self, mock_sleep, mock_loader_class): """Test that main() uses default values when not specified.""" mock_loader_instance = MagicMock() mock_loader_class.return_value = mock_loader_instance - + test_args = [ 'tg-load-knowledge', '-i', 'doc-123', 'file1.ttl' ] - + with patch('sys.argv', test_args): main() - + # Verify defaults were used call_args = mock_loader_class.call_args[1] assert call_args['flow'] == 'default' assert call_args['user'] == 'trustgraph' assert call_args['collection'] == 'default' - assert call_args['url'] == 'ws://localhost:8088/' + assert call_args['url'] == 'http://localhost:8088/' + assert call_args['token'] is None class TestErrorHandling: """Test error handling scenarios.""" - @pytest.mark.asyncio - async def test_load_triples_handles_invalid_turtle(self, mock_websocket): + def test_load_triples_handles_invalid_turtle(self, knowledge_loader): """Test handling of invalid Turtle content.""" # Create file with invalid Turtle content with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: f.write("Invalid Turtle Content {{{") f.flush() - - loader = KnowledgeLoader( - files=[f.name], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - + # Should raise an exception for invalid Turtle with pytest.raises(Exception): - await loader.load_triples(f.name, mock_websocket) - + list(knowledge_loader.load_triples_from_file(f.name)) + Path(f.name).unlink(missing_ok=True) - @pytest.mark.asyncio - async def test_load_entity_contexts_handles_invalid_turtle(self, mock_websocket): + def test_load_entity_contexts_handles_invalid_turtle(self, knowledge_loader): """Test handling of invalid Turtle content in entity contexts.""" # Create file with invalid Turtle content with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: f.write("Invalid Turtle Content {{{") f.flush() - - loader = KnowledgeLoader( - files=[f.name], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - + # Should raise an exception for invalid Turtle with pytest.raises(Exception): - await loader.load_entity_contexts(f.name, mock_websocket) - + list(knowledge_loader.load_entity_contexts_from_file(f.name)) + Path(f.name).unlink(missing_ok=True) - @pytest.mark.asyncio - @patch('trustgraph.cli.load_knowledge.connect') + @patch('trustgraph.cli.load_knowledge.Api') @patch('builtins.print') # Mock print to avoid output during tests - async def test_run_handles_connection_errors(self, mock_print, mock_connect, knowledge_loader, temp_turtle_file): - """Test handling of WebSocket connection errors.""" - knowledge_loader.files = [temp_turtle_file] - - # Mock connection failure - mock_connect.side_effect = ConnectionError("Failed to connect") - - # Should not raise exception, just print error - await knowledge_loader.run() + def test_run_handles_api_errors(self, mock_print, mock_api_class, temp_turtle_file): + """Test handling of API errors.""" + # Mock API to raise an error + mock_api_class.side_effect = Exception("API connection failed") + + loader = KnowledgeLoader( + files=[temp_turtle_file], + flow="test-flow", + user="test-user", + collection="test-collection", + document_id="test-doc", + url="http://test.example.com/" + ) + + # Should raise the exception + with pytest.raises(Exception, match="API connection failed"): + loader.run() @patch('trustgraph.cli.load_knowledge.KnowledgeLoader') - @patch('trustgraph.cli.load_knowledge.asyncio.run') @patch('trustgraph.cli.load_knowledge.time.sleep') @patch('builtins.print') # Mock print to avoid output during tests - def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_asyncio_run, mock_loader_class): + def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_loader_class): """Test that main() retries on exceptions.""" mock_loader_instance = MagicMock() mock_loader_class.return_value = mock_loader_instance - + # First call raises exception, second succeeds - mock_asyncio_run.side_effect = [Exception("Test error"), None] - + mock_loader_instance.run.side_effect = [Exception("Test error"), None] + test_args = [ 'tg-load-knowledge', - '-i', 'doc-123', + '-i', 'doc-123', 'file1.ttl' ] - + with patch('sys.argv', test_args): main() - + # Should have been called twice (first failed, second succeeded) - assert mock_asyncio_run.call_count == 2 + assert mock_loader_instance.run.call_count == 2 mock_sleep.assert_called_once_with(10) class TestDataValidation: """Test data validation and edge cases.""" - @pytest.mark.asyncio - async def test_empty_turtle_file(self, mock_websocket): + def test_empty_turtle_file(self, knowledge_loader): """Test handling of empty Turtle files.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: f.write("") # Empty file f.flush() - - loader = KnowledgeLoader( - files=[f.name], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - - await loader.load_triples(f.name, mock_websocket) - await loader.load_entity_contexts(f.name, mock_websocket) - - # Should not send any messages for empty file - mock_websocket.send.assert_not_called() - + + triples = list(knowledge_loader.load_triples_from_file(f.name)) + contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name)) + + # Should return empty lists for empty file + assert len(triples) == 0 + assert len(contexts) == 0 + Path(f.name).unlink(missing_ok=True) - @pytest.mark.asyncio - async def test_turtle_with_mixed_literals_and_uris(self, mock_websocket): + def test_turtle_with_mixed_literals_and_uris(self, knowledge_loader): """Test handling of Turtle with mixed literal and URI objects.""" turtle_content = """ @prefix ex: . @@ -443,37 +350,23 @@ ex:john ex:name "John Smith" ; ex:city "New York" . ex:mary ex:name "Mary Johnson" . """ - + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: f.write(turtle_content) f.flush() - - loader = KnowledgeLoader( - files=[f.name], - flow="test-flow", - user="test-user", - collection="test-collection", - document_id="test-doc" - ) - - await loader.load_entity_contexts(f.name, mock_websocket) - - sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list] - + + contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name)) + # Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson") # URI ex:mary should be skipped - assert len(sent_messages) == 4 - + assert len(contexts) == 4 + # Verify all contexts are for literals (subjects should be URIs) - contexts = [] - for message in sent_messages: - entity_context = message["entities"][0] - assert entity_context["entity"]["e"] is True # Subject is URI - contexts.append(entity_context["context"]) - - assert "John Smith" in contexts - assert "25" in contexts - assert "New York" in contexts - assert "Mary Johnson" in contexts - - Path(f.name).unlink(missing_ok=True) \ No newline at end of file + context_values = [context for entity, context in contexts] + + assert "John Smith" in context_values + assert "25" in context_values + assert "New York" in context_values + assert "Mary Johnson" in context_values + + Path(f.name).unlink(missing_ok=True) diff --git a/tests/unit/test_cli/test_tool_commands.py b/tests/unit/test_cli/test_tool_commands.py index 64cf9441..913fe416 100644 --- a/tests/unit/test_cli/test_tool_commands.py +++ b/tests/unit/test_cli/test_tool_commands.py @@ -135,7 +135,8 @@ class TestSetToolStructuredQuery: arguments=[], group=None, state=None, - applicable_states=None + applicable_states=None, + token=None ) def test_set_main_structured_query_no_arguments_needed(self): @@ -313,7 +314,7 @@ class TestShowToolsStructuredQuery: show_main() - mock_show.assert_called_once_with(url='http://custom.com') + mock_show.assert_called_once_with(url='http://custom.com', token=None) class TestStructuredQueryToolValidation: diff --git a/tests/unit/test_gateway/test_config_receiver.py b/tests/unit/test_gateway/test_config_receiver.py index c186c768..ee500766 100644 --- a/tests/unit/test_gateway/test_config_receiver.py +++ b/tests/unit/test_gateway/test_config_receiver.py @@ -22,18 +22,18 @@ class TestConfigReceiver: def test_config_receiver_initialization(self): """Test ConfigReceiver initialization""" - mock_pulsar_client = Mock() + mock_backend = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + config_receiver = ConfigReceiver(mock_backend) - assert config_receiver.pulsar_client == mock_pulsar_client + assert config_receiver.backend == mock_backend assert config_receiver.flow_handlers == [] assert config_receiver.flows == {} def test_add_handler(self): """Test adding flow handlers""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) handler1 = Mock() handler2 = Mock() @@ -48,8 +48,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_with_new_flows(self): """Test on_config method with new flows""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Track calls manually instead of using AsyncMock start_flow_calls = [] @@ -87,8 +87,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_with_removed_flows(self): """Test on_config method with removed flows""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Pre-populate with existing flows config_receiver.flows = { @@ -128,8 +128,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_with_no_flows(self): """Test on_config method with no flows in config""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Mock the start_flow and stop_flow methods with async functions async def mock_start_flow(*args): @@ -158,8 +158,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_exception_handling(self): """Test on_config method handles exceptions gracefully""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Create mock message that will cause an exception mock_msg = Mock() @@ -174,8 +174,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_start_flow_with_handlers(self): """Test start_flow method with multiple handlers""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Add mock handlers handler1 = Mock() @@ -197,8 +197,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_start_flow_with_handler_exception(self): """Test start_flow method handles handler exceptions""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Add mock handler that raises exception handler = Mock() @@ -217,8 +217,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_stop_flow_with_handlers(self): """Test stop_flow method with multiple handlers""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Add mock handlers handler1 = Mock() @@ -240,8 +240,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_stop_flow_with_handler_exception(self): """Test stop_flow method handles handler exceptions""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Add mock handler that raises exception handler = Mock() @@ -260,9 +260,9 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_config_loader_creates_consumer(self): """Test config_loader method creates Pulsar consumer""" - mock_pulsar_client = Mock() + mock_backend = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + config_receiver = ConfigReceiver(mock_backend) # Temporarily restore the real config_loader for this test config_receiver.config_loader = _real_config_loader.__get__(config_receiver) @@ -291,8 +291,8 @@ class TestConfigReceiver: # Verify Consumer was created with correct parameters mock_consumer_class.assert_called_once() call_args = mock_consumer_class.call_args - - assert call_args[1]['client'] == mock_pulsar_client + + assert call_args[1]['backend'] == mock_backend assert call_args[1]['subscriber'] == "gateway-test-uuid" assert call_args[1]['handler'] == config_receiver.on_config assert call_args[1]['start_of_messages'] is True @@ -301,8 +301,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_start_creates_config_loader_task(self, mock_create_task): """Test start method creates config loader task""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Mock create_task to avoid actually creating tasks with real coroutines mock_task = Mock() @@ -320,8 +320,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_mixed_flow_operations(self): """Test on_config with mixed add/remove operations""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Pre-populate with existing flows config_receiver.flows = { @@ -380,8 +380,8 @@ class TestConfigReceiver: @pytest.mark.asyncio async def test_on_config_invalid_json_flow_data(self): """Test on_config handles invalid JSON in flow data""" - mock_pulsar_client = Mock() - config_receiver = ConfigReceiver(mock_pulsar_client) + mock_backend = Mock() + config_receiver = ConfigReceiver(mock_backend) # Mock the start_flow method with an async function async def mock_start_flow(*args): diff --git a/tests/unit/test_gateway/test_dispatch_config.py b/tests/unit/test_gateway/test_dispatch_config.py index df319bdc..4fbd8484 100644 --- a/tests/unit/test_gateway/test_dispatch_config.py +++ b/tests/unit/test_gateway/test_dispatch_config.py @@ -24,10 +24,10 @@ class TestConfigRequestor: mock_translator_registry.get_response_translator.return_value = mock_response_translator # Mock dependencies - mock_pulsar_client = Mock() + mock_backend = Mock() requestor = ConfigRequestor( - pulsar_client=mock_pulsar_client, + backend=mock_backend, consumer="test-consumer", subscriber="test-subscriber", timeout=60 @@ -55,7 +55,7 @@ class TestConfigRequestor: with patch.object(ServiceRequestor, 'start', return_value=None), \ patch.object(ServiceRequestor, 'process', return_value=None): requestor = ConfigRequestor( - pulsar_client=Mock(), + backend=Mock(), consumer="test-consumer", subscriber="test-subscriber" ) @@ -79,7 +79,7 @@ class TestConfigRequestor: mock_response_translator.from_response_with_completion.return_value = "translated_response" requestor = ConfigRequestor( - pulsar_client=Mock(), + backend=Mock(), consumer="test-consumer", subscriber="test-subscriber" ) diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index 6bb2e4d1..33f1229d 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -39,12 +39,12 @@ class TestDispatcherManager: def test_dispatcher_manager_initialization(self): """Test DispatcherManager initialization""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) - assert manager.pulsar_client == mock_pulsar_client + assert manager.backend == mock_backend assert manager.config_receiver == mock_config_receiver assert manager.prefix == "api-gateway" # default prefix assert manager.flows == {} @@ -55,19 +55,19 @@ class TestDispatcherManager: def test_dispatcher_manager_initialization_with_custom_prefix(self): """Test DispatcherManager initialization with custom prefix""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix") + manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix") assert manager.prefix == "custom-prefix" @pytest.mark.asyncio async def test_start_flow(self): """Test start_flow method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) flow_data = {"name": "test_flow", "steps": []} @@ -79,9 +79,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_stop_flow(self): """Test stop_flow method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} @@ -93,9 +93,9 @@ class TestDispatcherManager: def test_dispatch_global_service_returns_wrapper(self): """Test dispatch_global_service returns DispatcherWrapper""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) wrapper = manager.dispatch_global_service() @@ -104,9 +104,9 @@ class TestDispatcherManager: def test_dispatch_core_export_returns_wrapper(self): """Test dispatch_core_export returns DispatcherWrapper""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) wrapper = manager.dispatch_core_export() @@ -115,9 +115,9 @@ class TestDispatcherManager: def test_dispatch_core_import_returns_wrapper(self): """Test dispatch_core_import returns DispatcherWrapper""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) wrapper = manager.dispatch_core_import() @@ -127,9 +127,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_core_import(self): """Test process_core_import method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import: mock_importer = Mock() @@ -138,16 +138,16 @@ class TestDispatcherManager: result = await manager.process_core_import("data", "error", "ok", "request") - mock_core_import.assert_called_once_with(mock_pulsar_client) + mock_core_import.assert_called_once_with(mock_backend) mock_importer.process.assert_called_once_with("data", "error", "ok", "request") assert result == "import_result" @pytest.mark.asyncio async def test_process_core_export(self): """Test process_core_export method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export: mock_exporter = Mock() @@ -156,16 +156,16 @@ class TestDispatcherManager: result = await manager.process_core_export("data", "error", "ok", "request") - mock_core_export.assert_called_once_with(mock_pulsar_client) + mock_core_export.assert_called_once_with(mock_backend) mock_exporter.process.assert_called_once_with("data", "error", "ok", "request") assert result == "export_result" @pytest.mark.asyncio async def test_process_global_service(self): """Test process_global_service method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) manager.invoke_global_service = AsyncMock(return_value="global_result") @@ -178,9 +178,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_global_service_with_existing_dispatcher(self): """Test invoke_global_service with existing dispatcher""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Pre-populate with existing dispatcher mock_dispatcher = Mock() @@ -195,9 +195,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_global_service_creates_new_dispatcher(self): """Test invoke_global_service creates new dispatcher""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() @@ -211,10 +211,12 @@ class TestDispatcherManager: # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( - pulsar_client=mock_pulsar_client, + backend=mock_backend, timeout=120, consumer="api-gateway-config-request", - subscriber="api-gateway-config-request" + subscriber="api-gateway-config-request", + request_queue=None, + response_queue=None ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") @@ -225,9 +227,9 @@ class TestDispatcherManager: def test_dispatch_flow_import_returns_method(self): """Test dispatch_flow_import returns correct method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) result = manager.dispatch_flow_import() @@ -235,9 +237,9 @@ class TestDispatcherManager: def test_dispatch_flow_export_returns_method(self): """Test dispatch_flow_export returns correct method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) result = manager.dispatch_flow_export() @@ -245,9 +247,9 @@ class TestDispatcherManager: def test_dispatch_socket_returns_method(self): """Test dispatch_socket returns correct method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) result = manager.dispatch_socket() @@ -255,9 +257,9 @@ class TestDispatcherManager: def test_dispatch_flow_service_returns_wrapper(self): """Test dispatch_flow_service returns DispatcherWrapper""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) wrapper = manager.dispatch_flow_service() @@ -267,9 +269,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_flow_import_with_valid_flow_and_kind(self): """Test process_flow_import with valid flow and kind""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow manager.flows["test_flow"] = { @@ -292,7 +294,7 @@ class TestDispatcherManager: result = await manager.process_flow_import("ws", "running", params) mock_dispatcher_class.assert_called_once_with( - pulsar_client=mock_pulsar_client, + backend=mock_backend, ws="ws", running="running", queue={"queue": "test_queue"} @@ -303,9 +305,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_flow_import_with_invalid_flow(self): """Test process_flow_import with invalid flow""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) params = {"flow": "invalid_flow", "kind": "triples"} @@ -318,9 +320,9 @@ class TestDispatcherManager: import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow manager.flows["test_flow"] = { @@ -340,9 +342,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_flow_export_with_valid_flow_and_kind(self): """Test process_flow_export with valid flow and kind""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow manager.flows["test_flow"] = { @@ -364,7 +366,7 @@ class TestDispatcherManager: result = await manager.process_flow_export("ws", "running", params) mock_dispatcher_class.assert_called_once_with( - pulsar_client=mock_pulsar_client, + backend=mock_backend, ws="ws", running="running", queue={"queue": "test_queue"}, @@ -376,9 +378,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_socket(self): """Test process_socket method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux: mock_mux_instance = Mock() @@ -392,9 +394,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_flow_service(self): """Test process_flow_service method""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) manager.invoke_flow_service = AsyncMock(return_value="flow_result") @@ -407,9 +409,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_with_existing_dispatcher(self): """Test invoke_flow_service with existing dispatcher""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Add flow to the flows dictionary manager.flows["test_flow"] = {"services": {"agent": {}}} @@ -427,9 +429,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_creates_request_response_dispatcher(self): """Test invoke_flow_service creates request-response dispatcher""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow manager.flows["test_flow"] = { @@ -454,7 +456,7 @@ class TestDispatcherManager: # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( - pulsar_client=mock_pulsar_client, + backend=mock_backend, request_queue="agent_request_queue", response_queue="agent_response_queue", timeout=120, @@ -471,9 +473,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_creates_sender_dispatcher(self): """Test invoke_flow_service creates sender dispatcher""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow manager.flows["test_flow"] = { @@ -498,7 +500,7 @@ class TestDispatcherManager: # Verify dispatcher was created with correct parameters mock_dispatcher_class.assert_called_once_with( - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue={"queue": "text_load_queue"} ) mock_dispatcher.start.assert_called_once() @@ -511,9 +513,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_invalid_flow(self): """Test invoke_flow_service with invalid flow""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) with pytest.raises(RuntimeError, match="Invalid flow"): await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") @@ -521,9 +523,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_unsupported_kind_by_flow(self): """Test invoke_flow_service with kind not supported by flow""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow without agent interface manager.flows["test_flow"] = { @@ -538,9 +540,9 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_invoke_flow_service_invalid_kind(self): """Test invoke_flow_service with invalid kind""" - mock_pulsar_client = Mock() + mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver) # Setup test flow with interface but unsupported kind manager.flows["test_flow"] = { diff --git a/tests/unit/test_gateway/test_dispatch_requestor.py b/tests/unit/test_gateway/test_dispatch_requestor.py index e9c89e1d..6b294540 100644 --- a/tests/unit/test_gateway/test_dispatch_requestor.py +++ b/tests/unit/test_gateway/test_dispatch_requestor.py @@ -15,12 +15,12 @@ class TestServiceRequestor: @patch('trustgraph.gateway.dispatch.requestor.Subscriber') def test_service_requestor_initialization(self, mock_subscriber, mock_publisher): """Test ServiceRequestor initialization""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_request_schema = MagicMock() mock_response_schema = MagicMock() requestor = ServiceRequestor( - pulsar_client=mock_pulsar_client, + backend=mock_backend, request_queue="test-request-queue", request_schema=mock_request_schema, response_queue="test-response-queue", @@ -32,12 +32,12 @@ class TestServiceRequestor: # Verify Publisher was created correctly mock_publisher.assert_called_once_with( - mock_pulsar_client, "test-request-queue", schema=mock_request_schema + mock_backend, "test-request-queue", schema=mock_request_schema ) # Verify Subscriber was created correctly mock_subscriber.assert_called_once_with( - mock_pulsar_client, "test-response-queue", + mock_backend, "test-response-queue", "test-subscription", "test-consumer", mock_response_schema ) @@ -48,12 +48,12 @@ class TestServiceRequestor: @patch('trustgraph.gateway.dispatch.requestor.Subscriber') def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher): """Test ServiceRequestor initialization with default parameters""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_request_schema = MagicMock() mock_response_schema = MagicMock() requestor = ServiceRequestor( - pulsar_client=mock_pulsar_client, + backend=mock_backend, request_queue="test-queue", request_schema=mock_request_schema, response_queue="response-queue", @@ -62,7 +62,7 @@ class TestServiceRequestor: # Verify default values mock_subscriber.assert_called_once_with( - mock_pulsar_client, "response-queue", + mock_backend, "response-queue", "api-gateway", "api-gateway", mock_response_schema ) assert requestor.timeout == 600 # Default timeout @@ -72,14 +72,14 @@ class TestServiceRequestor: @pytest.mark.asyncio async def test_service_requestor_start(self, mock_subscriber, mock_publisher): """Test ServiceRequestor start method""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_sub_instance = AsyncMock() mock_pub_instance = AsyncMock() mock_subscriber.return_value = mock_sub_instance mock_publisher.return_value = mock_pub_instance requestor = ServiceRequestor( - pulsar_client=mock_pulsar_client, + backend=mock_backend, request_queue="test-queue", request_schema=MagicMock(), response_queue="response-queue", @@ -98,14 +98,14 @@ class TestServiceRequestor: @patch('trustgraph.gateway.dispatch.requestor.Subscriber') def test_service_requestor_attributes(self, mock_subscriber, mock_publisher): """Test ServiceRequestor has correct attributes""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_pub_instance = AsyncMock() mock_sub_instance = AsyncMock() mock_publisher.return_value = mock_pub_instance mock_subscriber.return_value = mock_sub_instance requestor = ServiceRequestor( - pulsar_client=mock_pulsar_client, + backend=mock_backend, request_queue="test-queue", request_schema=MagicMock(), response_queue="response-queue", diff --git a/tests/unit/test_gateway/test_dispatch_sender.py b/tests/unit/test_gateway/test_dispatch_sender.py index 692604d5..06d828dd 100644 --- a/tests/unit/test_gateway/test_dispatch_sender.py +++ b/tests/unit/test_gateway/test_dispatch_sender.py @@ -14,18 +14,18 @@ class TestServiceSender: @patch('trustgraph.gateway.dispatch.sender.Publisher') def test_service_sender_initialization(self, mock_publisher): """Test ServiceSender initialization""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_schema = MagicMock() sender = ServiceSender( - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue", schema=mock_schema ) # Verify Publisher was created correctly mock_publisher.assert_called_once_with( - mock_pulsar_client, "test-queue", schema=mock_schema + mock_backend, "test-queue", schema=mock_schema ) @patch('trustgraph.gateway.dispatch.sender.Publisher') @@ -36,7 +36,7 @@ class TestServiceSender: mock_publisher.return_value = mock_pub_instance sender = ServiceSender( - pulsar_client=MagicMock(), + backend=MagicMock(), queue="test-queue", schema=MagicMock() ) @@ -55,7 +55,7 @@ class TestServiceSender: mock_publisher.return_value = mock_pub_instance sender = ServiceSender( - pulsar_client=MagicMock(), + backend=MagicMock(), queue="test-queue", schema=MagicMock() ) @@ -70,7 +70,7 @@ class TestServiceSender: def test_service_sender_to_request_not_implemented(self, mock_publisher): """Test ServiceSender to_request method raises RuntimeError""" sender = ServiceSender( - pulsar_client=MagicMock(), + backend=MagicMock(), queue="test-queue", schema=MagicMock() ) @@ -91,7 +91,7 @@ class TestServiceSender: return {"processed": request} sender = ConcreteSender( - pulsar_client=MagicMock(), + backend=MagicMock(), queue="test-queue", schema=MagicMock() ) @@ -111,7 +111,7 @@ class TestServiceSender: mock_publisher.return_value = mock_pub_instance sender = ServiceSender( - pulsar_client=MagicMock(), + backend=MagicMock(), queue="test-queue", schema=MagicMock() ) diff --git a/tests/unit/test_gateway/test_objects_import_dispatcher.py b/tests/unit/test_gateway/test_objects_import_dispatcher.py index ed9e8faa..0332c1a1 100644 --- a/tests/unit/test_gateway/test_objects_import_dispatcher.py +++ b/tests/unit/test_gateway/test_objects_import_dispatcher.py @@ -16,7 +16,7 @@ from trustgraph.schema import Metadata, ExtractedObject @pytest.fixture -def mock_pulsar_client(): +def mock_backend(): """Mock Pulsar client.""" client = Mock() return client @@ -96,7 +96,7 @@ class TestObjectsImportInitialization: """Test ObjectsImport initialization.""" @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that ObjectsImport creates Publisher with correct parameters.""" mock_publisher_instance = Mock() mock_publisher_class.return_value = mock_publisher_instance @@ -104,13 +104,13 @@ class TestObjectsImportInitialization: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-objects-queue" ) # Verify Publisher was created with correct parameters mock_publisher_class.assert_called_once_with( - mock_pulsar_client, + mock_backend, topic="test-objects-queue", schema=ExtractedObject ) @@ -121,12 +121,12 @@ class TestObjectsImportInitialization: assert objects_import.publisher == mock_publisher_instance @patch('trustgraph.gateway.dispatch.objects_import.Publisher') - def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that ObjectsImport stores all required references.""" objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="objects-queue" ) @@ -139,7 +139,7 @@ class TestObjectsImportLifecycle: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that start() calls publisher.start().""" mock_publisher_instance = Mock() mock_publisher_instance.start = AsyncMock() @@ -148,7 +148,7 @@ class TestObjectsImportLifecycle: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -158,7 +158,7 @@ class TestObjectsImportLifecycle: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that destroy() properly stops publisher and closes websocket.""" mock_publisher_instance = Mock() mock_publisher_instance.stop = AsyncMock() @@ -167,7 +167,7 @@ class TestObjectsImportLifecycle: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -180,7 +180,7 @@ class TestObjectsImportLifecycle: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running): + async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running): """Test that destroy() handles None websocket gracefully.""" mock_publisher_instance = Mock() mock_publisher_instance.stop = AsyncMock() @@ -189,7 +189,7 @@ class TestObjectsImportLifecycle: objects_import = ObjectsImport( ws=None, # None websocket running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -205,7 +205,7 @@ class TestObjectsImportMessageProcessing: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message): + async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() processes complete message correctly.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock() @@ -214,7 +214,7 @@ class TestObjectsImportMessageProcessing: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -248,7 +248,7 @@ class TestObjectsImportMessageProcessing: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message): + async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message): """Test that receive() handles message with minimal required fields.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock() @@ -257,7 +257,7 @@ class TestObjectsImportMessageProcessing: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -281,7 +281,7 @@ class TestObjectsImportMessageProcessing: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() uses appropriate default values for optional fields.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock() @@ -290,7 +290,7 @@ class TestObjectsImportMessageProcessing: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -323,7 +323,7 @@ class TestObjectsImportRunMethod: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') @pytest.mark.asyncio - async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that run() loops while running.get() returns True.""" mock_sleep.return_value = None mock_publisher_class.return_value = Mock() @@ -334,7 +334,7 @@ class TestObjectsImportRunMethod: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -353,7 +353,7 @@ class TestObjectsImportRunMethod: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') @pytest.mark.asyncio - async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running): + async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running): """Test that run() handles None websocket gracefully.""" mock_sleep.return_value = None mock_publisher_class.return_value = Mock() @@ -363,7 +363,7 @@ class TestObjectsImportRunMethod: objects_import = ObjectsImport( ws=None, # None websocket running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -417,7 +417,7 @@ class TestObjectsImportBatchProcessing: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message): + async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message): """Test that receive() processes batch message correctly.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock() @@ -426,7 +426,7 @@ class TestObjectsImportBatchProcessing: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -467,7 +467,7 @@ class TestObjectsImportBatchProcessing: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles empty batch correctly.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock() @@ -476,7 +476,7 @@ class TestObjectsImportBatchProcessing: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -507,7 +507,7 @@ class TestObjectsImportErrorHandling: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message): + async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message): """Test that receive() propagates publisher send errors.""" mock_publisher_instance = Mock() mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error")) @@ -516,7 +516,7 @@ class TestObjectsImportErrorHandling: objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) @@ -528,14 +528,14 @@ class TestObjectsImportErrorHandling: @patch('trustgraph.gateway.dispatch.objects_import.Publisher') @pytest.mark.asyncio - async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): + async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running): """Test that receive() handles malformed JSON appropriately.""" mock_publisher_class.return_value = Mock() objects_import = ObjectsImport( ws=mock_websocket, running=mock_running, - pulsar_client=mock_pulsar_client, + backend=mock_backend, queue="test-queue" ) diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py index a943078f..22d9ab04 100644 --- a/tests/unit/test_gateway/test_service.py +++ b/tests/unit/test_gateway/test_service.py @@ -19,23 +19,21 @@ class TestApi: def test_api_initialization_with_defaults(self): """Test Api initialization with default values""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() - + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_backend = Mock() + mock_get_pubsub.return_value = mock_backend + api = Api() - + assert api.port == default_port assert api.timeout == default_timeout assert api.pulsar_host == default_pulsar_host assert api.pulsar_api_key is None assert api.prometheus_url == default_prometheus_url + "/" assert api.auth.allow_all is True - - # Verify Pulsar client was created without API key - mock_client.assert_called_once_with( - default_pulsar_host, - listener_name=None - ) + + # Verify get_pubsub was called + mock_get_pubsub.assert_called_once() def test_api_initialization_with_custom_config(self): """Test Api initialization with custom configuration""" @@ -48,14 +46,13 @@ class TestApi: "prometheus_url": "http://custom-prometheus:9090", "api_token": "secret-token" } - - with patch('pulsar.Client') as mock_client, \ - patch('pulsar.AuthenticationToken') as mock_auth: - mock_client.return_value = Mock() - mock_auth.return_value = Mock() - + + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_backend = Mock() + mock_get_pubsub.return_value = mock_backend + api = Api(**config) - + assert api.port == 9000 assert api.timeout == 300 assert api.pulsar_host == "pulsar://custom-host:6650" @@ -63,35 +60,25 @@ class TestApi: assert api.prometheus_url == "http://custom-prometheus:9090/" assert api.auth.token == "secret-token" assert api.auth.allow_all is False - - # Verify Pulsar client was created with API key - mock_auth.assert_called_once_with("test-api-key") - mock_client.assert_called_once_with( - "pulsar://custom-host:6650", - listener_name="custom-listener", - authentication=mock_auth.return_value - ) + + # Verify get_pubsub was called with config + mock_get_pubsub.assert_called_once_with(**config) def test_api_initialization_with_pulsar_api_key(self): """Test Api initialization with Pulsar API key authentication""" - with patch('pulsar.Client') as mock_client, \ - patch('pulsar.AuthenticationToken') as mock_auth: - mock_client.return_value = Mock() - mock_auth.return_value = Mock() - + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() + api = Api(pulsar_api_key="test-key") - - mock_auth.assert_called_once_with("test-key") - mock_client.assert_called_once_with( - default_pulsar_host, - listener_name=None, - authentication=mock_auth.return_value - ) + + # Verify api key was stored + assert api.pulsar_api_key == "test-key" + mock_get_pubsub.assert_called_once() def test_api_initialization_prometheus_url_normalization(self): """Test that prometheus_url gets normalized with trailing slash""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() # Test URL without trailing slash api = Api(prometheus_url="http://prometheus:9090") @@ -103,16 +90,16 @@ class TestApi: def test_api_initialization_empty_api_token_means_no_auth(self): """Test that empty API token results in allow_all authentication""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() api = Api(api_token="") assert api.auth.allow_all is True def test_api_initialization_none_api_token_means_no_auth(self): """Test that None API token results in allow_all authentication""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() api = Api(api_token=None) assert api.auth.allow_all is True @@ -120,8 +107,8 @@ class TestApi: @pytest.mark.asyncio async def test_app_factory_creates_application(self): """Test that app_factory creates aiohttp application""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() api = Api() @@ -147,8 +134,8 @@ class TestApi: @pytest.mark.asyncio async def test_app_factory_with_custom_endpoints(self): """Test app_factory with custom endpoints""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() api = Api() @@ -180,13 +167,13 @@ class TestApi: def test_run_method_calls_web_run_app(self): """Test that run method calls web.run_app""" - with patch('pulsar.Client') as mock_client, \ + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \ patch('aiohttp.web.run_app') as mock_run_app: - mock_client.return_value = Mock() - + mock_get_pubsub.return_value = Mock() + api = Api(port=8080) api.run() - + # Verify run_app was called once with the correct port mock_run_app.assert_called_once() args, kwargs = mock_run_app.call_args @@ -195,19 +182,19 @@ class TestApi: def test_api_components_initialization(self): """Test that all API components are properly initialized""" - with patch('pulsar.Client') as mock_client: - mock_client.return_value = Mock() - + with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: + mock_get_pubsub.return_value = Mock() + api = Api() - + # Verify all components are initialized assert api.config_receiver is not None assert api.dispatcher_manager is not None assert api.endpoint_manager is not None assert api.endpoints == [] - + # Verify component relationships - assert api.dispatcher_manager.pulsar_client == api.pulsar_client + assert api.dispatcher_manager.backend == api.pubsub_backend assert api.dispatcher_manager.config_receiver == api.config_receiver assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager # EndpointManager doesn't store auth directly, it passes it to individual endpoints diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 4e8768a1..1a63227d 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -102,7 +102,7 @@ async def test_handle_normal_flow(): """Test normal websocket handling flow.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): nonlocal dispatcher_created @@ -110,33 +110,43 @@ async def test_handle_normal_flow(): dispatcher = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: # Mock task group context manager mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) - mock_tg.create_task = MagicMock(return_value=AsyncMock()) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + # Consume the coroutine to avoid "was never awaited" warning + coro.close() + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + result = await socket_endpoint.handle(request) - + # Should have created dispatcher assert dispatcher_created is True - + # Should return websocket assert result == mock_ws @@ -146,50 +156,64 @@ async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + # Mock TaskGroup to raise ExceptionGroup class TestException(Exception): pass - + exception_group = ExceptionGroup("Test exceptions", [TestException("test")]) - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() - mock_ws.close = AsyncMock() + mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) - mock_tg.create_task = MagicMock(side_effect=TestException("test")) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + # Consume the coroutine to avoid "was never awaited" warning + coro.close() + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - - with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: - mock_wait_for.return_value = None - + + with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for: + # Make wait_for consume the coroutine passed to it + async def wait_for_side_effect(coro, timeout=None): + coro.close() # Consume the coroutine + return None + mock_wait_for.side_effect = wait_for_side_effect + result = await socket_endpoint.handle(request) - + # Should have attempted graceful cleanup mock_wait_for.assert_called_once() - + # Should have called destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 - + # Should have closed websocket mock_ws.close.assert_called() @@ -199,48 +223,62 @@ async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + # Mock TaskGroup to raise exception exception_group = ExceptionGroup("Test", [Exception("test")]) - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = False mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) - mock_tg.create_task = MagicMock(side_effect=Exception("test")) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + # Consume the coroutine to avoid "was never awaited" warning + coro.close() + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + # Mock asyncio.wait_for to raise TimeoutError - with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: - mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout") - + with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for: + # Make wait_for consume the coroutine before raising + async def wait_for_timeout(coro, timeout=None): + coro.close() # Consume the coroutine + raise asyncio.TimeoutError("Cleanup timeout") + mock_wait_for.side_effect = wait_for_timeout + result = await socket_endpoint.handle(request) - + # Should have attempted cleanup with timeout mock_wait_for.assert_called_once() # Check that timeout was passed correctly assert mock_wait_for.call_args[1]['timeout'] == 5.0 - + # Should still call destroy in finally block assert mock_dispatcher.destroy.call_count >= 1 @@ -290,37 +328,47 @@ async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() mock_auth.permitted.return_value = True - + mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() - + async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - + socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) - + request = MagicMock() request.query = {"token": "valid-token"} request.match_info = {} - + with patch('aiohttp.web.WebSocketResponse') as mock_ws_class: mock_ws = AsyncMock() mock_ws.prepare = AsyncMock() mock_ws.close = AsyncMock() mock_ws.closed = True # Already closed mock_ws_class.return_value = mock_ws - + with patch('asyncio.TaskGroup') as mock_task_group: mock_tg = AsyncMock() mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aexit__ = AsyncMock(return_value=None) - mock_tg.create_task = MagicMock(return_value=AsyncMock()) + + # Create proper mock tasks that look like asyncio.Task objects + def create_task_mock(coro): + # Consume the coroutine to avoid "was never awaited" warning + coro.close() + task = AsyncMock() + task.done = MagicMock(return_value=True) + task.cancelled = MagicMock(return_value=False) + return task + + mock_tg.create_task = MagicMock(side_effect=create_task_mock) mock_task_group.return_value = mock_tg - + result = await socket_endpoint.handle(request) - + # Should still have called destroy mock_dispatcher.destroy.assert_called() - + # Should not attempt to close already closed websocket mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True \ No newline at end of file diff --git a/tests/unit/test_gateway/test_streaming_translators.py b/tests/unit/test_gateway/test_streaming_translators.py new file mode 100644 index 00000000..e767edd4 --- /dev/null +++ b/tests/unit/test_gateway/test_streaming_translators.py @@ -0,0 +1,326 @@ +""" +Unit tests for streaming behavior in message translators. + +These tests verify that translators correctly handle empty strings and +end_of_stream flags in streaming responses, preventing bugs where empty +final chunks could be dropped due to falsy value checks. +""" + +import pytest +from unittest.mock import MagicMock +from trustgraph.messaging.translators.retrieval import ( + GraphRagResponseTranslator, + DocumentRagResponseTranslator, +) +from trustgraph.messaging.translators.prompt import PromptResponseTranslator +from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator +from trustgraph.schema import ( + GraphRagResponse, + DocumentRagResponse, + PromptResponse, + TextCompletionResponse, +) + + +class TestGraphRagResponseTranslator: + """Test GraphRagResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_response(self): + """Test that empty response strings are preserved""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - Empty string should be included in result + assert "response" in result + assert result["response"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_response(self): + """Test that non-empty responses work correctly""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response="Some text", + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["response"] == "Some text" + assert result["end_of_stream"] is False + + def test_from_pulsar_with_none_response(self): + """Test that None response is handled correctly""" + # Arrange + translator = GraphRagResponseTranslator() + response = GraphRagResponse( + response=None, + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - None should not be included + assert "response" not in result + assert result["end_of_stream"] is True + + def test_from_response_with_completion_returns_correct_flag(self): + """Test that from_response_with_completion returns correct is_final flag""" + # Arrange + translator = GraphRagResponseTranslator() + + # Test non-final chunk + response_chunk = GraphRagResponse( + response="chunk", + end_of_stream=False, + error=None + ) + + # Act + result, is_final = translator.from_response_with_completion(response_chunk) + + # Assert + assert is_final is False + assert result["end_of_stream"] is False + + # Test final chunk with empty content + final_response = GraphRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result, is_final = translator.from_response_with_completion(final_response) + + # Assert + assert is_final is True + assert result["response"] == "" + assert result["end_of_stream"] is True + + +class TestDocumentRagResponseTranslator: + """Test DocumentRagResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_response(self): + """Test that empty response strings are preserved""" + # Arrange + translator = DocumentRagResponseTranslator() + response = DocumentRagResponse( + response="", + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "response" in result + assert result["response"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_response(self): + """Test that non-empty responses work correctly""" + # Arrange + translator = DocumentRagResponseTranslator() + response = DocumentRagResponse( + response="Document content", + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["response"] == "Document content" + assert result["end_of_stream"] is False + + +class TestPromptResponseTranslator: + """Test PromptResponseTranslator streaming behavior""" + + def test_from_pulsar_with_empty_text(self): + """Test that empty text strings are preserved""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text="", + object=None, + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "text" in result + assert result["text"] == "" + assert result["end_of_stream"] is True + + def test_from_pulsar_with_non_empty_text(self): + """Test that non-empty text works correctly""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text="Some prompt response", + object=None, + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert result["text"] == "Some prompt response" + assert result["end_of_stream"] is False + + def test_from_pulsar_with_none_text(self): + """Test that None text is handled correctly""" + # Arrange + translator = PromptResponseTranslator() + response = PromptResponse( + text=None, + object='{"result": "data"}', + end_of_stream=True, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "text" not in result + assert "object" in result + assert result["end_of_stream"] is True + + def test_from_pulsar_includes_end_of_stream(self): + """Test that end_of_stream flag is always included""" + # Arrange + translator = PromptResponseTranslator() + + # Test with end_of_stream=False + response = PromptResponse( + text="chunk", + object=None, + end_of_stream=False, + error=None + ) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "end_of_stream" in result + assert result["end_of_stream"] is False + + +class TestTextCompletionResponseTranslator: + """Test TextCompletionResponseTranslator streaming behavior""" + + def test_from_pulsar_always_includes_response(self): + """Test that response field is always included, even if empty""" + # Arrange + translator = TextCompletionResponseTranslator() + response = TextCompletionResponse( + response="", + end_of_stream=True, + error=None, + in_token=100, + out_token=5, + model="test-model" + ) + + # Act + result = translator.from_pulsar(response) + + # Assert - Response should always be present + assert "response" in result + assert result["response"] == "" + + def test_from_response_with_completion_with_empty_final(self): + """Test that empty final response is handled correctly""" + # Arrange + translator = TextCompletionResponseTranslator() + response = TextCompletionResponse( + response="", + end_of_stream=True, + error=None, + in_token=100, + out_token=5, + model="test-model" + ) + + # Act + result, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True + assert result["response"] == "" + + +class TestStreamingProtocolCompliance: + """Test that all translators follow streaming protocol conventions""" + + @pytest.mark.parametrize("translator_class,response_class,field_name", [ + (GraphRagResponseTranslator, GraphRagResponse, "response"), + (DocumentRagResponseTranslator, DocumentRagResponse, "response"), + (PromptResponseTranslator, PromptResponse, "text"), + (TextCompletionResponseTranslator, TextCompletionResponse, "response"), + ]) + def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name): + """Test that all translators preserve empty final chunks""" + # Arrange + translator = translator_class() + kwargs = { + field_name: "", + "end_of_stream": True, + "error": None, + } + response = response_class(**kwargs) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty" + assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string" + + @pytest.mark.parametrize("translator_class,response_class,field_name", [ + (GraphRagResponseTranslator, GraphRagResponse, "response"), + (DocumentRagResponseTranslator, DocumentRagResponse, "response"), + (TextCompletionResponseTranslator, TextCompletionResponse, "response"), + ]) + def test_end_of_stream_flag_included(self, translator_class, response_class, field_name): + """Test that end_of_stream flag is included in all response translators""" + # Arrange + translator = translator_class() + kwargs = { + field_name: "test content", + "end_of_stream": True, + "error": None, + } + response = response_class(**kwargs) + + # Act + result = translator.from_pulsar(response) + + # Assert + assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag" + assert result["end_of_stream"] is True diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py new file mode 100644 index 00000000..f86ae3da --- /dev/null +++ b/tests/unit/test_python_api_client.py @@ -0,0 +1,446 @@ +""" +Unit tests for TrustGraph Python API client library + +These tests use mocks and do not require a running server. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock, call +import json + +from trustgraph.api import ( + Api, + Triple, + AgentThought, + AgentObservation, + AgentAnswer, + RAGChunk, +) + + +class TestApiInstantiation: + """Test Api class instantiation and configuration""" + + def test_api_instantiation_defaults(self): + """Test Api with default parameters""" + api = Api() + assert api.url == "http://localhost:8088/api/v1/" + assert api.timeout == 60 + assert api.token is None + + def test_api_instantiation_with_url(self): + """Test Api with custom URL""" + api = Api(url="http://test-server:9000/") + assert api.url == "http://test-server:9000/api/v1/" + + def test_api_instantiation_with_url_trailing_slash(self): + """Test Api adds trailing slash if missing""" + api = Api(url="http://test-server:9000") + assert api.url == "http://test-server:9000/api/v1/" + + def test_api_instantiation_with_token(self): + """Test Api with authentication token""" + api = Api(token="test-token-123") + assert api.token == "test-token-123" + + def test_api_instantiation_with_timeout(self): + """Test Api with custom timeout""" + api = Api(timeout=120) + assert api.timeout == 120 + + +class TestApiLazyInitialization: + """Test lazy initialization of client components""" + + def test_socket_client_lazy_init(self): + """Test socket client is created on first access""" + api = Api(url="http://test/", token="token") + + assert api._socket_client is None + socket = api.socket() + assert api._socket_client is not None + assert socket is api._socket_client + + # Second access returns same instance + socket2 = api.socket() + assert socket2 is socket + + def test_bulk_client_lazy_init(self): + """Test bulk client is created on first access""" + api = Api(url="http://test/") + + assert api._bulk_client is None + bulk = api.bulk() + assert api._bulk_client is not None + + def test_async_flow_lazy_init(self): + """Test async flow is created on first access""" + api = Api(url="http://test/") + + assert api._async_flow is None + async_flow = api.async_flow() + assert api._async_flow is not None + + def test_metrics_lazy_init(self): + """Test metrics client is created on first access""" + api = Api(url="http://test/") + + assert api._metrics is None + metrics = api.metrics() + assert api._metrics is not None + + +class TestApiContextManager: + """Test context manager functionality""" + + def test_sync_context_manager(self): + """Test synchronous context manager""" + with Api(url="http://test/") as api: + assert api is not None + assert isinstance(api, Api) + # Should exit cleanly + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """Test asynchronous context manager""" + async with Api(url="http://test/") as api: + assert api is not None + assert isinstance(api, Api) + # Should exit cleanly + + +class TestFlowClient: + """Test Flow client functionality""" + + @patch('requests.post') + def test_flow_list(self, mock_post): + """Test listing flows""" + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"flow-ids": ["flow1", "flow2"]} + + api = Api(url="http://test/") + flows = api.flow().list() + + assert flows == ["flow1", "flow2"] + assert mock_post.called + + @patch('requests.post') + def test_flow_list_with_token(self, mock_post): + """Test flow listing includes auth token""" + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"flow-ids": []} + + api = Api(url="http://test/", token="my-token") + api.flow().list() + + # Verify Authorization header was set + call_args = mock_post.call_args + headers = call_args[1]['headers'] if 'headers' in call_args[1] else {} + assert 'Authorization' in headers + assert headers['Authorization'] == 'Bearer my-token' + + @patch('requests.post') + def test_flow_get(self, mock_post): + """Test getting flow definition""" + flow_def = {"name": "test-flow", "description": "Test"} + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"flow": json.dumps(flow_def)} + + api = Api(url="http://test/") + result = api.flow().get("test-flow") + + assert result == flow_def + + def test_flow_instance_creation(self): + """Test creating flow instance""" + api = Api(url="http://test/") + flow_instance = api.flow().id("my-flow") + + assert flow_instance is not None + assert flow_instance.id == "my-flow" + + def test_flow_instance_has_methods(self): + """Test flow instance has expected methods""" + api = Api(url="http://test/") + flow_instance = api.flow().id("my-flow") + + expected_methods = [ + 'text_completion', 'agent', 'graph_rag', 'document_rag', + 'graph_embeddings_query', 'embeddings', 'prompt', + 'triples_query', 'objects_query' + ] + + for method in expected_methods: + assert hasattr(flow_instance, method), f"Missing method: {method}" + + +class TestSocketClient: + """Test WebSocket client functionality""" + + def test_socket_client_url_conversion_http(self): + """Test HTTP URL converted to WebSocket""" + api = Api(url="http://test-server:8088/") + socket = api.socket() + + assert socket.url.startswith("ws://") + assert "test-server" in socket.url + + def test_socket_client_url_conversion_https(self): + """Test HTTPS URL converted to secure WebSocket""" + api = Api(url="https://test-server:8088/") + socket = api.socket() + + assert socket.url.startswith("wss://") + + def test_socket_client_token_passed(self): + """Test token is passed to socket client""" + api = Api(url="http://test/", token="socket-token") + socket = api.socket() + + assert socket.token == "socket-token" + + def test_socket_flow_instance(self): + """Test creating socket flow instance""" + api = Api(url="http://test/") + socket = api.socket() + flow_instance = socket.flow("test-flow") + + assert flow_instance is not None + assert flow_instance.flow_id == "test-flow" + + def test_socket_flow_has_methods(self): + """Test socket flow instance has expected methods""" + api = Api(url="http://test/") + flow_instance = api.socket().flow("test-flow") + + expected_methods = [ + 'agent', 'text_completion', 'graph_rag', 'document_rag', + 'prompt', 'graph_embeddings_query', 'embeddings', + 'triples_query', 'objects_query', 'mcp_tool' + ] + + for method in expected_methods: + assert hasattr(flow_instance, method), f"Missing method: {method}" + + +class TestBulkClient: + """Test bulk operations client""" + + def test_bulk_client_url_conversion(self): + """Test bulk client uses WebSocket URL""" + api = Api(url="http://test/") + bulk = api.bulk() + + assert bulk.url.startswith("ws://") + + def test_bulk_client_has_import_methods(self): + """Test bulk client has import methods""" + api = Api(url="http://test/") + bulk = api.bulk() + + import_methods = [ + 'import_triples', + 'import_graph_embeddings', + 'import_document_embeddings', + 'import_entity_contexts', + 'import_objects' + ] + + for method in import_methods: + assert hasattr(bulk, method), f"Missing method: {method}" + + def test_bulk_client_has_export_methods(self): + """Test bulk client has export methods""" + api = Api(url="http://test/") + bulk = api.bulk() + + export_methods = [ + 'export_triples', + 'export_graph_embeddings', + 'export_document_embeddings', + 'export_entity_contexts' + ] + + for method in export_methods: + assert hasattr(bulk, method), f"Missing method: {method}" + + +class TestMetricsClient: + """Test metrics client""" + + @patch('requests.get') + def test_metrics_get(self, mock_get): + """Test getting metrics""" + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "# HELP metric_name\nmetric_name 42" + + api = Api(url="http://test/") + metrics_text = api.metrics().get() + + assert "metric_name" in metrics_text + assert mock_get.called + + @patch('requests.get') + def test_metrics_with_token(self, mock_get): + """Test metrics request includes token""" + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "metrics" + + api = Api(url="http://test/", token="metrics-token") + api.metrics().get() + + # Verify token in headers + call_args = mock_get.call_args + headers = call_args[1].get('headers', {}) + assert 'Authorization' in headers + + +class TestStreamingTypes: + """Test streaming chunk types""" + + def test_agent_thought_creation(self): + """Test creating AgentThought chunk""" + chunk = AgentThought(content="thinking...", end_of_message=False) + + assert chunk.content == "thinking..." + assert chunk.end_of_message is False + assert chunk.chunk_type == "thought" + + def test_agent_observation_creation(self): + """Test creating AgentObservation chunk""" + chunk = AgentObservation(content="observing...", end_of_message=False) + + assert chunk.content == "observing..." + assert chunk.chunk_type == "observation" + + def test_agent_answer_creation(self): + """Test creating AgentAnswer chunk""" + chunk = AgentAnswer( + content="answer", + end_of_message=True, + end_of_dialog=True + ) + + assert chunk.content == "answer" + assert chunk.end_of_message is True + assert chunk.end_of_dialog is True + assert chunk.chunk_type == "final-answer" + + def test_rag_chunk_creation(self): + """Test creating RAGChunk""" + chunk = RAGChunk( + content="response chunk", + end_of_stream=False, + error=None + ) + + assert chunk.content == "response chunk" + assert chunk.end_of_stream is False + assert chunk.error is None + + def test_rag_chunk_with_error(self): + """Test RAGChunk with error""" + error_dict = {"type": "error", "message": "failed"} + chunk = RAGChunk( + content="", + end_of_stream=True, + error=error_dict + ) + + assert chunk.error == error_dict + + +class TestTripleType: + """Test Triple data structure""" + + def test_triple_creation(self): + """Test creating Triple""" + triple = Triple(s="subject", p="predicate", o="object") + + assert triple.s == "subject" + assert triple.p == "predicate" + assert triple.o == "object" + + def test_triple_with_uris(self): + """Test Triple with URI values""" + triple = Triple( + s="http://example.org/entity1", + p="http://example.org/relation", + o="http://example.org/entity2" + ) + + assert triple.s.startswith("http://") + assert triple.p.startswith("http://") + assert triple.o.startswith("http://") + + +class TestAsyncClients: + """Test async client availability""" + + def test_async_flow_creation(self): + """Test creating async flow client""" + api = Api(url="http://test/") + async_flow = api.async_flow() + + assert async_flow is not None + + def test_async_socket_creation(self): + """Test creating async socket client""" + api = Api(url="http://test/") + async_socket = api.async_socket() + + assert async_socket is not None + assert async_socket.url.startswith("ws://") + + def test_async_bulk_creation(self): + """Test creating async bulk client""" + api = Api(url="http://test/") + async_bulk = api.async_bulk() + + assert async_bulk is not None + + def test_async_metrics_creation(self): + """Test creating async metrics client""" + api = Api(url="http://test/") + async_metrics = api.async_metrics() + + assert async_metrics is not None + + +class TestErrorHandling: + """Test error handling""" + + @patch('requests.post') + def test_protocol_exception_on_non_200(self, mock_post): + """Test ProtocolException raised on non-200 status""" + from trustgraph.api.exceptions import ProtocolException + + mock_post.return_value.status_code = 500 + + api = Api(url="http://test/") + + with pytest.raises(ProtocolException): + api.flow().list() + + @patch('requests.post') + def test_application_exception_on_error_response(self, mock_post): + """Test ApplicationException on error in response""" + from trustgraph.api.exceptions import ApplicationException + + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "error": { + "type": "ValidationError", + "message": "Invalid input" + } + } + + api = Api(url="http://test/") + + with pytest.raises(ApplicationException): + api.flow().list() + + +# Run tests with: pytest tests/unit/test_python_api_client.py -v +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_retrieval/test_structured_diag/test_schema_contracts.py b/tests/unit/test_retrieval/test_structured_diag/test_schema_contracts.py index 99f66dc7..240bad89 100644 --- a/tests/unit/test_retrieval/test_structured_diag/test_schema_contracts.py +++ b/tests/unit/test_retrieval/test_structured_diag/test_schema_contracts.py @@ -23,9 +23,9 @@ class TestStructuredDiagnosisSchemaContract: assert request.operation == "detect-type" assert request.sample == "test data" - assert request.type is None # Optional, defaults to None - assert request.schema_name is None # Optional, defaults to None - assert request.options is None # Optional, defaults to None + assert request.type == "" # Optional, defaults to empty string + assert request.schema_name == "" # Optional, defaults to empty string + assert request.options == {} # Optional, defaults to empty dict def test_request_schema_all_operations(self): """Test request schema supports all operations""" @@ -66,9 +66,9 @@ class TestStructuredDiagnosisSchemaContract: assert response.detected_type == "xml" assert response.confidence == 0.9 assert response.error is None - assert response.descriptor is None - assert response.metadata is None - assert response.schema_matches is None # New field, defaults to None + assert response.descriptor == "" # Defaults to empty string + assert response.metadata == {} # Defaults to empty dict + assert response.schema_matches == [] # Defaults to empty list def test_response_schema_with_error(self): """Test response schema with error""" @@ -140,6 +140,7 @@ class TestStructuredDiagnosisSchemaContract: assert response.metadata == metadata assert response.metadata["field_count"] == "5" + @pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses") def test_schema_serialization(self): """Test that schemas can be serialized and deserialized correctly""" # Test request serialization @@ -158,6 +159,7 @@ class TestStructuredDiagnosisSchemaContract: assert deserialized.sample == request.sample assert deserialized.options == request.options + @pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses") def test_response_serialization_with_schema_matches(self): """Test response serialization with schema_matches array""" response = StructuredDataDiagnosisResponse( @@ -185,7 +187,7 @@ class TestStructuredDiagnosisSchemaContract: ) # Verify default value for new field - assert response.schema_matches is None # Defaults to None when not set + assert response.schema_matches == [] # Defaults to empty list when not set # Verify old fields still work assert response.detected_type == "json" @@ -221,7 +223,7 @@ class TestStructuredDiagnosisSchemaContract: ) assert error_response.error is not None - assert error_response.schema_matches is None # Default None when not set + assert error_response.schema_matches == [] # Default empty list when not set def test_all_operations_supported(self): """Verify all operations are properly supported in the contract""" diff --git a/tests/unit/test_rev_gateway/test_dispatcher.py b/tests/unit/test_rev_gateway/test_dispatcher.py index b4fa2eb1..2a9c8df0 100644 --- a/tests/unit/test_rev_gateway/test_dispatcher.py +++ b/tests/unit/test_rev_gateway/test_dispatcher.py @@ -72,7 +72,7 @@ class TestMessageDispatcher: assert dispatcher.max_workers == 10 assert dispatcher.semaphore._value == 10 assert dispatcher.active_tasks == set() - assert dispatcher.pulsar_client is None + assert dispatcher.backend is None assert dispatcher.dispatcher_manager is None assert len(dispatcher.service_mapping) > 0 @@ -86,7 +86,7 @@ class TestMessageDispatcher: @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): """Test MessageDispatcher initialization with pulsar_client and config_receiver""" - mock_pulsar_client = MagicMock() + mock_backend = MagicMock() mock_config_receiver = MagicMock() mock_dispatcher_instance = MagicMock() mock_dispatcher_manager.return_value = mock_dispatcher_instance @@ -94,14 +94,14 @@ class TestMessageDispatcher: dispatcher = MessageDispatcher( max_workers=8, config_receiver=mock_config_receiver, - pulsar_client=mock_pulsar_client + backend=mock_backend ) assert dispatcher.max_workers == 8 - assert dispatcher.pulsar_client == mock_pulsar_client + assert dispatcher.backend == mock_backend assert dispatcher.dispatcher_manager == mock_dispatcher_instance mock_dispatcher_manager.assert_called_once_with( - mock_pulsar_client, mock_config_receiver, prefix="rev-gateway" + mock_backend, mock_config_receiver, prefix="rev-gateway" ) def test_message_dispatcher_service_mapping(self): diff --git a/tests/unit/test_rev_gateway/test_rev_gateway_service.py b/tests/unit/test_rev_gateway/test_rev_gateway_service.py index d991ba45..23aff18e 100644 --- a/tests/unit/test_rev_gateway/test_rev_gateway_service.py +++ b/tests/unit/test_rev_gateway/test_rev_gateway_service.py @@ -16,11 +16,11 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway initialization with default parameters""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() @@ -38,11 +38,11 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway initialization with custom parameters""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway( websocket_uri="wss://example.com:8080/websocket", @@ -65,11 +65,11 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway initialization with WebSocket URI missing path""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway(websocket_uri="ws://example.com") @@ -78,53 +78,49 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway initialization with invalid WebSocket scheme""" with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): ReverseGateway(websocket_uri="http://example.com") @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway initialization with missing hostname""" with pytest.raises(ValueError, match="WebSocket URI must include hostname"): ReverseGateway(websocket_uri="ws://") @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway creates Pulsar client with authentication""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance - - with patch('pulsar.AuthenticationToken') as mock_auth: - mock_auth_instance = MagicMock() - mock_auth.return_value = mock_auth_instance - - gateway = ReverseGateway( - pulsar_api_key="test-key", - pulsar_listener="test-listener" - ) - - mock_auth.assert_called_once_with("test-key") - mock_pulsar_client.assert_called_once_with( - "pulsar://pulsar:6650", - listener_name="test-listener", - authentication=mock_auth_instance - ) + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): + """Test ReverseGateway creates backend with authentication""" + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + + gateway = ReverseGateway( + pulsar_api_key="test-key", + pulsar_listener="test-listener" + ) + + # Verify get_pubsub was called with the correct parameters + mock_get_pubsub.assert_called_once_with( + pulsar_host="pulsar://pulsar:6650", + pulsar_api_key="test-key", + pulsar_listener="test-listener" + ) @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway successful connection""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend mock_session = AsyncMock() mock_ws = AsyncMock() @@ -142,13 +138,13 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway connection failure""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend mock_session = AsyncMock() mock_session.ws_connect.side_effect = Exception("Connection failed") @@ -162,12 +158,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway disconnect""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() @@ -189,12 +185,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway send message""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() @@ -211,12 +207,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway send message with closed connection""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() @@ -234,12 +230,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway handle message""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend mock_dispatcher_instance = AsyncMock() mock_dispatcher_instance.handle_message.return_value = {"response": "success"} @@ -263,12 +259,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway handle message with invalid JSON""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() @@ -285,12 +281,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway listen with text message""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() gateway.running = True @@ -318,12 +314,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway listen with binary message""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() gateway.running = True @@ -351,12 +347,12 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway listen with close message""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() gateway.running = True @@ -383,36 +379,36 @@ class TestReverseGateway: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway shutdown""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance - + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + mock_dispatcher_instance = AsyncMock() mock_dispatcher.return_value = mock_dispatcher_instance - + gateway = ReverseGateway() gateway.running = True - + # Mock disconnect gateway.disconnect = AsyncMock() - + await gateway.shutdown() - + assert gateway.running is False mock_dispatcher_instance.shutdown.assert_called_once() gateway.disconnect.assert_called_once() - mock_client_instance.close.assert_called_once() + mock_backend.close.assert_called_once() @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') - def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + @patch('trustgraph.rev_gateway.service.get_pubsub') + def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway stop""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend gateway = ReverseGateway() gateway.running = True @@ -427,12 +423,12 @@ class TestReverseGatewayRun: @patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('pulsar.Client') + @patch('trustgraph.rev_gateway.service.get_pubsub') @pytest.mark.asyncio - async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): + async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): """Test ReverseGateway run method with successful connect/listen cycle""" - mock_client_instance = MagicMock() - mock_pulsar_client.return_value = mock_client_instance + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend mock_config_receiver_instance = AsyncMock() mock_config_receiver.return_value = mock_config_receiver_instance diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index f99d9883..fc839482 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -15,11 +15,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): """Test Qdrant document embeddings storage functionality""" @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client): + async def test_processor_initialization_basic(self, mock_qdrant_client): """Test basic Qdrant processor initialization""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -34,9 +32,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - # Verify base class initialization was called - mock_base_init.assert_called_once() - # Verify QdrantClient was created with correct parameters mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') @@ -45,11 +40,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): assert processor.qdrant == mock_qdrant_instance @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client): + async def test_processor_initialization_with_defaults(self, mock_qdrant_client): """Test processor initialization with default values""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -68,11 +61,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_document_embeddings_basic(self, mock_uuid, mock_qdrant_client): """Test storing document embeddings with basic message""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True # Collection already exists mock_qdrant_client.return_value = mock_qdrant_instance @@ -87,7 +78,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('test_user', 'test_collection')] = {} + # Create mock message with chunks and vectors mock_message = MagicMock() mock_message.metadata.user = 'test_user' @@ -121,11 +115,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client): """Test storing document embeddings with multiple chunks""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -140,7 +132,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('multi_user', 'multi_collection')] = {} + # Create mock message with multiple chunks mock_message = MagicMock() mock_message.metadata.user = 'multi_user' @@ -180,11 +175,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client): """Test storing document embeddings with multiple vectors per chunk""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -199,7 +192,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('vector_user', 'vector_collection')] = {} + # Create mock message with chunk having multiple vectors mock_message = MagicMock() mock_message.metadata.user = 'vector_user' @@ -237,11 +233,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): assert point.payload['doc'] == 'multi-vector document chunk' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client): + async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client): """Test storing document embeddings skips empty chunks""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True # Collection exists mock_qdrant_client.return_value = mock_qdrant_instance @@ -277,11 +271,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_collection_creation_when_not_exists(self, mock_uuid, mock_qdrant_client): """Test that writing to non-existent collection creates it lazily""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist mock_qdrant_client.return_value = mock_qdrant_instance @@ -297,6 +289,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) + # Add collection to known_collections (simulates config push) + processor.known_collections[('new_user', 'new_collection')] = {} + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'new_user' @@ -326,11 +321,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_collection_creation_exception(self, mock_uuid, mock_qdrant_client): """Test that collection creation errors are propagated""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist # Simulate creation failure @@ -348,6 +341,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) + # Add collection to known_collections (simulates config push) + processor.known_collections[('error_user', 'error_collection')] = {} + # Create mock message mock_message = MagicMock() mock_message.metadata.user = 'error_user' @@ -364,12 +360,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): await processor.store_document_embeddings(mock_message) @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client): + async def test_collection_validation_on_write(self, mock_uuid, mock_qdrant_client): """Test collection validation checks collection exists before writing""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -385,6 +379,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) + # Add collection to known_collections (simulates config push) + processor.known_collections[('cache_user', 'cache_collection')] = {} + # Create first mock message mock_message1 = MagicMock() mock_message1.metadata.user = 'cache_user' @@ -428,11 +425,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_different_dimensions_different_collections(self, mock_uuid, mock_qdrant_client): """Test that different vector dimensions create different collections""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -448,6 +443,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) + # Add collection to known_collections (simulates config push) + processor.known_collections[('dim_user', 'dim_collection')] = {} + # Create mock message with different dimension vectors mock_message = MagicMock() mock_message.metadata.user = 'dim_user' @@ -482,11 +480,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client): + async def test_add_args_calls_parent(self, mock_qdrant_client): """Test that add_args() calls parent add_args method""" # Arrange - mock_base_init.return_value = None mock_qdrant_client.return_value = MagicMock() mock_parser = MagicMock() @@ -502,11 +498,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client): """Test proper UTF-8 decoding of chunk text""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -521,7 +515,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('utf8_user', 'utf8_collection')] = {} + # Create mock message with UTF-8 encoded text mock_message = MagicMock() mock_message.metadata.user = 'utf8_user' @@ -546,11 +543,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé' @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') - async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client): + async def test_chunk_decode_exception_handling(self, mock_qdrant_client): """Test handling of chunk decode exceptions""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -562,7 +557,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('decode_user', 'decode_collection')] = {} + # Create mock message with decode error mock_message = MagicMock() mock_message.metadata.user = 'decode_user' diff --git a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py index c4b603c9..d240b892 100644 --- a/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py @@ -15,11 +15,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): """Test Qdrant graph embeddings storage functionality""" @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client): + async def test_processor_initialization_basic(self, mock_qdrant_client): """Test basic Qdrant processor initialization""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -34,9 +32,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - # Verify base class initialization was called - mock_base_init.assert_called_once() - # Verify QdrantClient was created with correct parameters mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') @@ -46,11 +41,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_graph_embeddings_basic(self, mock_uuid, mock_qdrant_client): """Test storing graph embeddings with basic message""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True # Collection already exists mock_qdrant_client.return_value = mock_qdrant_instance @@ -64,7 +57,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('test_user', 'test_collection')] = {} + # Create mock message with entities and vectors mock_message = MagicMock() mock_message.metadata.user = 'test_user' @@ -98,11 +94,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_graph_embeddings_multiple_entities(self, mock_uuid, mock_qdrant_client): """Test storing graph embeddings with multiple entities""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -116,7 +110,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('multi_user', 'multi_collection')] = {} + # Create mock message with multiple entities mock_message = MagicMock() mock_message.metadata.user = 'multi_user' @@ -156,11 +153,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client): + async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client): """Test storing graph embeddings with multiple vectors per entity""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_client.return_value = mock_qdrant_instance @@ -174,7 +169,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - + + # Add collection to known_collections (simulates config push) + processor.known_collections[('vector_user', 'vector_collection')] = {} + # Create mock message with entity having multiple vectors mock_message = MagicMock() mock_message.metadata.user = 'vector_user' @@ -212,11 +210,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): assert point.payload['entity'] == 'multi_vector_entity' @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client): + async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client): """Test storing graph embeddings skips empty entity values""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -253,11 +249,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_instance.collection_exists.assert_not_called() @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client): + async def test_processor_initialization_with_defaults(self, mock_qdrant_client): """Test processor initialization with default values""" # Arrange - mock_base_init.return_value = None mock_qdrant_instance = MagicMock() mock_qdrant_client.return_value = mock_qdrant_instance @@ -275,11 +269,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase): mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None) @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') - @patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') - async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client): + async def test_add_args_calls_parent(self, mock_qdrant_client): """Test that add_args() calls parent add_args method""" # Arrange - mock_base_init.return_value = None mock_qdrant_client.return_value = MagicMock() mock_parser = MagicMock() diff --git a/trustgraph-base/pyproject.toml b/trustgraph-base/pyproject.toml index c36cab10..7d9f9219 100644 --- a/trustgraph-base/pyproject.toml +++ b/trustgraph-base/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pulsar-client", "prometheus-client", "requests", + "python-logging-loki", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index daa6a964..0ecb760e 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -1,3 +1,114 @@ -from . api import * +# Core API +from .api import Api + +# Flow clients +from .flow import Flow, FlowInstance +from .async_flow import AsyncFlow, AsyncFlowInstance + +# WebSocket clients +from .socket_client import SocketClient, SocketFlowInstance +from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance + +# Bulk operation clients +from .bulk_client import BulkClient +from .async_bulk_client import AsyncBulkClient + +# Metrics clients +from .metrics import Metrics +from .async_metrics import AsyncMetrics + +# Types +from .types import ( + Triple, + ConfigKey, + ConfigValue, + DocumentMetadata, + ProcessingMetadata, + CollectionMetadata, + StreamingChunk, + AgentThought, + AgentObservation, + AgentAnswer, + RAGChunk, +) + +# Exceptions +from .exceptions import ( + ProtocolException, + TrustGraphException, + AgentError, + ConfigError, + DocumentRagError, + FlowError, + GatewayError, + GraphRagError, + LLMError, + LoadError, + LookupError, + NLPQueryError, + ObjectsQueryError, + RequestError, + StructuredQueryError, + UnexpectedError, + # Legacy alias + ApplicationException, +) + +__all__ = [ + # Core API + "Api", + + # Flow clients + "Flow", + "FlowInstance", + "AsyncFlow", + "AsyncFlowInstance", + + # WebSocket clients + "SocketClient", + "SocketFlowInstance", + "AsyncSocketClient", + "AsyncSocketFlowInstance", + + # Bulk operation clients + "BulkClient", + "AsyncBulkClient", + + # Metrics clients + "Metrics", + "AsyncMetrics", + + # Types + "Triple", + "ConfigKey", + "ConfigValue", + "DocumentMetadata", + "ProcessingMetadata", + "CollectionMetadata", + "StreamingChunk", + "AgentThought", + "AgentObservation", + "AgentAnswer", + "RAGChunk", + + # Exceptions + "ProtocolException", + "TrustGraphException", + "AgentError", + "ConfigError", + "DocumentRagError", + "FlowError", + "GatewayError", + "GraphRagError", + "LLMError", + "LoadError", + "LookupError", + "NLPQueryError", + "ObjectsQueryError", + "RequestError", + "StructuredQueryError", + "UnexpectedError", + "ApplicationException", # Legacy alias +] diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index b0bae8ce..d1f07513 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -3,6 +3,7 @@ import requests import json import base64 import time +from typing import Optional from . library import Library from . flow import Flow @@ -26,7 +27,7 @@ def check_error(response): class Api: - def __init__(self, url="http://localhost:8088/", timeout=60): + def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None): self.url = url @@ -36,6 +37,16 @@ class Api: self.url += "api/v1/" self.timeout = timeout + self.token = token + + # Lazy initialization for new clients + self._socket_client = None + self._bulk_client = None + self._async_flow = None + self._async_socket_client = None + self._async_bulk_client = None + self._metrics = None + self._async_metrics = None def flow(self): return Flow(api=self) @@ -50,8 +61,12 @@ class Api: url = f"{self.url}{path}" + headers = {} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + # Invoke the API, input is passed as JSON - resp = requests.post(url, json=request, timeout=self.timeout) + resp = requests.post(url, json=request, timeout=self.timeout, headers=headers) # Should be a 200 status code if resp.status_code != 200: @@ -72,3 +87,96 @@ class Api: def collection(self): return Collection(self) + + # New synchronous methods + def socket(self): + """Synchronous WebSocket-based interface for streaming operations""" + if self._socket_client is None: + from . socket_client import SocketClient + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._socket_client = SocketClient(base_url, self.timeout, self.token) + return self._socket_client + + def bulk(self): + """Synchronous bulk operations interface for import/export""" + if self._bulk_client is None: + from . bulk_client import BulkClient + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._bulk_client = BulkClient(base_url, self.timeout, self.token) + return self._bulk_client + + def metrics(self): + """Synchronous metrics interface""" + if self._metrics is None: + from . metrics import Metrics + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._metrics = Metrics(base_url, self.timeout, self.token) + return self._metrics + + # New asynchronous methods + def async_flow(self): + """Asynchronous REST-based flow interface""" + if self._async_flow is None: + from . async_flow import AsyncFlow + self._async_flow = AsyncFlow(self.url, self.timeout, self.token) + return self._async_flow + + def async_socket(self): + """Asynchronous WebSocket-based interface for streaming operations""" + if self._async_socket_client is None: + from . async_socket_client import AsyncSocketClient + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token) + return self._async_socket_client + + def async_bulk(self): + """Asynchronous bulk operations interface for import/export""" + if self._async_bulk_client is None: + from . async_bulk_client import AsyncBulkClient + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token) + return self._async_bulk_client + + def async_metrics(self): + """Asynchronous metrics interface""" + if self._async_metrics is None: + from . async_metrics import AsyncMetrics + # Extract base URL (remove api/v1/ suffix) + base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") + self._async_metrics = AsyncMetrics(base_url, self.timeout, self.token) + return self._async_metrics + + # Resource management + def close(self): + """Close all synchronous connections""" + if self._socket_client: + self._socket_client.close() + if self._bulk_client: + self._bulk_client.close() + + async def aclose(self): + """Close all asynchronous connections""" + if self._async_socket_client: + await self._async_socket_client.aclose() + if self._async_bulk_client: + await self._async_bulk_client.aclose() + if self._async_flow: + await self._async_flow.aclose() + + # Context manager support + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.aclose() diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py new file mode 100644 index 00000000..76cb9f56 --- /dev/null +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -0,0 +1,131 @@ + +import json +import websockets +from typing import Optional, AsyncIterator, Dict, Any, Iterator + +from . types import Triple + + +class AsyncBulkClient: + """Asynchronous bulk operations client""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = self._convert_to_ws_url(url) + self.timeout: int = timeout + self.token: Optional[str] = token + + def _convert_to_ws_url(self, url: str) -> str: + """Convert HTTP URL to WebSocket URL""" + if url.startswith("http://"): + return url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + return url.replace("https://", "wss://", 1) + elif url.startswith("ws://") or url.startswith("wss://"): + return url + else: + return f"ws://{url}" + + async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None: + """Bulk import triples via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for triple in triples: + message = { + "s": triple.s, + "p": triple.p, + "o": triple.o + } + await websocket.send(json.dumps(message)) + + async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]: + """Bulk export triples via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + data = json.loads(raw_message) + yield Triple( + s=data.get("s", ""), + p=data.get("p", ""), + o=data.get("o", "") + ) + + async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import graph embeddings via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for embedding in embeddings: + await websocket.send(json.dumps(embedding)) + + async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: + """Bulk export graph embeddings via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import document embeddings via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for embedding in embeddings: + await websocket.send(json.dumps(embedding)) + + async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: + """Bulk export document embeddings via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import entity contexts via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for context in contexts: + await websocket.send(json.dumps(context)) + + async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: + """Bulk export entity contexts via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import objects via WebSocket""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for obj in objects: + await websocket.send(json.dumps(obj)) + + async def aclose(self) -> None: + """Close connections""" + # Cleanup handled by context managers + pass diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py new file mode 100644 index 00000000..5d3cd486 --- /dev/null +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -0,0 +1,245 @@ + +import aiohttp +import json +from typing import Optional, Dict, Any, List + +from . exceptions import ProtocolException, ApplicationException + + +def check_error(response): + if "error" in response: + try: + msg = response["error"]["message"] + tp = response["error"]["type"] + except: + raise ApplicationException(response["error"]) + + raise ApplicationException(f"{tp}: {msg}") + + +class AsyncFlow: + """Asynchronous REST-based flow interface""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = url + self.timeout: int = timeout + self.token: Optional[str] = token + + async def request(self, path: str, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Make async HTTP request to Gateway API""" + url = f"{self.url}{path}" + + headers = {"Content-Type": "application/json"} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + timeout = aiohttp.ClientTimeout(total=self.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=request_data, headers=headers) as resp: + if resp.status != 200: + raise ProtocolException(f"Status code {resp.status}") + + try: + obj = await resp.json() + except: + raise ProtocolException(f"Expected JSON response") + + check_error(obj) + return obj + + async def list(self) -> List[str]: + """List all flows""" + result = await self.request("flow", {"operation": "list-flows"}) + return result.get("flow-ids", []) + + async def get(self, id: str) -> Dict[str, Any]: + """Get flow definition""" + result = await self.request("flow", { + "operation": "get-flow", + "flow-id": id + }) + return json.loads(result.get("flow", "{}")) + + async def start(self, class_name: str, id: str, description: str, parameters: Optional[Dict] = None): + """Start a flow""" + request_data = { + "operation": "start-flow", + "flow-id": id, + "class-name": class_name, + "description": description + } + if parameters: + request_data["parameters"] = json.dumps(parameters) + + await self.request("flow", request_data) + + async def stop(self, id: str): + """Stop a flow""" + await self.request("flow", { + "operation": "stop-flow", + "flow-id": id + }) + + async def list_classes(self) -> List[str]: + """List flow classes""" + result = await self.request("flow", {"operation": "list-classes"}) + return result.get("class-names", []) + + async def get_class(self, class_name: str) -> Dict[str, Any]: + """Get flow class definition""" + result = await self.request("flow", { + "operation": "get-class", + "class-name": class_name + }) + return json.loads(result.get("class-definition", "{}")) + + async def put_class(self, class_name: str, definition: Dict[str, Any]): + """Create/update flow class""" + await self.request("flow", { + "operation": "put-class", + "class-name": class_name, + "class-definition": json.dumps(definition) + }) + + async def delete_class(self, class_name: str): + """Delete flow class""" + await self.request("flow", { + "operation": "delete-class", + "class-name": class_name + }) + + def id(self, flow_id: str): + """Get async flow instance""" + return AsyncFlowInstance(self, flow_id) + + async def aclose(self) -> None: + """Close connection (cleanup handled by aiohttp session)""" + pass + + +class AsyncFlowInstance: + """Asynchronous REST flow instance""" + + def __init__(self, flow: AsyncFlow, flow_id: str): + self.flow = flow + self.flow_id = flow_id + + async def request(self, service: str, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Make request to flow-scoped service""" + return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data) + + async def agent(self, question: str, user: str, state: Optional[Dict] = None, + group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]: + """Execute agent (non-streaming, use async_socket for streaming)""" + request_data = { + "question": question, + "user": user, + "streaming": False # REST doesn't support streaming + } + if state is not None: + request_data["state"] = state + if group is not None: + request_data["group"] = group + if history is not None: + request_data["history"] = history + request_data.update(kwargs) + + return await self.request("agent", request_data) + + async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str: + """Text completion (non-streaming, use async_socket for streaming)""" + request_data = { + "system": system, + "prompt": prompt, + "streaming": False + } + request_data.update(kwargs) + + result = await self.request("text-completion", request_data) + return result.get("response", "") + + async def graph_rag(self, query: str, user: str, collection: str, + max_subgraph_size: int = 1000, max_subgraph_count: int = 5, + max_entity_distance: int = 3, **kwargs: Any) -> str: + """Graph RAG (non-streaming, use async_socket for streaming)""" + request_data = { + "query": query, + "user": user, + "collection": collection, + "max-subgraph-size": max_subgraph_size, + "max-subgraph-count": max_subgraph_count, + "max-entity-distance": max_entity_distance, + "streaming": False + } + request_data.update(kwargs) + + result = await self.request("graph-rag", request_data) + return result.get("response", "") + + async def document_rag(self, query: str, user: str, collection: str, + doc_limit: int = 10, **kwargs: Any) -> str: + """Document RAG (non-streaming, use async_socket for streaming)""" + request_data = { + "query": query, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + "streaming": False + } + request_data.update(kwargs) + + result = await self.request("document-rag", request_data) + return result.get("response", "") + + async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any): + """Query graph embeddings for semantic search""" + request_data = { + "text": text, + "user": user, + "collection": collection, + "limit": limit + } + request_data.update(kwargs) + + return await self.request("graph-embeddings", request_data) + + async def embeddings(self, text: str, **kwargs: Any): + """Generate text embeddings""" + request_data = {"text": text} + request_data.update(kwargs) + + return await self.request("embeddings", request_data) + + async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any): + """Triple pattern query""" + request_data = {"limit": limit} + if s is not None: + request_data["s"] = str(s) + if p is not None: + request_data["p"] = str(p) + if o is not None: + request_data["o"] = str(o) + if user is not None: + request_data["user"] = user + if collection is not None: + request_data["collection"] = collection + request_data.update(kwargs) + + return await self.request("triples", request_data) + + async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs: Any): + """GraphQL query""" + request_data = { + "query": query, + "user": user, + "collection": collection + } + if variables: + request_data["variables"] = variables + if operation_name: + request_data["operationName"] = operation_name + request_data.update(kwargs) + + return await self.request("objects", request_data) diff --git a/trustgraph-base/trustgraph/api/async_metrics.py b/trustgraph-base/trustgraph/api/async_metrics.py new file mode 100644 index 00000000..9ba22f02 --- /dev/null +++ b/trustgraph-base/trustgraph/api/async_metrics.py @@ -0,0 +1,33 @@ + +import aiohttp +from typing import Optional, Dict + + +class AsyncMetrics: + """Asynchronous metrics client""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = url + self.timeout: int = timeout + self.token: Optional[str] = token + + async def get(self) -> str: + """Get Prometheus metrics as text""" + url: str = f"{self.url}/api/metrics" + + headers: Dict[str, str] = {} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + timeout = aiohttp.ClientTimeout(total=self.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as resp: + if resp.status != 200: + raise Exception(f"Status code {resp.status}") + + return await resp.text() + + async def aclose(self) -> None: + """Close connections""" + pass diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py new file mode 100644 index 00000000..cb6c8605 --- /dev/null +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -0,0 +1,343 @@ + +import json +import websockets +from typing import Optional, Dict, Any, AsyncIterator, Union + +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk +from . exceptions import ProtocolException, ApplicationException + + +class AsyncSocketClient: + """Asynchronous WebSocket client""" + + def __init__(self, url: str, timeout: int, token: Optional[str]): + self.url = self._convert_to_ws_url(url) + self.timeout = timeout + self.token = token + self._request_counter = 0 + + def _convert_to_ws_url(self, url: str) -> str: + """Convert HTTP URL to WebSocket URL""" + if url.startswith("http://"): + return url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + return url.replace("https://", "wss://", 1) + elif url.startswith("ws://") or url.startswith("wss://"): + return url + else: + # Assume ws:// + return f"ws://{url}" + + def flow(self, flow_id: str): + """Get async flow instance for WebSocket operations""" + return AsyncSocketFlowInstance(self, flow_id) + + async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]): + """Async WebSocket request implementation (non-streaming)""" + # Generate unique request ID + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + # Build WebSocket URL with optional token + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + # Build request message + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + # Connect and send request + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + # Wait for single response + raw_message = await websocket.recv() + response = json.loads(raw_message) + + if response.get("id") != request_id: + raise ProtocolException(f"Response ID mismatch") + + if "error" in response: + raise ApplicationException(response["error"]) + + if "response" not in response: + raise ProtocolException(f"Missing response in message") + + return response["response"] + + async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]): + """Async WebSocket request implementation (streaming)""" + # Generate unique request ID + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + # Build WebSocket URL with optional token + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + # Build request message + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + # Connect and send request + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + # Yield chunks as they arrive + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != request_id: + continue # Ignore messages for other requests + + if "error" in response: + raise ApplicationException(response["error"]) + + if "response" in response: + resp = response["response"] + + # Parse different chunk types + chunk = self._parse_chunk(resp) + yield chunk + + # Check if this is the final chunk + if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + break + + def _parse_chunk(self, resp: Dict[str, Any]): + """Parse response chunk into appropriate type""" + chunk_type = resp.get("chunk_type") + + if chunk_type == "thought": + return AgentThought( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + elif chunk_type == "observation": + return AgentObservation( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + elif chunk_type == "answer" or chunk_type == "final-answer": + return AgentAnswer( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False), + end_of_dialog=resp.get("end_of_dialog", False) + ) + elif chunk_type == "action": + # Agent action chunks - treat as thoughts for display purposes + return AgentThought( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + else: + # RAG-style chunk (or generic chunk) + # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field + content = resp.get("response", resp.get("chunk", resp.get("text", ""))) + return RAGChunk( + content=content, + end_of_stream=resp.get("end_of_stream", False), + error=None # Errors are always thrown, never stored + ) + + async def aclose(self): + """Close WebSocket connection""" + # Cleanup handled by context manager + pass + + +class AsyncSocketFlowInstance: + """Asynchronous WebSocket flow instance""" + + def __init__(self, client: AsyncSocketClient, flow_id: str): + self.client = client + self.flow_id = flow_id + + async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, history: Optional[list] = None, + streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]: + """Agent with optional streaming""" + request = { + "question": question, + "user": user, + "streaming": streaming + } + if state is not None: + request["state"] = state + if group is not None: + request["group"] = group + if history is not None: + request["history"] = history + request.update(kwargs) + + if streaming: + return self.client._send_request_streaming("agent", self.flow_id, request) + else: + return await self.client._send_request("agent", self.flow_id, request) + + async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs): + """Text completion with optional streaming""" + request = { + "system": system, + "prompt": prompt, + "streaming": streaming + } + request.update(kwargs) + + if streaming: + return self._text_completion_streaming(request) + else: + result = await self.client._send_request("text-completion", self.flow_id, request) + return result.get("response", "") + + async def _text_completion_streaming(self, request): + """Helper for streaming text completion""" + async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request): + if hasattr(chunk, 'content'): + yield chunk.content + + async def graph_rag(self, query: str, user: str, collection: str, + max_subgraph_size: int = 1000, max_subgraph_count: int = 5, + max_entity_distance: int = 3, streaming: bool = False, **kwargs): + """Graph RAG with optional streaming""" + request = { + "query": query, + "user": user, + "collection": collection, + "max-subgraph-size": max_subgraph_size, + "max-subgraph-count": max_subgraph_count, + "max-entity-distance": max_entity_distance, + "streaming": streaming + } + request.update(kwargs) + + if streaming: + return self._graph_rag_streaming(request) + else: + result = await self.client._send_request("graph-rag", self.flow_id, request) + return result.get("response", "") + + async def _graph_rag_streaming(self, request): + """Helper for streaming graph RAG""" + async for chunk in self.client._send_request_streaming("graph-rag", self.flow_id, request): + if hasattr(chunk, 'content'): + yield chunk.content + + async def document_rag(self, query: str, user: str, collection: str, + doc_limit: int = 10, streaming: bool = False, **kwargs): + """Document RAG with optional streaming""" + request = { + "query": query, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + "streaming": streaming + } + request.update(kwargs) + + if streaming: + return self._document_rag_streaming(request) + else: + result = await self.client._send_request("document-rag", self.flow_id, request) + return result.get("response", "") + + async def _document_rag_streaming(self, request): + """Helper for streaming document RAG""" + async for chunk in self.client._send_request_streaming("document-rag", self.flow_id, request): + if hasattr(chunk, 'content'): + yield chunk.content + + async def prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs): + """Execute prompt with optional streaming""" + request = { + "id": id, + "variables": variables, + "streaming": streaming + } + request.update(kwargs) + + if streaming: + return self._prompt_streaming(request) + else: + result = await self.client._send_request("prompt", self.flow_id, request) + return result.get("response", "") + + async def _prompt_streaming(self, request): + """Helper for streaming prompt""" + async for chunk in self.client._send_request_streaming("prompt", self.flow_id, request): + if hasattr(chunk, 'content'): + yield chunk.content + + async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs): + """Query graph embeddings for semantic search""" + request = { + "text": text, + "user": user, + "collection": collection, + "limit": limit + } + request.update(kwargs) + + return await self.client._send_request("graph-embeddings", self.flow_id, request) + + async def embeddings(self, text: str, **kwargs): + """Generate text embeddings""" + request = {"text": text} + request.update(kwargs) + + return await self.client._send_request("embeddings", self.flow_id, request) + + async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs): + """Triple pattern query""" + request = {"limit": limit} + if s is not None: + request["s"] = str(s) + if p is not None: + request["p"] = str(p) + if o is not None: + request["o"] = str(o) + if user is not None: + request["user"] = user + if collection is not None: + request["collection"] = collection + request.update(kwargs) + + return await self.client._send_request("triples", self.flow_id, request) + + async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None, + operation_name: Optional[str] = None, **kwargs): + """GraphQL query""" + request = { + "query": query, + "user": user, + "collection": collection + } + if variables: + request["variables"] = variables + if operation_name: + request["operationName"] = operation_name + request.update(kwargs) + + return await self.client._send_request("objects", self.flow_id, request) + + async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs): + """Execute MCP tool""" + request = { + "name": name, + "parameters": parameters + } + request.update(kwargs) + + return await self.client._send_request("mcp-tool", self.flow_id, request) diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py new file mode 100644 index 00000000..a119668d --- /dev/null +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -0,0 +1,270 @@ + +import json +import asyncio +import websockets +from typing import Optional, Iterator, Dict, Any, Coroutine + +from . types import Triple +from . exceptions import ProtocolException + + +class BulkClient: + """Synchronous bulk operations client""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = self._convert_to_ws_url(url) + self.timeout: int = timeout + self.token: Optional[str] = token + + def _convert_to_ws_url(self, url: str) -> str: + """Convert HTTP URL to WebSocket URL""" + if url.startswith("http://"): + return url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + return url.replace("https://", "wss://", 1) + elif url.startswith("ws://") or url.startswith("wss://"): + return url + else: + return f"ws://{url}" + + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: + """Run async coroutine synchronously""" + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(coro) + + def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None: + """Bulk import triples via WebSocket""" + self._run_async(self._import_triples_async(flow, triples)) + + async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None: + """Async implementation of triple import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + for triple in triples: + message = { + "s": triple.s, + "p": triple.p, + "o": triple.o + } + await websocket.send(json.dumps(message)) + + def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]: + """Bulk export triples via WebSocket""" + async_gen = self._export_triples_async(flow) + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + while True: + try: + triple = loop.run_until_complete(async_gen.__anext__()) + yield triple + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _export_triples_async(self, flow: str) -> Iterator[Triple]: + """Async implementation of triple export""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + data = json.loads(raw_message) + yield Triple( + s=data.get("s", ""), + p=data.get("p", ""), + o=data.get("o", "") + ) + + def import_graph_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import graph embeddings via WebSocket""" + self._run_async(self._import_graph_embeddings_async(flow, embeddings)) + + async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: + """Async implementation of graph embeddings import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + for embedding in embeddings: + await websocket.send(json.dumps(embedding)) + + def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]: + """Bulk export graph embeddings via WebSocket""" + async_gen = self._export_graph_embeddings_async(flow) + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + while True: + try: + embedding = loop.run_until_complete(async_gen.__anext__()) + yield embedding + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: + """Async implementation of graph embeddings export""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + def import_document_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import document embeddings via WebSocket""" + self._run_async(self._import_document_embeddings_async(flow, embeddings)) + + async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: + """Async implementation of document embeddings import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + for embedding in embeddings: + await websocket.send(json.dumps(embedding)) + + def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]: + """Bulk export document embeddings via WebSocket""" + async_gen = self._export_document_embeddings_async(flow) + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + while True: + try: + embedding = loop.run_until_complete(async_gen.__anext__()) + yield embedding + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: + """Async implementation of document embeddings export""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import entity contexts via WebSocket""" + self._run_async(self._import_entity_contexts_async(flow, contexts)) + + async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None: + """Async implementation of entity contexts import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + for context in contexts: + await websocket.send(json.dumps(context)) + + def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]: + """Bulk export entity contexts via WebSocket""" + async_gen = self._export_entity_contexts_async(flow) + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + while True: + try: + context = loop.run_until_complete(async_gen.__anext__()) + yield context + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]: + """Async implementation of entity contexts export""" + ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + async for raw_message in websocket: + yield json.loads(raw_message) + + def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None: + """Bulk import objects via WebSocket""" + self._run_async(self._import_objects_async(flow, objects)) + + async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None: + """Async implementation of objects import""" + ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + for obj in objects: + await websocket.send(json.dumps(obj)) + + def close(self) -> None: + """Close connections""" + # Cleanup handled by context managers + pass diff --git a/trustgraph-base/trustgraph/api/collection.py b/trustgraph-base/trustgraph/api/collection.py index 0e1abeaf..5a1f0850 100644 --- a/trustgraph-base/trustgraph/api/collection.py +++ b/trustgraph-base/trustgraph/api/collection.py @@ -41,9 +41,7 @@ class Collection: collection = v["collection"], name = v["name"], description = v["description"], - tags = v["tags"], - created_at = v["created_at"], - updated_at = v["updated_at"] + tags = v["tags"] ) for v in collections ] @@ -76,9 +74,7 @@ class Collection: collection = v["collection"], name = v["name"], description = v["description"], - tags = v["tags"], - created_at = v["created_at"], - updated_at = v["updated_at"] + tags = v["tags"] ) return None except Exception as e: diff --git a/trustgraph-base/trustgraph/api/exceptions.py b/trustgraph-base/trustgraph/api/exceptions.py index b3f732d4..311d2651 100644 --- a/trustgraph-base/trustgraph/api/exceptions.py +++ b/trustgraph-base/trustgraph/api/exceptions.py @@ -1,6 +1,134 @@ +""" +TrustGraph API Exceptions +Exception hierarchy for errors returned by TrustGraph services. +Each service error type maps to a specific exception class. +""" + +# Protocol-level exceptions (communication errors) class ProtocolException(Exception): + """Raised when WebSocket protocol errors occur""" pass -class ApplicationException(Exception): + +# Base class for all TrustGraph application errors +class TrustGraphException(Exception): + """Base class for all TrustGraph service errors""" + def __init__(self, message: str, error_type: str = None): + super().__init__(message) + self.message = message + self.error_type = error_type + + +# Service-specific exceptions +class AgentError(TrustGraphException): + """Agent service error""" pass + + +class ConfigError(TrustGraphException): + """Configuration service error""" + pass + + +class DocumentRagError(TrustGraphException): + """Document RAG retrieval error""" + pass + + +class FlowError(TrustGraphException): + """Flow management error""" + pass + + +class GatewayError(TrustGraphException): + """API Gateway error""" + pass + + +class GraphRagError(TrustGraphException): + """Graph RAG retrieval error""" + pass + + +class LLMError(TrustGraphException): + """LLM service error""" + pass + + +class LoadError(TrustGraphException): + """Data loading error""" + pass + + +class LookupError(TrustGraphException): + """Lookup/search error""" + pass + + +class NLPQueryError(TrustGraphException): + """NLP query service error""" + pass + + +class ObjectsQueryError(TrustGraphException): + """Objects query service error""" + pass + + +class RequestError(TrustGraphException): + """Request processing error""" + pass + + +class StructuredQueryError(TrustGraphException): + """Structured query service error""" + pass + + +class UnexpectedError(TrustGraphException): + """Unexpected/unknown error""" + pass + + +# Mapping from error type string to exception class +ERROR_TYPE_MAPPING = { + "agent-error": AgentError, + "config-error": ConfigError, + "document-rag-error": DocumentRagError, + "flow-error": FlowError, + "gateway-error": GatewayError, + "graph-rag-error": GraphRagError, + "llm-error": LLMError, + "load-error": LoadError, + "lookup-error": LookupError, + "nlp-query-error": NLPQueryError, + "objects-query-error": ObjectsQueryError, + "request-error": RequestError, + "structured-query-error": StructuredQueryError, + "unexpected-error": UnexpectedError, +} + + +def raise_from_error_dict(error_dict: dict) -> None: + """ + Raise appropriate exception from TrustGraph error dictionary. + + Args: + error_dict: Dictionary with 'type' and 'message' keys + + Raises: + Appropriate TrustGraphException subclass based on error type + """ + error_type = error_dict.get("type", "unexpected-error") + message = error_dict.get("message", "Unknown error") + + # Look up exception class, default to UnexpectedError + exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError) + + # Raise the appropriate exception + raise exception_class(message, error_type) + + +# Legacy exception for backwards compatibility +ApplicationException = TrustGraphException diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 0214a4bd..744ad2e7 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -160,14 +160,14 @@ class FlowInstance: )["answer"] def graph_rag( - self, question, user="trustgraph", collection="default", + self, query, user="trustgraph", collection="default", entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, ): # The input consists of a question input = { - "query": question, + "query": query, "user": user, "collection": collection, "entity-limit": entity_limit, @@ -182,13 +182,13 @@ class FlowInstance: )["response"] def document_rag( - self, question, user="trustgraph", collection="default", + self, query, user="trustgraph", collection="default", doc_limit=10, ): # The input consists of a question input = { - "query": question, + "query": query, "user": user, "collection": collection, "doc-limit": doc_limit, @@ -211,6 +211,21 @@ class FlowInstance: input )["vectors"] + def graph_embeddings_query(self, text, user, collection, limit=10): + + # Query graph embeddings for semantic search + input = { + "text": text, + "user": user, + "collection": collection, + "limit": limit + } + + return self.request( + "service/graph-embeddings", + input + ) + def prompt(self, id, variables): input = { diff --git a/trustgraph-base/trustgraph/api/metrics.py b/trustgraph-base/trustgraph/api/metrics.py new file mode 100644 index 00000000..68968349 --- /dev/null +++ b/trustgraph-base/trustgraph/api/metrics.py @@ -0,0 +1,27 @@ + +import requests +from typing import Optional, Dict + + +class Metrics: + """Synchronous metrics client""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = url + self.timeout: int = timeout + self.token: Optional[str] = token + + def get(self) -> str: + """Get Prometheus metrics as text""" + url: str = f"{self.url}/api/metrics" + + headers: Dict[str, str] = {} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + resp = requests.get(url, timeout=self.timeout, headers=headers) + + if resp.status_code != 200: + raise Exception(f"Status code {resp.status_code}") + + return resp.text diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py new file mode 100644 index 00000000..b1be0195 --- /dev/null +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -0,0 +1,457 @@ + +import json +import asyncio +import websockets +from typing import Optional, Dict, Any, Iterator, Union, List +from threading import Lock + +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk +from . exceptions import ProtocolException, raise_from_error_dict + + +class SocketClient: + """Synchronous WebSocket client (wraps async websockets library)""" + + def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + self.url: str = self._convert_to_ws_url(url) + self.timeout: int = timeout + self.token: Optional[str] = token + self._connection: Optional[Any] = None + self._request_counter: int = 0 + self._lock: Lock = Lock() + self._loop: Optional[asyncio.AbstractEventLoop] = None + + def _convert_to_ws_url(self, url: str) -> str: + """Convert HTTP URL to WebSocket URL""" + if url.startswith("http://"): + return url.replace("http://", "ws://", 1) + elif url.startswith("https://"): + return url.replace("https://", "wss://", 1) + elif url.startswith("ws://") or url.startswith("wss://"): + return url + else: + # Assume ws:// + return f"ws://{url}" + + def flow(self, flow_id: str) -> "SocketFlowInstance": + """Get flow instance for WebSocket operations""" + return SocketFlowInstance(self, flow_id) + + def _send_request_sync( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + streaming: bool = False + ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: + """Synchronous wrapper around async WebSocket communication""" + # Create event loop if needed + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is running (e.g., in Jupyter), create new loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if streaming: + # For streaming, we need to return an iterator + # Create a generator that runs async code + return self._streaming_generator(service, flow, request, loop) + else: + # For non-streaming, just run the async code and return result + return loop.run_until_complete(self._send_request_async(service, flow, request)) + + def _streaming_generator( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + loop: asyncio.AbstractEventLoop + ) -> Iterator[StreamingChunk]: + """Generator that yields streaming chunks""" + async_gen = self._send_request_async_streaming(service, flow, request) + + try: + while True: + try: + chunk = loop.run_until_complete(async_gen.__anext__()) + yield chunk + except StopAsyncIteration: + break + finally: + # Clean up async generator + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _send_request_async( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any] + ) -> Dict[str, Any]: + """Async implementation of WebSocket request (non-streaming)""" + # Generate unique request ID + with self._lock: + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + # Build WebSocket URL with optional token + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + # Build request message + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + # Connect and send request + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + # Wait for single response + raw_message = await websocket.recv() + response = json.loads(raw_message) + + if response.get("id") != request_id: + raise ProtocolException(f"Response ID mismatch") + + if "error" in response: + raise_from_error_dict(response["error"]) + + if "response" not in response: + raise ProtocolException(f"Missing response in message") + + return response["response"] + + async def _send_request_async_streaming( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any] + ) -> Iterator[StreamingChunk]: + """Async implementation of WebSocket request (streaming)""" + # Generate unique request ID + with self._lock: + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + # Build WebSocket URL with optional token + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + # Build request message + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + # Connect and send request + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + # Yield chunks as they arrive + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != request_id: + continue # Ignore messages for other requests + + if "error" in response: + raise_from_error_dict(response["error"]) + + if "response" in response: + resp = response["response"] + + # Check for errors in response chunks + if "error" in resp: + raise_from_error_dict(resp["error"]) + + # Parse different chunk types + chunk = self._parse_chunk(resp) + yield chunk + + # Check if this is the final chunk + if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"): + break + + def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk: + """Parse response chunk into appropriate type""" + chunk_type = resp.get("chunk_type") + + if chunk_type == "thought": + return AgentThought( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + elif chunk_type == "observation": + return AgentObservation( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + elif chunk_type == "answer" or chunk_type == "final-answer": + return AgentAnswer( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False), + end_of_dialog=resp.get("end_of_dialog", False) + ) + elif chunk_type == "action": + # Agent action chunks - treat as thoughts for display purposes + return AgentThought( + content=resp.get("content", ""), + end_of_message=resp.get("end_of_message", False) + ) + else: + # RAG-style chunk (or generic chunk) + # Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field + content = resp.get("response", resp.get("chunk", resp.get("text", ""))) + return RAGChunk( + content=content, + end_of_stream=resp.get("end_of_stream", False), + error=None # Errors are always thrown, never stored + ) + + def close(self) -> None: + """Close WebSocket connection""" + # Cleanup handled by context manager in async code + pass + + +class SocketFlowInstance: + """Synchronous WebSocket flow instance with same interface as REST FlowInstance""" + + def __init__(self, client: SocketClient, flow_id: str) -> None: + self.client: SocketClient = client + self.flow_id: str = flow_id + + def agent( + self, + question: str, + user: str, + state: Optional[Dict[str, Any]] = None, + group: Optional[str] = None, + history: Optional[List[Dict[str, Any]]] = None, + streaming: bool = False, + **kwargs: Any + ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: + """Agent with optional streaming""" + request = { + "question": question, + "user": user, + "streaming": streaming + } + if state is not None: + request["state"] = state + if group is not None: + request["group"] = group + if history is not None: + request["history"] = history + request.update(kwargs) + + return self.client._send_request_sync("agent", self.flow_id, request, streaming) + + def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: + """Text completion with optional streaming""" + request = { + "system": system, + "prompt": prompt, + "streaming": streaming + } + request.update(kwargs) + + result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming) + + if streaming: + # For text completion, yield just the content + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + else: + return result.get("response", "") + + def graph_rag( + self, + query: str, + user: str, + collection: str, + max_subgraph_size: int = 1000, + max_subgraph_count: int = 5, + max_entity_distance: int = 3, + streaming: bool = False, + **kwargs: Any + ) -> Union[str, Iterator[str]]: + """Graph RAG with optional streaming""" + request = { + "query": query, + "user": user, + "collection": collection, + "max-subgraph-size": max_subgraph_size, + "max-subgraph-count": max_subgraph_count, + "max-entity-distance": max_entity_distance, + "streaming": streaming + } + request.update(kwargs) + + result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming) + + if streaming: + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + else: + return result.get("response", "") + + def document_rag( + self, + query: str, + user: str, + collection: str, + doc_limit: int = 10, + streaming: bool = False, + **kwargs: Any + ) -> Union[str, Iterator[str]]: + """Document RAG with optional streaming""" + request = { + "query": query, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + "streaming": streaming + } + request.update(kwargs) + + result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming) + + if streaming: + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + else: + return result.get("response", "") + + def prompt( + self, + id: str, + variables: Dict[str, str], + streaming: bool = False, + **kwargs: Any + ) -> Union[str, Iterator[str]]: + """Execute prompt with optional streaming""" + request = { + "id": id, + "variables": variables, + "streaming": streaming + } + request.update(kwargs) + + result = self.client._send_request_sync("prompt", self.flow_id, request, streaming) + + if streaming: + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + else: + return result.get("response", "") + + def graph_embeddings_query( + self, + text: str, + user: str, + collection: str, + limit: int = 10, + **kwargs: Any + ) -> Dict[str, Any]: + """Query graph embeddings for semantic search""" + request = { + "text": text, + "user": user, + "collection": collection, + "limit": limit + } + request.update(kwargs) + + return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False) + + def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]: + """Generate text embeddings""" + request = {"text": text} + request.update(kwargs) + + return self.client._send_request_sync("embeddings", self.flow_id, request, False) + + def triples_query( + self, + s: Optional[str] = None, + p: Optional[str] = None, + o: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + limit: int = 100, + **kwargs: Any + ) -> Dict[str, Any]: + """Triple pattern query""" + request = {"limit": limit} + if s is not None: + request["s"] = str(s) + if p is not None: + request["p"] = str(p) + if o is not None: + request["o"] = str(o) + if user is not None: + request["user"] = user + if collection is not None: + request["collection"] = collection + request.update(kwargs) + + return self.client._send_request_sync("triples", self.flow_id, request, False) + + def objects_query( + self, + query: str, + user: str, + collection: str, + variables: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + **kwargs: Any + ) -> Dict[str, Any]: + """GraphQL query""" + request = { + "query": query, + "user": user, + "collection": collection + } + if variables: + request["variables"] = variables + if operation_name: + request["operationName"] = operation_name + request.update(kwargs) + + return self.client._send_request_sync("objects", self.flow_id, request, False) + + def mcp_tool( + self, + name: str, + parameters: Dict[str, Any], + **kwargs: Any + ) -> Dict[str, Any]: + """Execute MCP tool""" + request = { + "name": name, + "parameters": parameters + } + request.update(kwargs) + + return self.client._send_request_sync("mcp-tool", self.flow_id, request, False) diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 71b438f6..a8608853 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -1,7 +1,7 @@ import dataclasses import datetime -from typing import List +from typing import List, Optional, Dict, Any from .. knowledge import hash, Uri, Literal @dataclasses.dataclass @@ -49,5 +49,34 @@ class CollectionMetadata: name : str description : str tags : List[str] - created_at : str - updated_at : str + +# Streaming chunk types + +@dataclasses.dataclass +class StreamingChunk: + """Base class for streaming chunks""" + content: str + end_of_message: bool = False + +@dataclasses.dataclass +class AgentThought(StreamingChunk): + """Agent reasoning chunk""" + chunk_type: str = "thought" + +@dataclasses.dataclass +class AgentObservation(StreamingChunk): + """Agent tool observation chunk""" + chunk_type: str = "observation" + +@dataclasses.dataclass +class AgentAnswer(StreamingChunk): + """Agent final answer chunk""" + chunk_type: str = "final-answer" + end_of_dialog: bool = False + +@dataclasses.dataclass +class RAGChunk(StreamingChunk): + """RAG streaming chunk""" + chunk_type: str = "rag" + end_of_stream: bool = False + error: Optional[Dict[str, str]] = None diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index b329f52e..e8530f6c 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -1,11 +1,12 @@ -from . pubsub import PulsarClient +from . pubsub import PulsarClient, get_pubsub from . async_processor import AsyncProcessor from . consumer import Consumer from . producer import Producer from . publisher import Publisher from . subscriber import Subscriber from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics +from . logging import add_logging_args, setup_logging from . flow_processor import FlowProcessor from . consumer_spec import ConsumerSpec from . parameter_spec import ParameterSpec @@ -33,4 +34,5 @@ from . tool_service import ToolService from . tool_client import ToolClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec +from . collection_config_handler import CollectionConfigHandler diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index e496da7c..8068c67d 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -15,10 +15,11 @@ from prometheus_client import start_http_server, Info from .. schema import ConfigPush, config_push_queue from .. log_level import LogLevel -from . pubsub import PulsarClient +from . pubsub import PulsarClient, get_pubsub from . producer import Producer from . consumer import Consumer from . metrics import ProcessorMetrics, ConsumerMetrics +from . logging import add_logging_args, setup_logging default_config_queue = config_push_queue @@ -33,8 +34,11 @@ class AsyncProcessor: # Store the identity self.id = params.get("id") - # Register a pulsar client - self.pulsar_client_object = PulsarClient(**params) + # Create pub/sub backend via factory + self.pubsub_backend = get_pubsub(**params) + + # Store pulsar_host for backward compatibility + self._pulsar_host = params.get("pulsar_host", "pulsar://pulsar:6650") # Initialise metrics, records the parameters ProcessorMetrics(processor = self.id).info({ @@ -69,7 +73,7 @@ class AsyncProcessor: self.config_sub_task = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub_backend, # Changed from client to backend subscriber = config_subscriber_id, flow = None, @@ -95,16 +99,16 @@ class AsyncProcessor: # This is called to stop all threads. An over-ride point for extra # functionality def stop(self): - self.pulsar_client.close() + self.pubsub_backend.close() self.running = False - # Returns the pulsar host + # Returns the pub/sub backend (new interface) @property - def pulsar_host(self): return self.pulsar_client_object.pulsar_host + def pubsub(self): return self.pubsub_backend - # Returns the pulsar client + # Returns the pulsar host (backward compatibility) @property - def pulsar_client(self): return self.pulsar_client_object.client + def pulsar_host(self): return self._pulsar_host # Register a new event handler for configuration change def register_config_handler(self, handler): @@ -165,18 +169,9 @@ class AsyncProcessor: raise e @classmethod - def setup_logging(cls, log_level='INFO'): + def setup_logging(cls, args): """Configure logging for the entire application""" - # Support environment variable override - env_log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', log_level) - - # Configure logging - logging.basicConfig( - level=getattr(logging, env_log_level.upper()), - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] - ) - logger.info(f"Logging configured with level: {env_log_level}") + setup_logging(args) # Startup fabric. launch calls launch_async in async mode. @classmethod @@ -203,7 +198,7 @@ class AsyncProcessor: args = vars(args) # Setup logging before anything else - cls.setup_logging(args.get('log_level', 'INFO').upper()) + cls.setup_logging(args) # Debug logger.debug(f"Arguments: {args}") @@ -255,12 +250,21 @@ class AsyncProcessor: @staticmethod def add_args(parser): + # Pub/sub backend selection + parser.add_argument( + '--pubsub-backend', + default=os.getenv('PUBSUB_BACKEND', 'pulsar'), + choices=['pulsar', 'mqtt'], + help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', + ) + PulsarClient.add_args(parser) + add_logging_args(parser) parser.add_argument( - '--config-queue', + '--config-push-queue', default=default_config_queue, - help=f'Config push queue {default_config_queue}', + help=f'Config push queue (default: {default_config_queue})', ) parser.add_argument( diff --git a/trustgraph-base/trustgraph/base/backend.py b/trustgraph-base/trustgraph/base/backend.py new file mode 100644 index 00000000..b9f5f923 --- /dev/null +++ b/trustgraph-base/trustgraph/base/backend.py @@ -0,0 +1,148 @@ +""" +Backend abstraction interfaces for pub/sub systems. + +This module defines Protocol classes that all pub/sub backends must implement, +allowing TrustGraph to work with different messaging systems (Pulsar, MQTT, Kafka, etc.) +""" + +from typing import Protocol, Any, runtime_checkable + + +@runtime_checkable +class Message(Protocol): + """Protocol for a received message.""" + + def value(self) -> Any: + """ + Get the deserialized message content. + + Returns: + Dataclass instance representing the message + """ + ... + + def properties(self) -> dict: + """ + Get message properties/metadata. + + Returns: + Dictionary of message properties + """ + ... + + +@runtime_checkable +class BackendProducer(Protocol): + """Protocol for backend-specific producer.""" + + def send(self, message: Any, properties: dict = {}) -> None: + """ + Send a message (dataclass instance) with optional properties. + + Args: + message: Dataclass instance to send + properties: Optional metadata properties + """ + ... + + def flush(self) -> None: + """Flush any buffered messages.""" + ... + + def close(self) -> None: + """Close the producer.""" + ... + + +@runtime_checkable +class BackendConsumer(Protocol): + """Protocol for backend-specific consumer.""" + + def receive(self, timeout_millis: int = 2000) -> Message: + """ + Receive a message from the topic. + + Args: + timeout_millis: Timeout in milliseconds + + Returns: + Message object + + Raises: + TimeoutError: If no message received within timeout + """ + ... + + def acknowledge(self, message: Message) -> None: + """ + Acknowledge successful processing of a message. + + Args: + message: The message to acknowledge + """ + ... + + def negative_acknowledge(self, message: Message) -> None: + """ + Negative acknowledge - triggers redelivery. + + Args: + message: The message to negatively acknowledge + """ + ... + + def unsubscribe(self) -> None: + """Unsubscribe from the topic.""" + ... + + def close(self) -> None: + """Close the consumer.""" + ... + + +@runtime_checkable +class PubSubBackend(Protocol): + """Protocol defining the interface all pub/sub backends must implement.""" + + def create_producer(self, topic: str, schema: type, **options) -> BackendProducer: + """ + Create a producer for a topic. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + schema: Dataclass type for messages + **options: Backend-specific options (e.g., chunking_enabled) + + Returns: + Backend-specific producer instance + """ + ... + + def create_consumer( + self, + topic: str, + subscription: str, + schema: type, + initial_position: str = 'latest', + consumer_type: str = 'shared', + **options + ) -> BackendConsumer: + """ + Create a consumer for a topic. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + subscription: Subscription/consumer group name + schema: Dataclass type for messages + initial_position: 'earliest' or 'latest' (some backends may ignore) + consumer_type: 'shared', 'exclusive', 'failover' (some backends may ignore) + **options: Backend-specific options + + Returns: + Backend-specific consumer instance + """ + ... + + def close(self) -> None: + """Close the backend connection.""" + ... diff --git a/trustgraph-base/trustgraph/base/cassandra_config.py b/trustgraph-base/trustgraph/base/cassandra_config.py index 46a1745d..bacc4313 100644 --- a/trustgraph-base/trustgraph/base/cassandra_config.py +++ b/trustgraph-base/trustgraph/base/cassandra_config.py @@ -13,14 +13,15 @@ from typing import Optional, Tuple, List, Any def get_cassandra_defaults() -> dict: """ Get default Cassandra configuration values from environment variables or fallback defaults. - + Returns: - dict: Dictionary with 'host', 'username', and 'password' keys + dict: Dictionary with 'host', 'username', 'password', and 'keyspace' keys """ return { 'host': os.getenv('CASSANDRA_HOST', 'cassandra'), 'username': os.getenv('CASSANDRA_USERNAME'), - 'password': os.getenv('CASSANDRA_PASSWORD') + 'password': os.getenv('CASSANDRA_PASSWORD'), + 'keyspace': os.getenv('CASSANDRA_KEYSPACE') } @@ -53,82 +54,108 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None: password_help += " (default: )" if 'CASSANDRA_PASSWORD' in os.environ: password_help += " [from CASSANDRA_PASSWORD]" - + + keyspace_help = "Cassandra keyspace (default: service-specific)" + if defaults['keyspace']: + keyspace_help = f"Cassandra keyspace (default: {defaults['keyspace']})" + if 'CASSANDRA_KEYSPACE' in os.environ: + keyspace_help += " [from CASSANDRA_KEYSPACE]" + parser.add_argument( '--cassandra-host', default=defaults['host'], help=host_help ) - + parser.add_argument( '--cassandra-username', default=defaults['username'], help=username_help ) - + parser.add_argument( '--cassandra-password', default=defaults['password'], help=password_help ) + parser.add_argument( + '--cassandra-keyspace', + default=defaults['keyspace'], + help=keyspace_help + ) + def resolve_cassandra_config( args: Optional[Any] = None, host: Optional[str] = None, username: Optional[str] = None, - password: Optional[str] = None -) -> Tuple[List[str], Optional[str], Optional[str]]: + password: Optional[str] = None, + default_keyspace: Optional[str] = None +) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]: """ Resolve Cassandra configuration from various sources. - + Can accept either argparse args object or explicit parameters. Converts host string to list format for Cassandra driver. - + Args: - args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password + args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace host: Optional explicit host parameter (overrides args) username: Optional explicit username parameter (overrides args) password: Optional explicit password parameter (overrides args) - + default_keyspace: Optional default keyspace if not specified elsewhere + Returns: - tuple: (hosts_list, username, password) + tuple: (hosts_list, username, password, keyspace) """ # If args provided, extract values + keyspace = None if args is not None: host = host or getattr(args, 'cassandra_host', None) username = username or getattr(args, 'cassandra_username', None) password = password or getattr(args, 'cassandra_password', None) - + keyspace = getattr(args, 'cassandra_keyspace', None) + # Apply defaults if still None defaults = get_cassandra_defaults() host = host or defaults['host'] username = username or defaults['username'] password = password or defaults['password'] - + keyspace = keyspace or defaults['keyspace'] or default_keyspace + # Convert host string to list if isinstance(host, str): hosts = [h.strip() for h in host.split(',') if h.strip()] else: hosts = host - - return hosts, username, password + + return hosts, username, password, keyspace -def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[str], Optional[str]]: +def get_cassandra_config_from_params( + params: dict, + default_keyspace: Optional[str] = None +) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]: """ Extract and resolve Cassandra configuration from a parameters dictionary. - + Args: params: Dictionary of parameters that may contain Cassandra configuration - + default_keyspace: Optional default keyspace if not specified in params + Returns: - tuple: (hosts_list, username, password) + tuple: (hosts_list, username, password, keyspace) """ # Get Cassandra parameters host = params.get('cassandra_host') username = params.get('cassandra_username') password = params.get('cassandra_password') - + # Use resolve function to handle defaults and list conversion - return resolve_cassandra_config(host=host, username=username, password=password) \ No newline at end of file + return resolve_cassandra_config( + host=host, + username=username, + password=password, + default_keyspace=default_keyspace + ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/base/collection_config_handler.py b/trustgraph-base/trustgraph/base/collection_config_handler.py new file mode 100644 index 00000000..8c1af822 --- /dev/null +++ b/trustgraph-base/trustgraph/base/collection_config_handler.py @@ -0,0 +1,128 @@ +""" +Handler for storage services to process collection configuration from config push +""" + +import json +import logging +from typing import Dict, Set + +logger = logging.getLogger(__name__) + +class CollectionConfigHandler: + """ + Handles collection configuration from config push messages for storage services. + + Storage services should: + 1. Inherit from this class along with their service base class + 2. Call register_config_handler(self.on_collection_config) in __init__ + 3. Implement create_collection(user, collection, metadata) method + 4. Implement delete_collection(user, collection) method + """ + + def __init__(self, **kwargs): + # Track known collections: {(user, collection): metadata_dict} + self.known_collections: Dict[tuple, dict] = {} + # Pass remaining kwargs up the inheritance chain + super().__init__(**kwargs) + + async def on_collection_config(self, config: dict, version: int): + """ + Handle config push messages and extract collection information + + Args: + config: Configuration dictionary from ConfigPush message + version: Configuration version number + """ + logger.info(f"Processing collection configuration (version {version})") + + # Extract collections from config (treat missing key as empty) + collection_config = config.get("collection", {}) + + # Track which collections we've seen in this config + current_collections: Set[tuple] = set() + + # Process each collection in the config + for key, value_json in collection_config.items(): + try: + # Parse user:collection key + if ":" not in key: + logger.warning(f"Invalid collection key format (expected user:collection): {key}") + continue + + user, collection = key.split(":", 1) + current_collections.add((user, collection)) + + # Parse metadata + metadata = json.loads(value_json) + + # Check if this is a new collection or updated + collection_key = (user, collection) + if collection_key not in self.known_collections: + logger.info(f"New collection detected: {user}/{collection}") + await self.create_collection(user, collection, metadata) + self.known_collections[collection_key] = metadata + else: + # Collection already exists, update metadata if changed + if self.known_collections[collection_key] != metadata: + logger.info(f"Collection metadata updated: {user}/{collection}") + # Most storage services don't need to do anything for metadata updates + # They just need to know the collection exists + self.known_collections[collection_key] = metadata + + except Exception as e: + logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True) + + # Find collections that were deleted (in known but not in current) + deleted_collections = set(self.known_collections.keys()) - current_collections + for user, collection in deleted_collections: + logger.info(f"Collection deleted: {user}/{collection}") + try: + # Remove from known_collections FIRST to immediately reject new writes + # This eliminates race condition with worker threads + del self.known_collections[(user, collection)] + # Physical deletion happens after - worker threads already rejecting writes + await self.delete_collection(user, collection) + except Exception as e: + logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True) + # If physical deletion failed, should we re-add to known_collections? + # For now, keep it removed - collection is logically deleted per config + + logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}") + + async def create_collection(self, user: str, collection: str, metadata: dict): + """ + Create a collection in the storage backend. + + Subclasses must implement this method. + + Args: + user: User ID + collection: Collection ID + metadata: Collection metadata dictionary + """ + raise NotImplementedError("Storage service must implement create_collection method") + + async def delete_collection(self, user: str, collection: str): + """ + Delete a collection from the storage backend. + + Subclasses must implement this method. + + Args: + user: User ID + collection: Collection ID + """ + raise NotImplementedError("Storage service must implement delete_collection method") + + def collection_exists(self, user: str, collection: str) -> bool: + """ + Check if a collection is known to exist + + Args: + user: User ID + collection: Collection ID + + Returns: + True if collection exists, False otherwise + """ + return (user, collection) in self.known_collections diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 43b4bc51..2a220312 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -9,9 +9,6 @@ # one handler, and a single thread of concurrency, nothing too outrageous # will happen if synchronous / blocking code is used -from pulsar.schema import JsonSchema -import pulsar -import _pulsar import asyncio import time import logging @@ -21,11 +18,15 @@ from .. exceptions import TooManyRequests # Module logger logger = logging.getLogger(__name__) +# Timeout exception - can come from different backends +class TimeoutError(Exception): + pass + class Consumer: def __init__( - self, taskgroup, flow, client, topic, subscriber, schema, - handler, + self, taskgroup, flow, backend, topic, subscriber, schema, + handler, metrics = None, start_of_messages=False, rate_limit_retry_time = 10, rate_limit_timeout = 7200, @@ -35,7 +36,7 @@ class Consumer: self.taskgroup = taskgroup self.flow = flow - self.client = client + self.backend = backend # Changed from 'client' to 'backend' self.topic = topic self.subscriber = subscriber self.schema = schema @@ -96,18 +97,20 @@ class Consumer: logger.info(f"Subscribing to topic: {self.topic}") + # Determine initial position if self.start_of_messages: - pos = pulsar.InitialPosition.Earliest + initial_pos = 'earliest' else: - pos = pulsar.InitialPosition.Latest + initial_pos = 'latest' + # Create consumer via backend self.consumer = await asyncio.to_thread( - self.client.subscribe, + self.backend.create_consumer, topic = self.topic, - subscription_name = self.subscriber, - schema = JsonSchema(self.schema), - initial_position = pos, - consumer_type = pulsar.ConsumerType.Shared, + subscription = self.subscriber, + schema = self.schema, + initial_position = initial_pos, + consumer_type = 'shared', ) except Exception as e: @@ -159,9 +162,10 @@ class Consumer: self.consumer.receive, timeout_millis=2000 ) - except _pulsar.Timeout: - continue except Exception as e: + # Handle timeout from any backend + if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower(): + continue raise e await self.handle_one_from_queue(msg) diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py index 89581b02..0ef4672b 100644 --- a/trustgraph-base/trustgraph/base/consumer_spec.py +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -19,7 +19,7 @@ class ConsumerSpec(Spec): consumer = Consumer( taskgroup = processor.taskgroup, flow = flow, - client = processor.pulsar_client, + backend = processor.pubsub, topic = definition[self.name], subscriber = processor.id + "--" + flow.name + "--" + self.name, schema = self.schema, diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py new file mode 100644 index 00000000..7bab6091 --- /dev/null +++ b/trustgraph-base/trustgraph/base/logging.py @@ -0,0 +1,159 @@ + +""" +Centralized logging configuration for TrustGraph server-side components. + +This module provides standardized logging setup across all TrustGraph services, +ensuring consistent log formats, levels, and command-line arguments. + +Supports dual output to console and Loki for centralized log aggregation. +""" + +import logging +import logging.handlers +from queue import Queue +import os + + +def add_logging_args(parser): + """ + Add standard logging arguments to an argument parser. + + Args: + parser: argparse.ArgumentParser instance to add arguments to + """ + parser.add_argument( + '-l', '--log-level', + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Log level (default: INFO)' + ) + + parser.add_argument( + '--loki-enabled', + action='store_true', + default=True, + help='Enable Loki logging (default: True)' + ) + + parser.add_argument( + '--no-loki-enabled', + dest='loki_enabled', + action='store_false', + help='Disable Loki logging' + ) + + parser.add_argument( + '--loki-url', + default=os.getenv('LOKI_URL', 'http://loki:3100/loki/api/v1/push'), + help='Loki push URL (default: http://loki:3100/loki/api/v1/push)' + ) + + parser.add_argument( + '--loki-username', + default=os.getenv('LOKI_USERNAME', None), + help='Loki username for authentication (optional)' + ) + + parser.add_argument( + '--loki-password', + default=os.getenv('LOKI_PASSWORD', None), + help='Loki password for authentication (optional)' + ) + + +def setup_logging(args): + """ + Configure logging from parsed command-line arguments. + + Sets up logging with a standardized format and output to stdout. + Optionally enables Loki integration for centralized log aggregation. + + This should be called early in application startup, before any + logging calls are made. + + Args: + args: Dictionary of parsed arguments (typically from vars(args)) + Must contain 'log_level' key, optional Loki configuration + """ + log_level = args.get('log_level', 'INFO') + loki_enabled = args.get('loki_enabled', True) + + # Build list of handlers starting with console + handlers = [logging.StreamHandler()] + + # Add Loki handler if enabled + queue_listener = None + if loki_enabled: + loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push') + loki_username = args.get('loki_username') + loki_password = args.get('loki_password') + processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion") + + try: + from logging_loki import LokiHandler + + # Create Loki handler with optional authentication and processor label + loki_handler_kwargs = { + 'url': loki_url, + 'version': "1", + } + + if loki_username and loki_password: + loki_handler_kwargs['auth'] = (loki_username, loki_password) + + # Add processor label if available (for consistency with Prometheus metrics) + if processor_id: + loki_handler_kwargs['tags'] = {'processor': processor_id} + + loki_handler = LokiHandler(**loki_handler_kwargs) + + # Wrap in QueueHandler for non-blocking operation + log_queue = Queue(maxsize=500) + queue_handler = logging.handlers.QueueHandler(log_queue) + handlers.append(queue_handler) + + # Start QueueListener in background thread + queue_listener = logging.handlers.QueueListener( + log_queue, + loki_handler, + respect_handler_level=True + ) + queue_listener.start() + + # Store listener reference for potential cleanup + # (attached to root logger for access if needed) + logging.getLogger().loki_queue_listener = queue_listener + + except ImportError: + # Graceful degradation if python-logging-loki not installed + print("WARNING: python-logging-loki not installed, Loki logging disabled") + print("Install with: pip install python-logging-loki") + except Exception as e: + # Graceful degradation if Loki connection fails + print(f"WARNING: Failed to setup Loki logging: {e}") + print("Continuing with console-only logging") + + # Get processor ID for log formatting (use 'unknown' if not available) + processor_id = args.get('id', 'unknown') + + # Configure logging with all handlers + # Use processor ID as the primary identifier in logs + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s', + handlers=handlers, + force=True # Force reconfiguration if already configured + ) + + # Prevent recursive logging from Loki's HTTP client + if loki_enabled and queue_listener: + # Disable urllib3 logging to prevent infinite loop + logging.getLogger('urllib3').setLevel(logging.WARNING) + logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING) + + logger = logging.getLogger(__name__) + logger.info(f"Logging configured with level: {log_level}") + if loki_enabled and queue_listener: + logger.info(f"Loki logging enabled: {loki_url}") + elif loki_enabled: + logger.warning("Loki logging requested but not available") diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py index 0d65d1de..20b4b0d6 100644 --- a/trustgraph-base/trustgraph/base/producer.py +++ b/trustgraph-base/trustgraph/base/producer.py @@ -1,5 +1,4 @@ -from pulsar.schema import JsonSchema import asyncio import logging @@ -8,10 +7,10 @@ logger = logging.getLogger(__name__) class Producer: - def __init__(self, client, topic, schema, metrics=None, + def __init__(self, backend, topic, schema, metrics=None, chunking_enabled=True): - self.client = client + self.backend = backend # Changed from 'client' to 'backend' self.topic = topic self.schema = schema @@ -44,9 +43,9 @@ class Producer: try: logger.info(f"Connecting publisher to {self.topic}...") - self.producer = self.client.create_producer( + self.producer = self.backend.create_producer( topic = self.topic, - schema = JsonSchema(self.schema), + schema = self.schema, chunking_enabled = self.chunking_enabled, ) logger.info(f"Connected publisher to {self.topic}") diff --git a/trustgraph-base/trustgraph/base/producer_spec.py b/trustgraph-base/trustgraph/base/producer_spec.py index 9c8bbc6a..cf46b958 100644 --- a/trustgraph-base/trustgraph/base/producer_spec.py +++ b/trustgraph-base/trustgraph/base/producer_spec.py @@ -15,7 +15,7 @@ class ProducerSpec(Spec): ) producer = Producer( - client = processor.pulsar_client, + backend = processor.pubsub, topic = definition[self.name], schema = self.schema, metrics = producer_metrics, diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 307a118a..74b25132 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -37,33 +37,34 @@ class PromptClient(RequestResponse): else: logger.info("DEBUG prompt_client: Streaming path") - # Streaming path - collect all chunks - full_text = "" - full_object = None + # Streaming path - just forward chunks, don't accumulate + last_text = "" + last_object = None - async def collect_chunks(resp): - nonlocal full_text, full_object - logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}") + async def forward_chunks(resp): + nonlocal last_text, last_object + logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}") if resp.error: logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}") raise RuntimeError(resp.error.message) - if resp.text: - full_text += resp.text - logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars") - # Call chunk callback if provided + end_stream = getattr(resp, 'end_of_stream', False) + + # Always call callback if there's text OR if it's the final message + if resp.text is not None: + last_text = resp.text + # Call chunk callback if provided with both chunk and end_of_stream flag if chunk_callback: - logger.info(f"DEBUG prompt_client: Calling chunk_callback") + logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}") if asyncio.iscoroutinefunction(chunk_callback): - await chunk_callback(resp.text) + await chunk_callback(resp.text, end_stream) else: - chunk_callback(resp.text) + chunk_callback(resp.text, end_stream) elif resp.object: logger.info(f"DEBUG prompt_client: Got object response") - full_object = resp.object + last_object = resp.object - end_stream = getattr(resp, 'end_of_stream', False) logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}") return end_stream @@ -79,17 +80,17 @@ class PromptClient(RequestResponse): logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}") await self.request( req, - recipient=collect_chunks, + recipient=forward_chunks, timeout=timeout ) - logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars") + logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}") - if full_text: - logger.info("DEBUG prompt_client: Returning full_text") - return full_text + if last_text: + logger.info("DEBUG prompt_client: Returning last_text") + return last_text - logger.info("DEBUG prompt_client: Returning parsed full_object") - return json.loads(full_object) + logger.info("DEBUG prompt_client: Returning parsed last_object") + return json.loads(last_object) if last_object else None async def extract_definitions(self, text, timeout=600): return await self.prompt( diff --git a/trustgraph-base/trustgraph/base/publisher.py b/trustgraph-base/trustgraph/base/publisher.py index 5a481f82..0297d2b5 100644 --- a/trustgraph-base/trustgraph/base/publisher.py +++ b/trustgraph-base/trustgraph/base/publisher.py @@ -1,9 +1,6 @@ -from pulsar.schema import JsonSchema - import asyncio import time -import pulsar import logging # Module logger @@ -11,9 +8,9 @@ logger = logging.getLogger(__name__) class Publisher: - def __init__(self, client, topic, schema=None, max_size=10, + def __init__(self, backend, topic, schema=None, max_size=10, chunking_enabled=True, drain_timeout=5.0): - self.client = client + self.backend = backend # Changed from 'client' to 'backend' self.topic = topic self.schema = schema self.q = asyncio.Queue(maxsize=max_size) @@ -47,9 +44,9 @@ class Publisher: try: - producer = self.client.create_producer( + producer = self.backend.create_producer( topic=self.topic, - schema=JsonSchema(self.schema), + schema=self.schema, chunking_enabled=self.chunking_enabled, ) diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index 412363f2..a7772b67 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -4,8 +4,45 @@ import pulsar import _pulsar import uuid from pulsar.schema import JsonSchema +import logging from .. log_level import LogLevel +from .pulsar_backend import PulsarBackend + +logger = logging.getLogger(__name__) + + +def get_pubsub(**config): + """ + Factory function to create a pub/sub backend based on configuration. + + Args: + config: Configuration dictionary from command-line args + Must include 'pubsub_backend' key + + Returns: + Backend instance (PulsarBackend, MQTTBackend, etc.) + + Example: + backend = get_pubsub( + pubsub_backend='pulsar', + pulsar_host='pulsar://localhost:6650' + ) + """ + backend_type = config.get('pubsub_backend', 'pulsar') + + if backend_type == 'pulsar': + return PulsarBackend( + host=config.get('pulsar_host', PulsarClient.default_pulsar_host), + api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key), + listener=config.get('pulsar_listener'), + ) + elif backend_type == 'mqtt': + # TODO: Implement MQTT backend + raise NotImplementedError("MQTT backend not yet implemented") + else: + raise ValueError(f"Unknown pub/sub backend: {backend_type}") + class PulsarClient: @@ -71,10 +108,3 @@ class PulsarClient: '--pulsar-listener', help=f'Pulsar listener (default: none)', ) - - parser.add_argument( - '-l', '--log-level', - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - help=f'Log level (default: INFO)' - ) diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py new file mode 100644 index 00000000..c6248622 --- /dev/null +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -0,0 +1,350 @@ +""" +Pulsar backend implementation for pub/sub abstraction. + +This module provides a Pulsar-specific implementation of the backend interfaces, +handling topic mapping, serialization, and Pulsar client management. +""" + +import pulsar +import _pulsar +import json +import logging +import base64 +import types +from dataclasses import asdict, is_dataclass +from typing import Any + +from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message + +logger = logging.getLogger(__name__) + + +def dataclass_to_dict(obj: Any) -> dict: + """ + Recursively convert a dataclass to a dictionary, handling None values and bytes. + + None values are excluded from the dictionary (not serialized). + Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior). + Handles nested dataclasses, lists, and dictionaries recursively. + """ + if obj is None: + return None + + # Handle bytes - decode to UTF-8 for JSON serialization + if isinstance(obj, bytes): + return obj.decode('utf-8') + + # Handle dataclass - convert to dict then recursively process all values + if is_dataclass(obj): + result = {} + for key, value in asdict(obj).items(): + result[key] = dataclass_to_dict(value) if value is not None else None + return result + + # Handle list - recursively process all items + if isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + + # Handle dict - recursively process all values + if isinstance(obj, dict): + return {k: dataclass_to_dict(v) for k, v in obj.items()} + + # Return primitive types as-is + return obj + + +def dict_to_dataclass(data: dict, cls: type) -> Any: + """ + Convert a dictionary back to a dataclass instance. + + Handles nested dataclasses and missing fields. + """ + if data is None: + return None + + if not is_dataclass(cls): + return data + + # Get field types from the dataclass + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + + for key, value in data.items(): + if key in field_types: + field_type = field_types[key] + + # Handle modern union types (X | Y) + if isinstance(field_type, types.UnionType): + # Check if it's Optional (X | None) + if type(None) in field_type.__args__: + # Get the non-None type + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Check if this is a generic type (list, dict, etc.) + elif hasattr(field_type, '__origin__'): + # Handle list[T] + if field_type.__origin__ == list: + item_type = field_type.__args__[0] if field_type.__args__ else None + if item_type and is_dataclass(item_type) and isinstance(value, list): + kwargs[key] = [ + dict_to_dataclass(item, item_type) if isinstance(item, dict) else item + for item in value + ] + else: + kwargs[key] = value + # Handle old-style Optional[T] (which is Union[T, None]) + elif hasattr(field_type, '__args__') and type(None) in field_type.__args__: + # Get the non-None type from Union + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Handle direct dataclass fields + elif is_dataclass(field_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, field_type) + # Handle bytes fields (UTF-8 encoded strings from JSON) + elif field_type == bytes and isinstance(value, str): + kwargs[key] = value.encode('utf-8') + else: + kwargs[key] = value + + return cls(**kwargs) + + +class PulsarMessage: + """Wrapper for Pulsar messages to match Message protocol.""" + + def __init__(self, pulsar_msg, schema_cls): + self._msg = pulsar_msg + self._schema_cls = schema_cls + self._value = None + + def value(self) -> Any: + """Deserialize and return the message value as a dataclass.""" + if self._value is None: + # Get JSON string from Pulsar message + json_data = self._msg.data().decode('utf-8') + data_dict = json.loads(json_data) + # Convert to dataclass + self._value = dict_to_dataclass(data_dict, self._schema_cls) + return self._value + + def properties(self) -> dict: + """Return message properties.""" + return self._msg.properties() + + +class PulsarBackendProducer: + """Pulsar-specific producer implementation.""" + + def __init__(self, pulsar_producer, schema_cls): + self._producer = pulsar_producer + self._schema_cls = schema_cls + + def send(self, message: Any, properties: dict = {}) -> None: + """Send a dataclass message.""" + # Convert dataclass to dict, excluding None values + data_dict = dataclass_to_dict(message) + # Serialize to JSON + json_data = json.dumps(data_dict) + # Send via Pulsar + self._producer.send(json_data.encode('utf-8'), properties=properties) + + def flush(self) -> None: + """Flush buffered messages.""" + self._producer.flush() + + def close(self) -> None: + """Close the producer.""" + self._producer.close() + + +class PulsarBackendConsumer: + """Pulsar-specific consumer implementation.""" + + def __init__(self, pulsar_consumer, schema_cls): + self._consumer = pulsar_consumer + self._schema_cls = schema_cls + + def receive(self, timeout_millis: int = 2000) -> Message: + """Receive a message.""" + pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis) + return PulsarMessage(pulsar_msg, self._schema_cls) + + def acknowledge(self, message: Message) -> None: + """Acknowledge a message.""" + if isinstance(message, PulsarMessage): + self._consumer.acknowledge(message._msg) + + def negative_acknowledge(self, message: Message) -> None: + """Negative acknowledge a message.""" + if isinstance(message, PulsarMessage): + self._consumer.negative_acknowledge(message._msg) + + def unsubscribe(self) -> None: + """Unsubscribe from the topic.""" + self._consumer.unsubscribe() + + def close(self) -> None: + """Close the consumer.""" + self._consumer.close() + + +class PulsarBackend: + """ + Pulsar backend implementation. + + Handles topic mapping, client management, and creation of Pulsar-specific + producers and consumers. + """ + + def __init__(self, host: str, api_key: str = None, listener: str = None): + """ + Initialize Pulsar backend. + + Args: + host: Pulsar broker URL (e.g., pulsar://localhost:6650) + api_key: Optional API key for authentication + listener: Optional listener name for multi-homed setups + """ + self.host = host + self.api_key = api_key + self.listener = listener + + # Create Pulsar client + client_args = {'service_url': host} + + if listener: + client_args['listener_name'] = listener + + if api_key: + client_args['authentication'] = pulsar.AuthenticationToken(api_key) + + self.client = pulsar.Client(**client_args) + logger.info(f"Pulsar client connected to {host}") + + def map_topic(self, generic_topic: str) -> str: + """ + Map generic topic format to Pulsar URI. + + Format: qos/tenant/namespace/queue + Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue + + Args: + generic_topic: Generic topic string or already-formatted Pulsar URI + + Returns: + Pulsar topic URI + """ + # If already a Pulsar URI, return as-is + if '://' in generic_topic: + return generic_topic + + parts = generic_topic.split('/', 3) + if len(parts) != 4: + raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue") + + qos, tenant, namespace, queue = parts + + # Map QoS to persistence + if qos == 'q0': + persistence = 'non-persistent' + elif qos in ['q1', 'q2']: + persistence = 'persistent' + else: + raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2") + + return f"{persistence}://{tenant}/{namespace}/{queue}" + + def create_producer(self, topic: str, schema: type, **options) -> BackendProducer: + """ + Create a Pulsar producer. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + schema: Dataclass type for messages + **options: Backend-specific options (e.g., chunking_enabled) + + Returns: + PulsarBackendProducer instance + """ + pulsar_topic = self.map_topic(topic) + + producer_args = { + 'topic': pulsar_topic, + 'schema': pulsar.schema.BytesSchema(), # We handle serialization ourselves + } + + # Add optional parameters + if 'chunking_enabled' in options: + producer_args['chunking_enabled'] = options['chunking_enabled'] + + pulsar_producer = self.client.create_producer(**producer_args) + logger.debug(f"Created producer for topic: {pulsar_topic}") + + return PulsarBackendProducer(pulsar_producer, schema) + + def create_consumer( + self, + topic: str, + subscription: str, + schema: type, + initial_position: str = 'latest', + consumer_type: str = 'shared', + **options + ) -> BackendConsumer: + """ + Create a Pulsar consumer. + + Args: + topic: Generic topic format (qos/tenant/namespace/queue) + subscription: Subscription name + schema: Dataclass type for messages + initial_position: 'earliest' or 'latest' + consumer_type: 'shared', 'exclusive', or 'failover' + **options: Backend-specific options + + Returns: + PulsarBackendConsumer instance + """ + pulsar_topic = self.map_topic(topic) + + # Map initial position + if initial_position == 'earliest': + pos = pulsar.InitialPosition.Earliest + else: + pos = pulsar.InitialPosition.Latest + + # Map consumer type + if consumer_type == 'exclusive': + ctype = pulsar.ConsumerType.Exclusive + elif consumer_type == 'failover': + ctype = pulsar.ConsumerType.Failover + else: + ctype = pulsar.ConsumerType.Shared + + consumer_args = { + 'topic': pulsar_topic, + 'subscription_name': subscription, + 'schema': pulsar.schema.BytesSchema(), # We handle deserialization ourselves + 'initial_position': pos, + 'consumer_type': ctype, + } + + pulsar_consumer = self.client.subscribe(**consumer_args) + logger.debug(f"Created consumer for topic: {pulsar_topic}, subscription: {subscription}") + + return PulsarBackendConsumer(pulsar_consumer, schema) + + def close(self) -> None: + """Close the Pulsar client.""" + self.client.close() + logger.info("Pulsar client closed") diff --git a/trustgraph-base/trustgraph/base/request_response_spec.py b/trustgraph-base/trustgraph/base/request_response_spec.py index 82574e9d..e4c80c74 100644 --- a/trustgraph-base/trustgraph/base/request_response_spec.py +++ b/trustgraph-base/trustgraph/base/request_response_spec.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class RequestResponse(Subscriber): def __init__( - self, client, subscription, consumer_name, + self, backend, subscription, consumer_name, request_topic, request_schema, request_metrics, response_topic, response_schema, @@ -22,7 +22,7 @@ class RequestResponse(Subscriber): ): super(RequestResponse, self).__init__( - client = client, + backend = backend, subscription = subscription, consumer_name = consumer_name, topic = response_topic, @@ -31,7 +31,7 @@ class RequestResponse(Subscriber): ) self.producer = Producer( - client = client, + backend = backend, topic = request_topic, schema = request_schema, metrics = request_metrics, @@ -126,7 +126,7 @@ class RequestResponseSpec(Spec): ) rr = self.impl( - client = processor.pulsar_client, + backend = processor.pubsub, # Make subscription names unique, so that all subscribers get # to see all response messages diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 503fac80..d59bcab3 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -3,9 +3,7 @@ # off of a queue and make it available using an internal broker system, # so suitable for when multiple recipients are reading from the same queue -from pulsar.schema import JsonSchema import asyncio -import _pulsar import time import logging import uuid @@ -13,12 +11,16 @@ import uuid # Module logger logger = logging.getLogger(__name__) +# Timeout exception - can come from different backends +class TimeoutError(Exception): + pass + class Subscriber: - def __init__(self, client, topic, subscription, consumer_name, + def __init__(self, backend, topic, subscription, consumer_name, schema=None, max_size=100, metrics=None, backpressure_strategy="block", drain_timeout=5.0): - self.client = client + self.backend = backend # Changed from 'client' to 'backend' self.topic = topic self.subscription = subscription self.consumer_name = consumer_name @@ -43,18 +45,14 @@ class Subscriber: async def start(self): - # Build subscribe arguments - subscribe_args = { - 'topic': self.topic, - 'subscription_name': self.subscription, - 'consumer_name': self.consumer_name, - } - - # Only add schema if provided (omit if None) - if self.schema is not None: - subscribe_args['schema'] = JsonSchema(self.schema) - - self.consumer = self.client.subscribe(**subscribe_args) + # Create consumer via backend + self.consumer = await asyncio.to_thread( + self.backend.create_consumer, + topic=self.topic, + subscription=self.subscription, + schema=self.schema, + consumer_type='shared', + ) self.task = asyncio.create_task(self.run()) @@ -94,12 +92,13 @@ class Subscriber: drain_end_time = time.time() + self.drain_timeout logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s") - # Stop accepting new messages from Pulsar during drain - if self.consumer: + # Stop accepting new messages during drain + # Note: Not all backends support pausing message listeners + if self.consumer and hasattr(self.consumer, 'pause_message_listener'): try: self.consumer.pause_message_listener() - except _pulsar.InvalidConfiguration: - # Not all consumers have message listeners (e.g., blocking receive mode) + except Exception: + # Not all consumers support message listeners pass # Check drain timeout @@ -133,9 +132,10 @@ class Subscriber: self.consumer.receive, timeout_millis=250 ) - except _pulsar.Timeout: - continue except Exception as e: + # Handle timeout from any backend + if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower(): + continue logger.error(f"Exception in subscriber receive: {e}", exc_info=True) raise e @@ -157,19 +157,20 @@ class Subscriber: for msg in self.pending_acks.values(): try: self.consumer.negative_acknowledge(msg) - except _pulsar.AlreadyClosed: - pass # Consumer already closed + except Exception: + pass # Consumer already closed or error self.pending_acks.clear() if self.consumer: - try: - self.consumer.unsubscribe() - except _pulsar.AlreadyClosed: - pass # Already closed + if hasattr(self.consumer, 'unsubscribe'): + try: + self.consumer.unsubscribe() + except Exception: + pass # Already closed or error try: self.consumer.close() - except _pulsar.AlreadyClosed: - pass # Already closed + except Exception: + pass # Already closed or error self.consumer = None diff --git a/trustgraph-base/trustgraph/base/subscriber_spec.py b/trustgraph-base/trustgraph/base/subscriber_spec.py index 7dca09db..b408366c 100644 --- a/trustgraph-base/trustgraph/base/subscriber_spec.py +++ b/trustgraph-base/trustgraph/base/subscriber_spec.py @@ -16,7 +16,7 @@ class SubscriberSpec(Spec): ) subscriber = Subscriber( - client = processor.pulsar_client, + backend = processor.pubsub, topic = definition[self.name], subscription = flow.id, consumer_name = flow.id, diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index 25eac3b7..3a4da6ec 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -7,6 +7,7 @@ import time from pulsar.schema import JsonSchema from .. exceptions import * +from ..base.pubsub import get_pubsub # Default timeout for a request/response. In seconds. DEFAULT_TIMEOUT=300 @@ -39,30 +40,25 @@ class BaseClient: if subscriber == None: subscriber = str(uuid.uuid4()) - if pulsar_api_key: - auth = pulsar.AuthenticationToken(pulsar_api_key) - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - authentication=auth, - listener=listener, - ) - else: - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level), - listener_name=listener, - ) + # Create backend using factory + self.backend = get_pubsub( + pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, + pulsar_listener=listener, + pubsub_backend='pulsar' + ) - self.producer = self.client.create_producer( + self.producer = self.backend.create_producer( topic=input_queue, - schema=JsonSchema(input_schema), + schema=input_schema, chunking_enabled=True, ) - self.consumer = self.client.subscribe( - output_queue, subscriber, - schema=JsonSchema(output_schema), + self.consumer = self.backend.create_consumer( + topic=output_queue, + subscription=subscriber, + schema=output_schema, + consumer_type='shared', ) self.input_schema = input_schema @@ -136,10 +132,11 @@ class BaseClient: if hasattr(self, "consumer"): self.consumer.close() - + if hasattr(self, "producer"): self.producer.flush() self.producer.close() - - self.client.close() + + if hasattr(self, "backend"): + self.backend.close() diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index ed8c704a..be2bf5b9 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -64,7 +64,6 @@ class ConfigClient(BaseClient): def get(self, keys, timeout=300): resp = self.call( - id=id, operation="get", keys=[ ConfigKey( @@ -88,7 +87,6 @@ class ConfigClient(BaseClient): def list(self, type, timeout=300): resp = self.call( - id=id, operation="list", type=type, timeout=timeout @@ -99,7 +97,6 @@ class ConfigClient(BaseClient): def getvalues(self, type, timeout=300): resp = self.call( - id=id, operation="getvalues", type=type, timeout=timeout @@ -117,7 +114,6 @@ class ConfigClient(BaseClient): def delete(self, keys, timeout=300): resp = self.call( - id=id, operation="delete", keys=[ ConfigKey( @@ -134,7 +130,6 @@ class ConfigClient(BaseClient): def put(self, values, timeout=300): resp = self.call( - id=id, operation="put", values=[ ConfigValue( @@ -152,7 +147,6 @@ class ConfigClient(BaseClient): def config(self, timeout=300): resp = self.call( - id=id, operation="config", timeout=timeout ) diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index 38ac813b..22c82828 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -15,8 +15,6 @@ class CollectionManagementRequestTranslator(MessageTranslator): name=data.get("name"), description=data.get("description"), tags=data.get("tags"), - created_at=data.get("created_at"), - updated_at=data.get("updated_at"), tag_filter=data.get("tag_filter"), limit=data.get("limit") ) @@ -38,10 +36,6 @@ class CollectionManagementRequestTranslator(MessageTranslator): result["description"] = obj.description if obj.tags is not None: result["tags"] = list(obj.tags) - if obj.created_at is not None: - result["created_at"] = obj.created_at - if obj.updated_at is not None: - result["updated_at"] = obj.updated_at if obj.tag_filter is not None: result["tag_filter"] = list(obj.tag_filter) if obj.limit is not None: @@ -73,9 +67,7 @@ class CollectionManagementResponseTranslator(MessageTranslator): collection=coll_data.get("collection"), name=coll_data.get("name"), description=coll_data.get("description"), - tags=coll_data.get("tags"), - created_at=coll_data.get("created_at"), - updated_at=coll_data.get("updated_at") + tags=coll_data.get("tags", []) )) return CollectionManagementResponse( @@ -104,9 +96,7 @@ class CollectionManagementResponseTranslator(MessageTranslator): "collection": coll.collection, "name": coll.name, "description": coll.description, - "tags": list(coll.tags) if coll.tags else [], - "created_at": coll.created_at, - "updated_at": coll.updated_at + "tags": list(coll.tags) if coll.tags else [] }) print("RESULT IS", result, flush=True) diff --git a/trustgraph-base/trustgraph/messaging/translators/diagnosis.py b/trustgraph-base/trustgraph/messaging/translators/diagnosis.py index 92bad16f..e0cb6a89 100644 --- a/trustgraph-base/trustgraph/messaging/translators/diagnosis.py +++ b/trustgraph-base/trustgraph/messaging/translators/diagnosis.py @@ -57,7 +57,9 @@ class StructuredDataDiagnosisResponseTranslator(MessageTranslator): result["descriptor"] = obj.descriptor if obj.metadata: result["metadata"] = obj.metadata - if obj.schema_matches is not None: + # For schema-selection, always include schema_matches (even if empty) + # For other operations, only include if non-empty + if obj.operation == "schema-selection" or obj.schema_matches: result["schema-matches"] = obj.schema_matches return result diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py index 8916a77c..5ff99fdc 100644 --- a/trustgraph-base/trustgraph/messaging/translators/prompt.py +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator): def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: result = {} - - if obj.text: + + # Include text field if present (even if empty string) + if obj.text is not None: result["text"] = obj.text - if obj.object: + # Include object field if present + if obj.object is not None: result["object"] = obj.object - + + # Always include end_of_stream flag for streaming support + result["end_of_stream"] = getattr(obj, "end_of_stream", False) + return result def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 441a9d18..22166bd9 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -34,14 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: result = {} - # Check if this is a streaming response (has chunk) - if hasattr(obj, 'chunk') and obj.chunk: - result["chunk"] = obj.chunk - result["end_of_stream"] = getattr(obj, "end_of_stream", False) - else: - # Non-streaming response - if obj.response: - result["response"] = obj.response + # Include response content (even if empty string) + if obj.response is not None: + result["response"] = obj.response + + # Include end_of_stream flag + result["end_of_stream"] = getattr(obj, "end_of_stream", False) # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: @@ -51,13 +49,7 @@ class DocumentRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - # For streaming responses, check end_of_stream - if hasattr(obj, 'chunk') and obj.chunk: - is_final = getattr(obj, 'end_of_stream', False) - else: - # For non-streaming responses, it's always final - is_final = True - + is_final = getattr(obj, 'end_of_stream', False) return self.from_pulsar(obj), is_final @@ -98,14 +90,12 @@ class GraphRagResponseTranslator(MessageTranslator): def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: result = {} - # Check if this is a streaming response (has chunk) - if hasattr(obj, 'chunk') and obj.chunk: - result["chunk"] = obj.chunk - result["end_of_stream"] = getattr(obj, "end_of_stream", False) - else: - # Non-streaming response - if obj.response: - result["response"] = obj.response + # Include response content (even if empty string) + if obj.response is not None: + result["response"] = obj.response + + # Include end_of_stream flag + result["end_of_stream"] = getattr(obj, "end_of_stream", False) # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: @@ -115,11 +105,5 @@ class GraphRagResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - # For streaming responses, check end_of_stream - if hasattr(obj, 'chunk') and obj.chunk: - is_final = getattr(obj, 'end_of_stream', False) - else: - # For non-streaming responses, it's always final - is_final = True - + is_final = getattr(obj, 'end_of_stream', False) return self.from_pulsar(obj), is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py index b4ba4d13..fa3749b5 100644 --- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -28,14 +28,17 @@ class TextCompletionResponseTranslator(MessageTranslator): def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]: result = {"response": obj.response} - + if obj.in_token: result["in_token"] = obj.in_token if obj.out_token: result["out_token"] = obj.out_token if obj.model: result["model"] = obj.model - + + # Always include end_of_stream flag for streaming support + result["end_of_stream"] = getattr(obj, "end_of_stream", False) + return result def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index cb2022ac..1888e612 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -1,16 +1,14 @@ - -from pulsar.schema import Record, String, Array +from dataclasses import dataclass, field from .primitives import Triple -class Metadata(Record): - +@dataclass +class Metadata: # Source identifier - id = String() + id: str = "" # Subgraph - metadata = Array(Triple()) + metadata: list[Triple] = field(default_factory=list) # Collection management - user = String() - collection = String() - + user: str = "" + collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/core/primitives.py b/trustgraph-base/trustgraph/schema/core/primitives.py index fb85d05c..02517614 100644 --- a/trustgraph-base/trustgraph/schema/core/primitives.py +++ b/trustgraph-base/trustgraph/schema/core/primitives.py @@ -1,34 +1,39 @@ -from pulsar.schema import Record, String, Boolean, Array, Integer +from dataclasses import dataclass, field -class Error(Record): - type = String() - message = String() +@dataclass +class Error: + type: str = "" + message: str = "" -class Value(Record): - value = String() - is_uri = Boolean() - type = String() +@dataclass +class Value: + value: str = "" + is_uri: bool = False + type: str = "" -class Triple(Record): - s = Value() - p = Value() - o = Value() +@dataclass +class Triple: + s: Value | None = None + p: Value | None = None + o: Value | None = None -class Field(Record): - name = String() +@dataclass +class Field: + name: str = "" # int, string, long, bool, float, double, timestamp - type = String() - size = Integer() - primary = Boolean() - description = String() + type: str = "" + size: int = 0 + primary: bool = False + description: str = "" # NEW FIELDS for structured data: - required = Boolean() # Whether field is required - enum_values = Array(String()) # For enum type fields - indexed = Boolean() # Whether field should be indexed + required: bool = False # Whether field is required + enum_values: list[str] = field(default_factory=list) # For enum type fields + indexed: bool = False # Whether field should be indexed -class RowSchema(Record): - name = String() - description = String() - fields = Array(Field()) +@dataclass +class RowSchema: + name: str = "" + description: str = "" + fields: list[Field] = field(default_factory=list) diff --git a/trustgraph-base/trustgraph/schema/core/topic.py b/trustgraph-base/trustgraph/schema/core/topic.py index cdd643b7..09c633e4 100644 --- a/trustgraph-base/trustgraph/schema/core/topic.py +++ b/trustgraph-base/trustgraph/schema/core/topic.py @@ -1,4 +1,23 @@ -def topic(topic, kind='persistent', tenant='tg', namespace='flow'): - return f"{kind}://{tenant}/{namespace}/{topic}" +def topic(queue_name, qos='q1', tenant='tg', namespace='flow'): + """ + Create a generic topic identifier that can be mapped by backends. + + Args: + queue_name: The queue/topic name + qos: Quality of service + - 'q0' = best-effort (no ack) + - 'q1' = at-least-once (ack required) + - 'q2' = exactly-once (two-phase ack) + tenant: Tenant identifier for multi-tenancy + namespace: Namespace within tenant + + Returns: + Generic topic string: qos/tenant/namespace/queue_name + + Examples: + topic('my-queue') # q1/tg/flow/my-queue + topic('config', qos='q2', namespace='config') # q2/tg/config/config + """ + return f"{qos}/{tenant}/{namespace}/{queue_name}" diff --git a/trustgraph-base/trustgraph/schema/knowledge/document.py b/trustgraph-base/trustgraph/schema/knowledge/document.py index f41ee8a6..d8ce97b4 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/document.py +++ b/trustgraph-base/trustgraph/schema/knowledge/document.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, Bytes +from dataclasses import dataclass from ..core.metadata import Metadata from ..core.topic import topic @@ -6,24 +6,27 @@ from ..core.topic import topic ############################################################################ # PDF docs etc. -class Document(Record): - metadata = Metadata() - data = Bytes() +@dataclass +class Document: + metadata: Metadata | None = None + data: bytes = b"" ############################################################################ # Text documents / text from PDF -class TextDocument(Record): - metadata = Metadata() - text = Bytes() +@dataclass +class TextDocument: + metadata: Metadata | None = None + text: bytes = b"" ############################################################################ # Chunks of text -class Chunk(Record): - metadata = Metadata() - chunk = Bytes() +@dataclass +class Chunk: + metadata: Metadata | None = None + chunk: bytes = b"" -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py index cfdae068..a3e5b394 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/embeddings.py +++ b/trustgraph-base/trustgraph/schema/knowledge/embeddings.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double, Map +from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.primitives import Value, RowSchema @@ -8,49 +8,55 @@ from ..core.topic import topic # Graph embeddings are embeddings associated with a graph entity -class EntityEmbeddings(Record): - entity = Value() - vectors = Array(Array(Double())) +@dataclass +class EntityEmbeddings: + entity: Value | None = None + vectors: list[list[float]] = field(default_factory=list) # This is a 'batching' mechanism for the above data -class GraphEmbeddings(Record): - metadata = Metadata() - entities = Array(EntityEmbeddings()) +@dataclass +class GraphEmbeddings: + metadata: Metadata | None = None + entities: list[EntityEmbeddings] = field(default_factory=list) ############################################################################ # Document embeddings are embeddings associated with a chunk -class ChunkEmbeddings(Record): - chunk = Bytes() - vectors = Array(Array(Double())) +@dataclass +class ChunkEmbeddings: + chunk: bytes = b"" + vectors: list[list[float]] = field(default_factory=list) # This is a 'batching' mechanism for the above data -class DocumentEmbeddings(Record): - metadata = Metadata() - chunks = Array(ChunkEmbeddings()) +@dataclass +class DocumentEmbeddings: + metadata: Metadata | None = None + chunks: list[ChunkEmbeddings] = field(default_factory=list) ############################################################################ # Object embeddings are embeddings associated with the primary key of an # object -class ObjectEmbeddings(Record): - metadata = Metadata() - vectors = Array(Array(Double())) - name = String() - key_name = String() - id = String() +@dataclass +class ObjectEmbeddings: + metadata: Metadata | None = None + vectors: list[list[float]] = field(default_factory=list) + name: str = "" + key_name: str = "" + id: str = "" ############################################################################ # Structured object embeddings with enhanced capabilities -class StructuredObjectEmbedding(Record): - metadata = Metadata() - vectors = Array(Array(Double())) - schema_name = String() - object_id = String() # Primary key value - field_embeddings = Map(Array(Double())) # Per-field embeddings +@dataclass +class StructuredObjectEmbedding: + metadata: Metadata | None = None + vectors: list[list[float]] = field(default_factory=list) + schema_name: str = "" + object_id: str = "" # Primary key value + field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/graph.py b/trustgraph-base/trustgraph/schema/knowledge/graph.py index 1d55c8f0..9040c25e 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/graph.py +++ b/trustgraph-base/trustgraph/schema/knowledge/graph.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Array +from dataclasses import dataclass, field from ..core.primitives import Value, Triple from ..core.metadata import Metadata @@ -8,21 +8,24 @@ from ..core.topic import topic # Entity context are an entity associated with textual context -class EntityContext(Record): - entity = Value() - context = String() +@dataclass +class EntityContext: + entity: Value | None = None + context: str = "" # This is a 'batching' mechanism for the above data -class EntityContexts(Record): - metadata = Metadata() - entities = Array(EntityContext()) +@dataclass +class EntityContexts: + metadata: Metadata | None = None + entities: list[EntityContext] = field(default_factory=list) ############################################################################ # Graph triples -class Triples(Record): - metadata = Metadata() - triples = Array(Triple()) +@dataclass +class Triples: + metadata: Metadata | None = None + triples: list[Triple] = field(default_factory=list) -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 7cd5450e..cffcbac7 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -1,5 +1,4 @@ - -from pulsar.schema import Record, Bytes, String, Array, Long, Boolean +from dataclasses import dataclass, field from ..core.primitives import Triple, Error from ..core.topic import topic from ..core.metadata import Metadata @@ -22,40 +21,40 @@ from .embeddings import GraphEmbeddings # <- () # <- (error) -class KnowledgeRequest(Record): - +@dataclass +class KnowledgeRequest: # get-kg-core, delete-kg-core, list-kg-cores, put-kg-core # load-kg-core, unload-kg-core - operation = String() + operation: str = "" # list-kg-cores, delete-kg-core, put-kg-core - user = String() + user: str = "" # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # load-kg-core, unload-kg-core - id = String() + id: str = "" # load-kg-core - flow = String() + flow: str = "" # load-kg-core - collection = String() + collection: str = "" # put-kg-core - triples = Triples() - graph_embeddings = GraphEmbeddings() + triples: Triples | None = None + graph_embeddings: GraphEmbeddings | None = None -class KnowledgeResponse(Record): - error = Error() - ids = Array(String()) - eos = Boolean() # Indicates end of knowledge core stream - triples = Triples() - graph_embeddings = GraphEmbeddings() +@dataclass +class KnowledgeResponse: + error: Error | None = None + ids: list[str] = field(default_factory=list) + eos: bool = False # Indicates end of knowledge core stream + triples: Triples | None = None + graph_embeddings: GraphEmbeddings | None = None knowledge_request_queue = topic( - 'knowledge', kind='non-persistent', namespace='request' + 'knowledge', qos='q0', namespace='request' ) knowledge_response_queue = topic( - 'knowledge', kind='non-persistent', namespace='response', + 'knowledge', qos='q0', namespace='response', ) - diff --git a/trustgraph-base/trustgraph/schema/knowledge/nlp.py b/trustgraph-base/trustgraph/schema/knowledge/nlp.py index 0ffc3ba1..10b5f215 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/nlp.py +++ b/trustgraph-base/trustgraph/schema/knowledge/nlp.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Boolean +from dataclasses import dataclass from ..core.topic import topic @@ -6,21 +6,25 @@ from ..core.topic import topic # NLP extraction data types -class Definition(Record): - name = String() - definition = String() +@dataclass +class Definition: + name: str = "" + definition: str = "" -class Topic(Record): - name = String() - definition = String() +@dataclass +class Topic: + name: str = "" + definition: str = "" -class Relationship(Record): - s = String() - p = String() - o = String() - o_entity = Boolean() +@dataclass +class Relationship: + s: str = "" + p: str = "" + o: str = "" + o_entity: bool = False -class Fact(Record): - s = String() - p = String() - o = String() \ No newline at end of file +@dataclass +class Fact: + s: str = "" + p: str = "" + o: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/object.py b/trustgraph-base/trustgraph/schema/knowledge/object.py index 537eb95e..39b0095f 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/object.py +++ b/trustgraph-base/trustgraph/schema/knowledge/object.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Map, Double, Array +from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.topic import topic @@ -7,11 +7,13 @@ from ..core.topic import topic # Extracted object from text processing -class ExtractedObject(Record): - metadata = Metadata() - schema_name = String() # Which schema this object belongs to - values = Array(Map(String())) # Array of objects, each object is field name -> value - confidence = Double() - source_span = String() # Text span where object was found +@dataclass +class ExtractedObject: + metadata: Metadata | None = None + schema_name: str = "" # Which schema this object belongs to + values: list[dict[str, str]] = field(default_factory=list) # Array of objects, each object is field name -> value + confidence: float = 0.0 + source_span: str = "" # Text span where object was found + +############################################################################ -############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/rows.py b/trustgraph-base/trustgraph/schema/knowledge/rows.py index 8b1c79ef..ca2131df 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/rows.py +++ b/trustgraph-base/trustgraph/schema/knowledge/rows.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, Array, Map, String +from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.primitives import RowSchema @@ -8,9 +8,10 @@ from ..core.topic import topic # Stores rows of information -class Rows(Record): - metadata = Metadata() - row_schema = RowSchema() - rows = Array(Map(String())) +@dataclass +class Rows: + metadata: Metadata | None = None + row_schema: RowSchema | None = None + rows: list[dict[str, str]] = field(default_factory=list) -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/knowledge/structured.py b/trustgraph-base/trustgraph/schema/knowledge/structured.py index 3d2b1311..c227d767 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/structured.py +++ b/trustgraph-base/trustgraph/schema/knowledge/structured.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Bytes, Map +from dataclasses import dataclass, field from ..core.metadata import Metadata from ..core.topic import topic @@ -7,11 +7,13 @@ from ..core.topic import topic # Structured data submission for fire-and-forget processing -class StructuredDataSubmission(Record): - metadata = Metadata() - format = String() # "json", "csv", "xml" - schema_name = String() # Reference to schema in config - data = Bytes() # Raw data to ingest - options = Map(String()) # Format-specific options +@dataclass +class StructuredDataSubmission: + metadata: Metadata | None = None + format: str = "" # "json", "csv", "xml" + schema_name: str = "" # Reference to schema in config + data: bytes = b"" # Raw data to ingest + options: dict[str, str] = field(default_factory=dict) # Format-specific options + +############################################################################ -############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index 6e8be5eb..9f883ff2 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -1,5 +1,5 @@ -from pulsar.schema import Record, String, Array, Map, Boolean +from dataclasses import dataclass, field from ..core.topic import topic from ..core.primitives import Error @@ -8,33 +8,36 @@ from ..core.primitives import Error # Prompt services, abstract the prompt generation -class AgentStep(Record): - thought = String() - action = String() - arguments = Map(String()) - observation = String() - user = String() # User context for the step +@dataclass +class AgentStep: + thought: str = "" + action: str = "" + arguments: dict[str, str] = field(default_factory=dict) + observation: str = "" + user: str = "" # User context for the step -class AgentRequest(Record): - question = String() - state = String() - group = Array(String()) - history = Array(AgentStep()) - user = String() # User context for multi-tenancy - streaming = Boolean() # NEW: Enable streaming response delivery (default false) +@dataclass +class AgentRequest: + question: str = "" + state: str = "" + group: list[str] | None = None + history: list[AgentStep] = field(default_factory=list) + user: str = "" # User context for multi-tenancy + streaming: bool = False # NEW: Enable streaming response delivery (default false) -class AgentResponse(Record): +@dataclass +class AgentResponse: # Streaming-first design - chunk_type = String() # "thought", "action", "observation", "answer", "error" - content = String() # The actual content (interpretation depends on chunk_type) - end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete - end_of_dialog = Boolean() # Entire agent dialog is complete + chunk_type: str = "" # "thought", "action", "observation", "answer", "error" + content: str = "" # The actual content (interpretation depends on chunk_type) + end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete + end_of_dialog: bool = False # Entire agent dialog is complete # Legacy fields (deprecated but kept for backward compatibility) - answer = String() - error = Error() - thought = String() - observation = String() + answer: str = "" + error: Error | None = None + thought: str = "" + observation: str = "" ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index 905b2056..74381abb 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Integer, Array +from dataclasses import dataclass, field from datetime import datetime from ..core.primitives import Error @@ -10,41 +10,40 @@ from ..core.topic import topic # Collection metadata operations (for librarian service) -class CollectionMetadata(Record): +@dataclass +class CollectionMetadata: """Collection metadata record""" - user = String() - collection = String() - name = String() - description = String() - tags = Array(String()) - created_at = String() # ISO timestamp - updated_at = String() # ISO timestamp + user: str = "" + collection: str = "" + name: str = "" + description: str = "" + tags: list[str] = field(default_factory=list) ############################################################################ -class CollectionManagementRequest(Record): +@dataclass +class CollectionManagementRequest: """Request for collection management operations""" - operation = String() # e.g., "delete-collection" + operation: str = "" # e.g., "delete-collection" # For 'list-collections' - user = String() - collection = String() - timestamp = String() # ISO timestamp - name = String() - description = String() - tags = Array(String()) - created_at = String() # ISO timestamp - updated_at = String() # ISO timestamp + user: str = "" + collection: str = "" + timestamp: str = "" # ISO timestamp + name: str = "" + description: str = "" + tags: list[str] = field(default_factory=list) # For list - tag_filter = Array(String()) # Optional filter by tags - limit = Integer() + tag_filter: list[str] = field(default_factory=list) # Optional filter by tags + limit: int = 0 -class CollectionManagementResponse(Record): +@dataclass +class CollectionManagementResponse: """Response for collection management operations""" - error = Error() # Only populated if there's an error - timestamp = String() # ISO timestamp - collections = Array(CollectionMetadata()) + error: Error | None = None # Only populated if there's an error + timestamp: str = "" # ISO timestamp + collections: list[CollectionMetadata] = field(default_factory=list) ############################################################################ @@ -52,8 +51,9 @@ class CollectionManagementResponse(Record): # Topics collection_request_queue = topic( - 'collection', kind='non-persistent', namespace='request' + 'collection', qos='q0', namespace='request' ) collection_response_queue = topic( - 'collection', kind='non-persistent', namespace='response' + 'collection', qos='q0', namespace='response' ) + diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index a0955eab..38bd1cbf 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -1,5 +1,5 @@ -from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer +from dataclasses import dataclass, field from ..core.topic import topic from ..core.primitives import Error @@ -13,58 +13,61 @@ from ..core.primitives import Error # put(values) -> () # delete(keys) -> () # config() -> (version, config) -class ConfigKey(Record): - type = String() - key = String() +@dataclass +class ConfigKey: + type: str = "" + key: str = "" -class ConfigValue(Record): - type = String() - key = String() - value = String() +@dataclass +class ConfigValue: + type: str = "" + key: str = "" + value: str = "" # Prompt services, abstract the prompt generation -class ConfigRequest(Record): - - operation = String() # get, list, getvalues, delete, put, config +@dataclass +class ConfigRequest: + operation: str = "" # get, list, getvalues, delete, put, config # get, delete - keys = Array(ConfigKey()) + keys: list[ConfigKey] = field(default_factory=list) # list, getvalues - type = String() + type: str = "" # put - values = Array(ConfigValue()) - -class ConfigResponse(Record): + values: list[ConfigValue] = field(default_factory=list) +@dataclass +class ConfigResponse: # get, list, getvalues, config - version = Integer() + version: int = 0 # get, getvalues - values = Array(ConfigValue()) + values: list[ConfigValue] = field(default_factory=list) # list - directory = Array(String()) + directory: list[str] = field(default_factory=list) # config - config = Map(Map(String())) + config: dict[str, dict[str, str]] = field(default_factory=dict) # Everything - error = Error() + error: Error | None = None -class ConfigPush(Record): - version = Integer() - config = Map(Map(String())) +@dataclass +class ConfigPush: + version: int = 0 + config: dict[str, dict[str, str]] = field(default_factory=dict) config_request_queue = topic( - 'config', kind='non-persistent', namespace='request' + 'config', qos='q0', namespace='request' ) config_response_queue = topic( - 'config', kind='non-persistent', namespace='response' + 'config', qos='q0', namespace='response' ) config_push_queue = topic( - 'config', kind='persistent', namespace='config' + 'config', qos='q2', namespace='config' ) ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/diagnosis.py b/trustgraph-base/trustgraph/schema/services/diagnosis.py index 1bd6d3ed..529e7d9e 100644 --- a/trustgraph-base/trustgraph/schema/services/diagnosis.py +++ b/trustgraph-base/trustgraph/schema/services/diagnosis.py @@ -1,33 +1,36 @@ -from pulsar.schema import Record, String, Map, Double, Array +from dataclasses import dataclass, field from ..core.primitives import Error ############################################################################ # Structured data diagnosis services -class StructuredDataDiagnosisRequest(Record): - operation = String() # "detect-type", "generate-descriptor", "diagnose", or "schema-selection" - sample = String() # Data sample to analyze (text content) - type = String() # Data type (csv, json, xml) - optional, required for generate-descriptor - schema_name = String() # Target schema name for descriptor generation - optional +@dataclass +class StructuredDataDiagnosisRequest: + operation: str = "" # "detect-type", "generate-descriptor", "diagnose", or "schema-selection" + sample: str = "" # Data sample to analyze (text content) + type: str = "" # Data type (csv, json, xml) - optional, required for generate-descriptor + schema_name: str = "" # Target schema name for descriptor generation - optional # JSON encoded options (e.g., delimiter for CSV) - options = Map(String()) + options: dict[str, str] = field(default_factory=dict) -class StructuredDataDiagnosisResponse(Record): - error = Error() +@dataclass +class StructuredDataDiagnosisResponse: + error: Error | None = None - operation = String() # The operation that was performed - detected_type = String() # Detected data type (for detect-type/diagnose) - optional - confidence = Double() # Confidence score for type detection - optional + operation: str = "" # The operation that was performed + detected_type: str = "" # Detected data type (for detect-type/diagnose) - optional + confidence: float = 0.0 # Confidence score for type detection - optional # JSON encoded descriptor (for generate-descriptor/diagnose) - optional - descriptor = String() + descriptor: str = "" # JSON encoded additional metadata (e.g., field count, sample records) - metadata = Map(String()) + metadata: dict[str, str] = field(default_factory=dict) # Array of matching schema IDs (for schema-selection operation) - optional - schema_matches = Array(String()) + schema_matches: list[str] = field(default_factory=list) + +############################################################################ -############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index d03e559b..b993b1b3 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -1,5 +1,5 @@ -from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer +from dataclasses import dataclass, field from ..core.topic import topic from ..core.primitives import Error @@ -11,61 +11,61 @@ from ..core.primitives import Error # get_class(classname) -> (class) # put_class(class) -> (class) # delete_class(classname) -> () -# +# # list_flows() -> (flowid[]) # get_flow(flowid) -> (flow) # start_flow(flowid, classname) -> () # stop_flow(flowid) -> () # Prompt services, abstract the prompt generation -class FlowRequest(Record): - - operation = String() # list-classes, get-class, put-class, delete-class +@dataclass +class FlowRequest: + operation: str = "" # list-classes, get-class, put-class, delete-class # list-flows, get-flow, start-flow, stop-flow # get_class, put_class, delete_class, start_flow - class_name = String() + class_name: str = "" # put_class - class_definition = String() + class_definition: str = "" # start_flow - description = String() + description: str = "" # get_flow, start_flow, stop_flow - flow_id = String() + flow_id: str = "" # start_flow - optional parameters for flow customization - parameters = Map(String()) - -class FlowResponse(Record): + parameters: dict[str, str] = field(default_factory=dict) +@dataclass +class FlowResponse: # list_classes - class_names = Array(String()) + class_names: list[str] = field(default_factory=list) # list_flows - flow_ids = Array(String()) + flow_ids: list[str] = field(default_factory=list) # get_class - class_definition = String() + class_definition: str = "" # get_flow - flow = String() + flow: str = "" # get_flow - description = String() + description: str = "" # get_flow - parameters used when flow was started - parameters = Map(String()) + parameters: dict[str, str] = field(default_factory=dict) # Everything - error = Error() + error: Error | None = None flow_request_queue = topic( - 'flow', kind='non-persistent', namespace='request' + 'flow', qos='q0', namespace='request' ) flow_response_queue = topic( - 'flow', kind='non-persistent', namespace='response' + 'flow', qos='q0', namespace='response' ) ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index d9678a90..391d49e1 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -1,9 +1,8 @@ - -from pulsar.schema import Record, Bytes, String, Array, Long +from dataclasses import dataclass, field from ..core.primitives import Triple, Error from ..core.topic import topic from ..core.metadata import Metadata -from ..knowledge.document import Document, TextDocument +# Note: Document imports will be updated after knowledge schemas are converted # add-document # -> (document_id, document_metadata, content) @@ -50,76 +49,79 @@ from ..knowledge.document import Document, TextDocument # <- (processing_metadata[]) # <- (error) -class DocumentMetadata(Record): - id = String() - time = Long() - kind = String() - title = String() - comments = String() - metadata = Array(Triple()) - user = String() - tags = Array(String()) +@dataclass +class DocumentMetadata: + id: str = "" + time: int = 0 + kind: str = "" + title: str = "" + comments: str = "" + metadata: list[Triple] = field(default_factory=list) + user: str = "" + tags: list[str] = field(default_factory=list) -class ProcessingMetadata(Record): - id = String() - document_id = String() - time = Long() - flow = String() - user = String() - collection = String() - tags = Array(String()) +@dataclass +class ProcessingMetadata: + id: str = "" + document_id: str = "" + time: int = 0 + flow: str = "" + user: str = "" + collection: str = "" + tags: list[str] = field(default_factory=list) -class Criteria(Record): - key = String() - value = String() - operator = String() - -class LibrarianRequest(Record): +@dataclass +class Criteria: + key: str = "" + value: str = "" + operator: str = "" +@dataclass +class LibrarianRequest: # add-document, remove-document, update-document, get-document-metadata, # get-document-content, add-processing, remove-processing, list-documents, # list-processing - operation = String() + operation: str = "" # add-document, remove-document, update-document, get-document-metadata, # get-document-content - document_id = String() + document_id: str = "" # add-processing, remove-processing - processing_id = String() + processing_id: str = "" # add-document, update-document - document_metadata = DocumentMetadata() + document_metadata: DocumentMetadata | None = None # add-processing - processing_metadata = ProcessingMetadata() + processing_metadata: ProcessingMetadata | None = None # add-document - content = Bytes() + content: bytes = b"" # list-documents, list-processing - user = String() + user: str = "" # list-documents?, list-processing? - collection = String() + collection: str = "" - # - criteria = Array(Criteria()) + # + criteria: list[Criteria] = field(default_factory=list) -class LibrarianResponse(Record): - error = Error() - document_metadata = DocumentMetadata() - content = Bytes() - document_metadatas = Array(DocumentMetadata()) - processing_metadatas = Array(ProcessingMetadata()) +@dataclass +class LibrarianResponse: + error: Error | None = None + document_metadata: DocumentMetadata | None = None + content: bytes = b"" + document_metadatas: list[DocumentMetadata] = field(default_factory=list) + processing_metadatas: list[ProcessingMetadata] = field(default_factory=list) # FIXME: Is this right? Using persistence on librarian so that # message chunking works librarian_request_queue = topic( - 'librarian', kind='persistent', namespace='request' + 'librarian', qos='q1', namespace='request' ) librarian_response_queue = topic( - 'librarian', kind='persistent', namespace='response', + 'librarian', qos='q1', namespace='response', ) - diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 3fd21937..1261158e 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -1,5 +1,5 @@ -from pulsar.schema import Record, String, Array, Double, Integer, Boolean +from dataclasses import dataclass, field from ..core.topic import topic from ..core.primitives import Error @@ -8,46 +8,49 @@ from ..core.primitives import Error # LLM text completion -class TextCompletionRequest(Record): - system = String() - prompt = String() - streaming = Boolean() # Default false for backward compatibility +@dataclass +class TextCompletionRequest: + system: str = "" + prompt: str = "" + streaming: bool = False # Default false for backward compatibility -class TextCompletionResponse(Record): - error = Error() - response = String() - in_token = Integer() - out_token = Integer() - model = String() - end_of_stream = Boolean() # Indicates final message in stream +@dataclass +class TextCompletionResponse: + error: Error | None = None + response: str = "" + in_token: int = 0 + out_token: int = 0 + model: str = "" + end_of_stream: bool = False # Indicates final message in stream ############################################################################ # Embeddings -class EmbeddingsRequest(Record): - text = String() +@dataclass +class EmbeddingsRequest: + text: str = "" -class EmbeddingsResponse(Record): - error = Error() - vectors = Array(Array(Double())) +@dataclass +class EmbeddingsResponse: + error: Error | None = None + vectors: list[list[float]] = field(default_factory=list) ############################################################################ # Tool request/response -class ToolRequest(Record): - name = String() - +@dataclass +class ToolRequest: + name: str = "" # Parameters are JSON encoded - parameters = String() - -class ToolResponse(Record): - error = Error() + parameters: str = "" +@dataclass +class ToolResponse: + error: Error | None = None # Plain text aka "unstructured" - text = String() - + text: str = "" # JSON-encoded object aka "structured" - object = String() + object: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/lookup.py b/trustgraph-base/trustgraph/schema/services/lookup.py index 7cc0bd03..bdeac636 100644 --- a/trustgraph-base/trustgraph/schema/services/lookup.py +++ b/trustgraph-base/trustgraph/schema/services/lookup.py @@ -1,5 +1,4 @@ - -from pulsar.schema import Record, String +from dataclasses import dataclass from ..core.primitives import Error, Value, Triple from ..core.topic import topic @@ -9,13 +8,14 @@ from ..core.metadata import Metadata # Lookups -class LookupRequest(Record): - kind = String() - term = String() +@dataclass +class LookupRequest: + kind: str = "" + term: str = "" -class LookupResponse(Record): - text = String() - error = Error() +@dataclass +class LookupResponse: + text: str = "" + error: Error | None = None ############################################################################ - diff --git a/trustgraph-base/trustgraph/schema/services/nlp_query.py b/trustgraph-base/trustgraph/schema/services/nlp_query.py index a3e709a1..6cd65f0e 100644 --- a/trustgraph-base/trustgraph/schema/services/nlp_query.py +++ b/trustgraph-base/trustgraph/schema/services/nlp_query.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Array, Map, Integer, Double +from dataclasses import dataclass, field from ..core.primitives import Error from ..core.topic import topic @@ -7,15 +7,18 @@ from ..core.topic import topic # NLP to Structured Query Service - converts natural language to GraphQL -class QuestionToStructuredQueryRequest(Record): - question = String() - max_results = Integer() +@dataclass +class QuestionToStructuredQueryRequest: + question: str = "" + max_results: int = 0 -class QuestionToStructuredQueryResponse(Record): - error = Error() - graphql_query = String() # Generated GraphQL query - variables = Map(String()) # GraphQL variables if any - detected_schemas = Array(String()) # Which schemas the query targets - confidence = Double() +@dataclass +class QuestionToStructuredQueryResponse: + error: Error | None = None + graphql_query: str = "" # Generated GraphQL query + variables: dict[str, str] = field(default_factory=dict) # GraphQL variables if any + detected_schemas: list[str] = field(default_factory=list) # Which schemas the query targets + confidence: float = 0.0 ############################################################################ + diff --git a/trustgraph-base/trustgraph/schema/services/objects_query.py b/trustgraph-base/trustgraph/schema/services/objects_query.py index 6c3a307c..e24daef3 100644 --- a/trustgraph-base/trustgraph/schema/services/objects_query.py +++ b/trustgraph-base/trustgraph/schema/services/objects_query.py @@ -1,4 +1,5 @@ -from pulsar.schema import Record, String, Map, Array +from dataclasses import dataclass, field +from typing import Optional from ..core.primitives import Error from ..core.topic import topic @@ -7,22 +8,25 @@ from ..core.topic import topic # Objects Query Service - executes GraphQL queries against structured data -class GraphQLError(Record): - message = String() - path = Array(String()) # Path to the field that caused the error - extensions = Map(String()) # Additional error metadata +@dataclass +class GraphQLError: + message: str = "" + path: list[str] = field(default_factory=list) # Path to the field that caused the error + extensions: dict[str, str] = field(default_factory=dict) # Additional error metadata -class ObjectsQueryRequest(Record): - user = String() # Cassandra keyspace (follows pattern from TriplesQueryRequest) - collection = String() # Data collection identifier (required for partition key) - query = String() # GraphQL query string - variables = Map(String()) # GraphQL variables - operation_name = String() # Operation to execute for multi-operation documents +@dataclass +class ObjectsQueryRequest: + user: str = "" # Cassandra keyspace (follows pattern from TriplesQueryRequest) + collection: str = "" # Data collection identifier (required for partition key) + query: str = "" # GraphQL query string + variables: dict[str, str] = field(default_factory=dict) # GraphQL variables + operation_name: Optional[str] = None # Operation to execute for multi-operation documents -class ObjectsQueryResponse(Record): - error = Error() # System-level error (connection, timeout, etc.) - data = String() # JSON-encoded GraphQL response data - errors = Array(GraphQLError()) # GraphQL field-level errors - extensions = Map(String()) # Query metadata (execution time, etc.) +@dataclass +class ObjectsQueryResponse: + error: Error | None = None # System-level error (connection, timeout, etc.) + data: str = "" # JSON-encoded GraphQL response data + errors: list[GraphQLError] = field(default_factory=list) # GraphQL field-level errors + extensions: dict[str, str] = field(default_factory=dict) # Query metadata (execution time, etc.) -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index edb569c9..f7a31c14 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Map, Boolean +from dataclasses import dataclass, field from ..core.primitives import Error from ..core.topic import topic @@ -18,27 +18,28 @@ from ..core.topic import topic # extract-rows # schema, chunk -> rows -class PromptRequest(Record): - id = String() +@dataclass +class PromptRequest: + id: str = "" # JSON encoded values - terms = Map(String()) + terms: dict[str, str] = field(default_factory=dict) # Streaming support (default false for backward compatibility) - streaming = Boolean() - -class PromptResponse(Record): + streaming: bool = False +@dataclass +class PromptResponse: # Error case - error = Error() + error: Error | None = None # Just plain text - text = String() + text: str = "" # JSON encoded - object = String() + object: str = "" # Indicates final message in stream - end_of_stream = Boolean() + end_of_stream: bool = False ############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 91231ade..31d0852d 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Integer, Array, Double +from dataclasses import dataclass, field from ..core.primitives import Error, Value, Triple from ..core.topic import topic @@ -7,49 +7,55 @@ from ..core.topic import topic # Graph embeddings query -class GraphEmbeddingsRequest(Record): - vectors = Array(Array(Double())) - limit = Integer() - user = String() - collection = String() +@dataclass +class GraphEmbeddingsRequest: + vectors: list[list[float]] = field(default_factory=list) + limit: int = 0 + user: str = "" + collection: str = "" -class GraphEmbeddingsResponse(Record): - error = Error() - entities = Array(Value()) +@dataclass +class GraphEmbeddingsResponse: + error: Error | None = None + entities: list[Value] = field(default_factory=list) ############################################################################ # Graph triples query -class TriplesQueryRequest(Record): - user = String() - collection = String() - s = Value() - p = Value() - o = Value() - limit = Integer() +@dataclass +class TriplesQueryRequest: + user: str = "" + collection: str = "" + s: Value | None = None + p: Value | None = None + o: Value | None = None + limit: int = 0 -class TriplesQueryResponse(Record): - error = Error() - triples = Array(Triple()) +@dataclass +class TriplesQueryResponse: + error: Error | None = None + triples: list[Triple] = field(default_factory=list) ############################################################################ # Doc embeddings query -class DocumentEmbeddingsRequest(Record): - vectors = Array(Array(Double())) - limit = Integer() - user = String() - collection = String() +@dataclass +class DocumentEmbeddingsRequest: + vectors: list[list[float]] = field(default_factory=list) + limit: int = 0 + user: str = "" + collection: str = "" -class DocumentEmbeddingsResponse(Record): - error = Error() - chunks = Array(String()) +@dataclass +class DocumentEmbeddingsResponse: + error: Error | None = None + chunks: list[str] = field(default_factory=list) document_embeddings_request_queue = topic( - "non-persistent://trustgraph/document-embeddings-request" + "document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow' ) document_embeddings_response_queue = topic( - "non-persistent://trustgraph/document-embeddings-response" + "document-embeddings-response", qos='q0', tenant='trustgraph', namespace='flow' ) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 3cd7f792..72085ae8 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -1,5 +1,4 @@ - -from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double +from dataclasses import dataclass from ..core.topic import topic from ..core.primitives import Error, Value @@ -7,36 +6,37 @@ from ..core.primitives import Error, Value # Graph RAG text retrieval -class GraphRagQuery(Record): - query = String() - user = String() - collection = String() - entity_limit = Integer() - triple_limit = Integer() - max_subgraph_size = Integer() - max_path_length = Integer() - streaming = Boolean() +@dataclass +class GraphRagQuery: + query: str = "" + user: str = "" + collection: str = "" + entity_limit: int = 0 + triple_limit: int = 0 + max_subgraph_size: int = 0 + max_path_length: int = 0 + streaming: bool = False -class GraphRagResponse(Record): - error = Error() - response = String() - chunk = String() - end_of_stream = Boolean() +@dataclass +class GraphRagResponse: + error: Error | None = None + response: str = "" + end_of_stream: bool = False ############################################################################ # Document RAG text retrieval -class DocumentRagQuery(Record): - query = String() - user = String() - collection = String() - doc_limit = Integer() - streaming = Boolean() - -class DocumentRagResponse(Record): - error = Error() - response = String() - chunk = String() - end_of_stream = Boolean() +@dataclass +class DocumentRagQuery: + query: str = "" + user: str = "" + collection: str = "" + doc_limit: int = 0 + streaming: bool = False +@dataclass +class DocumentRagResponse: + error: Error | None = None + response: str = "" + end_of_stream: bool = False diff --git a/trustgraph-base/trustgraph/schema/services/storage.py b/trustgraph-base/trustgraph/schema/services/storage.py index 16791615..b010e54b 100644 --- a/trustgraph-base/trustgraph/schema/services/storage.py +++ b/trustgraph-base/trustgraph/schema/services/storage.py @@ -1,42 +1,8 @@ -from pulsar.schema import Record, String +# This file previously contained legacy storage management queue definitions +# (StorageManagementRequest, StorageManagementResponse, and related topics). +# +# These have been removed as collection management now uses a config-based +# approach via CollectionConfigHandler instead of request/response queues. +# +# This file is kept for potential future storage-related schema definitions. -from ..core.primitives import Error -from ..core.topic import topic - -############################################################################ - -# Storage management operations - -class StorageManagementRequest(Record): - """Request for storage management operations sent to store processors""" - operation = String() # e.g., "delete-collection" - user = String() - collection = String() - -class StorageManagementResponse(Record): - """Response from storage processors for management operations""" - error = Error() # Only populated if there's an error, if null success - -############################################################################ - -# Storage management topics - -# Topics for sending collection management requests to different storage types -vector_storage_management_topic = topic( - 'vector-storage-management', kind='non-persistent', namespace='request' -) - -object_storage_management_topic = topic( - 'object-storage-management', kind='non-persistent', namespace='request' -) - -triples_storage_management_topic = topic( - 'triples-storage-management', kind='non-persistent', namespace='request' -) - -# Topic for receiving responses from storage processors -storage_management_response_topic = topic( - 'storage-management', kind='non-persistent', namespace='response' -) - -############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/structured_query.py b/trustgraph-base/trustgraph/schema/services/structured_query.py index df21bfe2..ae1eaa5f 100644 --- a/trustgraph-base/trustgraph/schema/services/structured_query.py +++ b/trustgraph-base/trustgraph/schema/services/structured_query.py @@ -1,4 +1,4 @@ -from pulsar.schema import Record, String, Map, Array +from dataclasses import dataclass, field from ..core.primitives import Error from ..core.topic import topic @@ -7,14 +7,17 @@ from ..core.topic import topic # Structured Query Service - executes GraphQL queries -class StructuredQueryRequest(Record): - question = String() - user = String() # Cassandra keyspace identifier - collection = String() # Data collection identifier +@dataclass +class StructuredQueryRequest: + question: str = "" + user: str = "" # Cassandra keyspace identifier + collection: str = "" # Data collection identifier -class StructuredQueryResponse(Record): - error = Error() - data = String() # JSON-encoded GraphQL response data - errors = Array(String()) # GraphQL errors if any +@dataclass +class StructuredQueryResponse: + error: Error | None = None + data: str = "" # JSON-encoded GraphQL response data + errors: list[str] = field(default_factory=list) # GraphQL errors if any ############################################################################ + diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 3442d47f..c9192794 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 572238b7..65921d92 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", "requests", "pulsar-client", "aiohttp", @@ -84,6 +84,7 @@ tg-unload-kg-core = "trustgraph.cli.unload_kg_core:main" tg-start-library-processing = "trustgraph.cli.start_library_processing:main" tg-stop-flow = "trustgraph.cli.stop_flow:main" tg-stop-library-processing = "trustgraph.cli.stop_library_processing:main" +tg-verify-system-status = "trustgraph.cli.verify_system_status:main" tg-list-config-items = "trustgraph.cli.list_config_items:main" tg-get-config-item = "trustgraph.cli.get_config_item:main" tg-put-config-item = "trustgraph.cli.put_config_item:main" diff --git a/trustgraph-cli/trustgraph/cli/delete_config_item.py b/trustgraph-cli/trustgraph/cli/delete_config_item.py index 1de02890..cf4cba93 100644 --- a/trustgraph-cli/trustgraph/cli/delete_config_item.py +++ b/trustgraph-cli/trustgraph/cli/delete_config_item.py @@ -8,10 +8,11 @@ from trustgraph.api import Api from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def delete_config_item(url, config_type, key): +def delete_config_item(url, config_type, key, token=None): - api = Api(url).config() + api = Api(url, token=token).config() config_key = ConfigKey(type=config_type, key=key) api.delete([config_key]) @@ -43,6 +44,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -51,6 +58,7 @@ def main(): url=args.api_url, config_type=args.type, key=args.key, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/dump_queues.py b/trustgraph-cli/trustgraph/cli/dump_queues.py index 93151858..0a298450 100644 --- a/trustgraph-cli/trustgraph/cli/dump_queues.py +++ b/trustgraph-cli/trustgraph/cli/dump_queues.py @@ -17,6 +17,7 @@ from datetime import datetime import argparse from trustgraph.base.subscriber import Subscriber +from trustgraph.base.pubsub import get_pubsub def format_message(queue_name, msg): """Format a message with timestamp and queue name.""" @@ -167,11 +168,11 @@ async def async_main(queues, output_file, pulsar_host, listener_name, subscriber print(f"Mode: {'append' if append_mode else 'overwrite'}") print(f"Press Ctrl+C to stop\n") - # Connect to Pulsar + # Create backend connection try: - client = pulsar.Client(pulsar_host, listener_name=listener_name) + backend = get_pubsub(pulsar_host=pulsar_host, pulsar_listener=listener_name, pubsub_backend='pulsar') except Exception as e: - print(f"Error connecting to Pulsar at {pulsar_host}: {e}", file=sys.stderr) + print(f"Error connecting to backend at {pulsar_host}: {e}", file=sys.stderr) sys.exit(1) # Create Subscribers and central queue @@ -181,7 +182,7 @@ async def async_main(queues, output_file, pulsar_host, listener_name, subscriber for queue_name in queues: try: sub = Subscriber( - client=client, + backend=backend, topic=queue_name, subscription=subscriber_name, consumer_name=f"{subscriber_name}-{queue_name}", @@ -195,7 +196,7 @@ async def async_main(queues, output_file, pulsar_host, listener_name, subscriber if not subscribers: print("\nNo subscribers created. Exiting.", file=sys.stderr) - client.close() + backend.close() sys.exit(1) print(f"\nListening for messages...\n") @@ -256,7 +257,7 @@ async def async_main(queues, output_file, pulsar_host, listener_name, subscriber # Clean shutdown of Subscribers for _, sub in subscribers: await sub.stop() - client.close() + backend.close() print(f"\nMessages logged to: {output_file}") diff --git a/trustgraph-cli/trustgraph/cli/get_config_item.py b/trustgraph-cli/trustgraph/cli/get_config_item.py index 832d2711..c2421e94 100644 --- a/trustgraph-cli/trustgraph/cli/get_config_item.py +++ b/trustgraph-cli/trustgraph/cli/get_config_item.py @@ -9,10 +9,11 @@ from trustgraph.api import Api from trustgraph.api.types import ConfigKey default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def get_config_item(url, config_type, key, format_type): +def get_config_item(url, config_type, key, format_type, token=None): - api = Api(url).config() + api = Api(url, token=token).config() config_key = ConfigKey(type=config_type, key=key) values = api.get([config_key]) @@ -59,6 +60,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -68,6 +75,7 @@ def main(): config_type=args.type, key=args.key, format_type=args.format, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index 6e0a8bc0..b75f7155 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -14,6 +14,7 @@ import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def write_triple(f, data): msg = ( @@ -51,13 +52,16 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, user, id, output): +async def fetch(url, user, id, output, token=None): if not url.endswith("/"): url += "/" url = url + "api/v1/socket" + if token: + url = f"{url}?token={token}" + mid = str(uuid.uuid4()) async with connect(url) as ws: @@ -138,6 +142,12 @@ def main(): help=f'Output file' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -148,6 +158,7 @@ def main(): user = args.user, id = args.id, output = args.output, + token = args.token, ) ) diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index e6e82edd..de70021b 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -5,12 +5,10 @@ Uses the agent service to answer a question import argparse import os import textwrap -import uuid -import asyncio -import json -from websockets.asyncio.client import connect +from trustgraph.api import Api -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' @@ -99,79 +97,47 @@ def output(text, prefix="> ", width=78): ) print(out) -async def question( +def question( url, question, flow_id, user, collection, - plan=None, state=None, group=None, verbose=False, streaming=True + plan=None, state=None, group=None, verbose=False, streaming=True, + token=None ): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - if verbose: output(wrap(question), "\U00002753 ") print() - # Track last chunk type and current outputter for streaming - last_chunk_type = None - current_outputter = None + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) - def think(x): - if verbose: - output(wrap(x), "\U0001f914 ") - print() + # Prepare request parameters + request_params = { + "question": question, + "user": user, + "streaming": streaming, + } - def observe(x): - if verbose: - output(wrap(x), "\U0001f4a1 ") - print() + # Only add optional fields if they have values + if state is not None: + request_params["state"] = state + if group is not None: + request_params["group"] = group - mid = str(uuid.uuid4()) + try: + # Call agent + response = flow.agent(**request_params) - async with connect(url) as ws: + # Handle streaming response + if streaming: + # Track last chunk type and current outputter for streaming + last_chunk_type = None + current_outputter = None - req = { - "id": mid, - "service": "agent", - "flow": flow_id, - "request": { - "question": question, - "user": user, - "history": [], - "streaming": streaming - } - } - - # Only add optional fields if they have values - if state is not None: - req["request"]["state"] = state - if group is not None: - req["request"]["group"] = group - - req = json.dumps(req) - - await ws.send(req) - - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "error" in obj: - raise RuntimeError(obj["error"]) - - if obj["id"] != mid: - print("Ignore message") - continue - - response = obj["response"] - - # Handle streaming format (new format with chunk_type) - if "chunk_type" in response: - chunk_type = response["chunk_type"] - content = response.get("content", "") + for chunk in response: + chunk_type = chunk.chunk_type + content = chunk.content # Check if we're switching to a new message type if last_chunk_type != chunk_type: @@ -195,33 +161,32 @@ async def question( # Output the chunk if current_outputter: current_outputter.output(content) - elif chunk_type == "answer": + # Flush word buffer after each chunk to avoid delay + if current_outputter.word_buffer: + print(current_outputter.word_buffer, end="", flush=True) + current_outputter.column += len(current_outputter.word_buffer) + current_outputter.word_buffer = "" + elif chunk_type == "final-answer": print(content, end="", flush=True) - else: - # Handle legacy format (backward compatibility) - if "thought" in response: - think(response["thought"]) - if "observation" in response: - observe(response["observation"]) + # Close any remaining outputter + if current_outputter: + current_outputter.__exit__(None, None, None) + current_outputter = None + # Add final newline if we were outputting answer + elif last_chunk_type == "final-answer": + print() - if "answer" in response: - print(response["answer"]) + else: + # Non-streaming response + if "answer" in response: + print(response["answer"]) + if "error" in response: + raise RuntimeError(response["error"]) - if "error" in response: - raise RuntimeError(response["error"]) - - if obj["complete"]: - # Close any remaining outputter - if current_outputter: - current_outputter.__exit__(None, None, None) - current_outputter = None - # Add final newline if we were outputting answer - elif last_chunk_type == "answer": - print() - break - - await ws.close() + finally: + # Clean up socket connection + socket.close() def main(): @@ -236,6 +201,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -292,19 +263,18 @@ def main(): try: - asyncio.run( - question( - url = args.url, - flow_id = args.flow_id, - question = args.question, - user = args.user, - collection = args.collection, - plan = args.plan, - state = args.state, - group = args.group, - verbose = args.verbose, - streaming = not args.no_streaming, - ) + question( + url = args.url, + flow_id = args.flow_id, + question = args.question, + user = args.user, + collection = args.collection, + plan = args.plan, + state = args.state, + group = args.group, + verbose = args.verbose, + streaming = not args.no_streaming, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index e6a040ac..7e88bdc4 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -4,89 +4,50 @@ Uses the DocumentRAG service to answer a question import argparse import os -import asyncio -import json -import uuid -from websockets.asyncio.client import connect from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' default_doc_limit = 10 -async def question_streaming(url, flow_id, question, user, collection, doc_limit): - """Streaming version using websockets""" +def question(url, flow_id, question, user, collection, doc_limit, streaming=True, token=None): - # Convert http:// to ws:// - if url.startswith('http://'): - url = 'ws://' + url[7:] - elif url.startswith('https://'): - url = 'wss://' + url[8:] + # Create API client + api = Api(url=url, token=token) - if not url.endswith("/"): - url += "/" + if streaming: + # Use socket client for streaming + socket = api.socket() + flow = socket.flow(flow_id) - url = url + "api/v1/socket" + try: + response = flow.document_rag( + query=question, + user=user, + collection=collection, + doc_limit=doc_limit, + streaming=True + ) - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - req = { - "id": mid, - "service": "document-rag", - "flow": flow_id, - "request": { - "query": question, - "user": user, - "collection": collection, - "doc-limit": doc_limit, - "streaming": True - } - } - - req = json.dumps(req) - await ws.send(req) - - while True: - msg = await ws.recv() - obj = json.loads(msg) - - if "error" in obj: - raise RuntimeError(obj["error"]) - - if obj["id"] != mid: - print("Ignore message") - continue - - response = obj["response"] - - # Handle streaming format (chunk) - if "chunk" in response: - chunk = response["chunk"] + # Stream output + for chunk in response: print(chunk, end="", flush=True) - elif "response" in response: - # Final response with complete text - # Already printed via chunks, just add newline - pass + print() # Final newline - if obj["complete"]: - print() # Final newline - break - - await ws.close() - -def question_non_streaming(url, flow_id, question, user, collection, doc_limit): - """Non-streaming version using HTTP API""" - - api = Api(url).flow().id(flow_id) - - resp = api.document_rag( - question=question, user=user, collection=collection, - doc_limit=doc_limit, - ) - - print(resp) + finally: + socket.close() + else: + # Use REST API for non-streaming + flow = api.flow().id(flow_id) + resp = flow.document_rag( + query=question, + user=user, + collection=collection, + doc_limit=doc_limit, + ) + print(resp) def main(): @@ -101,6 +62,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -127,6 +94,7 @@ def main(): parser.add_argument( '-d', '--doc-limit', + type=int, default=default_doc_limit, help=f'Document limit (default: {default_doc_limit})' ) @@ -141,30 +109,20 @@ def main(): try: - if not args.no_streaming: - asyncio.run( - question_streaming( - url=args.url, - flow_id=args.flow_id, - question=args.question, - user=args.user, - collection=args.collection, - doc_limit=args.doc_limit, - ) - ) - else: - question_non_streaming( - url=args.url, - flow_id=args.flow_id, - question=args.question, - user=args.user, - collection=args.collection, - doc_limit=args.doc_limit, - ) + question( + url=args.url, + flow_id=args.flow_id, + question=args.question, + user=args.user, + collection=args.collection, + doc_limit=args.doc_limit, + streaming=not args.no_streaming, + token=args.token, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 45d02b6d..5fa359ab 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -4,13 +4,10 @@ Uses the GraphRAG service to answer a question import argparse import os -import asyncio -import json -import uuid -from websockets.asyncio.client import connect from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' default_entity_limit = 50 @@ -18,89 +15,51 @@ default_triple_limit = 30 default_max_subgraph_size = 150 default_max_path_length = 2 -async def question_streaming( +def question( url, flow_id, question, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length + max_subgraph_size, max_path_length, streaming=True, token=None ): - """Streaming version using websockets""" - # Convert http:// to ws:// - if url.startswith('http://'): - url = 'ws://' + url[7:] - elif url.startswith('https://'): - url = 'wss://' + url[8:] + # Create API client + api = Api(url=url, token=token) - if not url.endswith("/"): - url += "/" + if streaming: + # Use socket client for streaming + socket = api.socket() + flow = socket.flow(flow_id) - url = url + "api/v1/socket" + try: + response = flow.graph_rag( + query=question, + user=user, + collection=collection, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, + streaming=True + ) - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - req = { - "id": mid, - "service": "graph-rag", - "flow": flow_id, - "request": { - "query": question, - "user": user, - "collection": collection, - "entity-limit": entity_limit, - "triple-limit": triple_limit, - "max-subgraph-size": max_subgraph_size, - "max-path-length": max_path_length, - "streaming": True - } - } - - req = json.dumps(req) - await ws.send(req) - - while True: - msg = await ws.recv() - obj = json.loads(msg) - - if "error" in obj: - raise RuntimeError(obj["error"]) - - if obj["id"] != mid: - print("Ignore message") - continue - - response = obj["response"] - - # Handle streaming format (chunk) - if "chunk" in response: - chunk = response["chunk"] + # Stream output + for chunk in response: print(chunk, end="", flush=True) - elif "response" in response: - # Final response with complete text - # Already printed via chunks, just add newline - pass + print() # Final newline - if obj["complete"]: - print() # Final newline - break - - await ws.close() - -def question_non_streaming( - url, flow_id, question, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length -): - """Non-streaming version using HTTP API""" - - api = Api(url).flow().id(flow_id) - - resp = api.graph_rag( - question=question, user=user, collection=collection, - entity_limit=entity_limit, triple_limit=triple_limit, - max_subgraph_size=max_subgraph_size, - max_path_length=max_path_length - ) - - print(resp) + finally: + socket.close() + else: + # Use REST API for non-streaming + flow = api.flow().id(flow_id) + resp = flow.graph_rag( + query=question, + user=user, + collection=collection, + entity_limit=entity_limit, + triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length + ) + print(resp) def main(): @@ -115,6 +74,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -141,24 +106,28 @@ def main(): parser.add_argument( '-e', '--entity-limit', + type=int, default=default_entity_limit, help=f'Entity limit (default: {default_entity_limit})' ) parser.add_argument( - '-t', '--triple-limit', + '--triple-limit', + type=int, default=default_triple_limit, help=f'Triple limit (default: {default_triple_limit})' ) parser.add_argument( '-s', '--max-subgraph-size', + type=int, default=default_max_subgraph_size, help=f'Max subgraph size (default: {default_max_subgraph_size})' ) parser.add_argument( '-p', '--max-path-length', + type=int, default=default_max_path_length, help=f'Max path length (default: {default_max_path_length})' ) @@ -173,36 +142,23 @@ def main(): try: - if not args.no_streaming: - asyncio.run( - question_streaming( - url=args.url, - flow_id=args.flow_id, - question=args.question, - user=args.user, - collection=args.collection, - entity_limit=args.entity_limit, - triple_limit=args.triple_limit, - max_subgraph_size=args.max_subgraph_size, - max_path_length=args.max_path_length, - ) - ) - else: - question_non_streaming( - url=args.url, - flow_id=args.flow_id, - question=args.question, - user=args.user, - collection=args.collection, - entity_limit=args.entity_limit, - triple_limit=args.triple_limit, - max_subgraph_size=args.max_subgraph_size, - max_path_length=args.max_path_length, - ) + question( + url=args.url, + flow_id=args.flow_id, + question=args.question, + user=args.user, + collection=args.collection, + entity_limit=args.entity_limit, + triple_limit=args.triple_limit, + max_subgraph_size=args.max_subgraph_size, + max_path_length=args.max_path_length, + streaming=not args.no_streaming, + token=args.token, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index da69dcd6..a1611625 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -5,64 +5,39 @@ and user prompt. Both arguments are required. import argparse import os -import json -import uuid -import asyncio -from websockets.asyncio.client import connect +from trustgraph.api import Api -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -async def query(url, flow_id, system, prompt, streaming=True): +def query(url, flow_id, system, prompt, streaming=True, token=None): - if not url.endswith("/"): - url += "/" + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) - url = url + "api/v1/socket" + try: + # Call text completion + response = flow.text_completion( + system=system, + prompt=prompt, + streaming=streaming + ) - mid = str(uuid.uuid4()) + if streaming: + # Stream output to stdout without newline + for chunk in response: + print(chunk, end="", flush=True) + # Add final newline after streaming + print() + else: + # Non-streaming: print complete response + print(response) - async with connect(url) as ws: - - req = { - "id": mid, - "service": "text-completion", - "flow": flow_id, - "request": { - "system": system, - "prompt": prompt, - "streaming": streaming - } - } - - await ws.send(json.dumps(req)) - - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "error" in obj: - raise RuntimeError(obj["error"]) - - if obj["id"] != mid: - continue - - if "response" in obj["response"]: - if streaming: - # Stream output to stdout without newline - print(obj["response"]["response"], end="", flush=True) - else: - # Non-streaming: print complete response - print(obj["response"]["response"]) - - if obj["complete"]: - if streaming: - # Add final newline after streaming - print() - break - - await ws.close() + finally: + # Clean up socket connection + socket.close() def main(): @@ -77,6 +52,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( 'system', nargs=1, @@ -105,17 +86,18 @@ def main(): try: - asyncio.run(query( + query( url=args.url, flow_id=args.flow_id, system=args.system[0], prompt=args.prompt[0], - streaming=not args.no_streaming - )) + streaming=not args.no_streaming, + token=args.token, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index c996c57d..09cc9043 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -10,76 +10,41 @@ using key=value arguments on the command line, and these replace import argparse import os import json -import uuid -import asyncio -from websockets.asyncio.client import connect +from trustgraph.api import Api -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -async def query(url, flow_id, template_id, variables, streaming=True): +def query(url, flow_id, template_id, variables, streaming=True, token=None): - if not url.endswith("/"): - url += "/" + # Create API client + api = Api(url=url, token=token) + socket = api.socket() + flow = socket.flow(flow_id) - url = url + "api/v1/socket" + try: + # Call prompt + response = flow.prompt( + id=template_id, + variables=variables, + streaming=streaming + ) - mid = str(uuid.uuid4()) + if streaming: + # Stream output (prompt yields strings directly) + for chunk in response: + if chunk: + print(chunk, end="", flush=True) + # Add final newline after streaming + print() - async with connect(url) as ws: + else: + # Non-streaming: print complete response + print(response) - req = { - "id": mid, - "service": "prompt", - "flow": flow_id, - "request": { - "id": template_id, - "variables": variables, - "streaming": streaming - } - } - - await ws.send(json.dumps(req)) - - full_response = {"text": "", "object": ""} - - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "error" in obj: - raise RuntimeError(obj["error"]) - - if obj["id"] != mid: - continue - - response = obj["response"] - - # Handle text responses (streaming) - if "text" in response and response["text"]: - if streaming: - # Stream output to stdout without newline - print(response["text"], end="", flush=True) - full_response["text"] += response["text"] - else: - # Non-streaming: print complete response - print(response["text"]) - - # Handle object responses (JSON, never streamed) - if "object" in response and response["object"]: - full_response["object"] = response["object"] - - if obj["complete"]: - if streaming and full_response["text"]: - # Add final newline after streaming text - print() - elif full_response["object"]: - # Print JSON object (pretty-printed) - print(json.dumps(json.loads(full_response["object"]), indent=4)) - break - - await ws.close() + finally: + # Clean up socket connection + socket.close() def main(): @@ -94,6 +59,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-f', '--flow-id', default="default", @@ -135,17 +106,18 @@ specified multiple times''', try: - asyncio.run(query( + query( url=args.url, flow_id=args.flow_id, template_id=args.id[0], variables=variables, - streaming=not args.no_streaming - )) + streaming=not args.no_streaming, + token=args.token, + ) except Exception as e: print("Exception:", e, flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/list_collections.py b/trustgraph-cli/trustgraph/cli/list_collections.py index 56929e93..4086f471 100644 --- a/trustgraph-cli/trustgraph/cli/list_collections.py +++ b/trustgraph-cli/trustgraph/cli/list_collections.py @@ -28,19 +28,17 @@ def list_collections(url, user, tag_filter): collection.collection, collection.name, collection.description, - ", ".join(collection.tags), - collection.created_at, - collection.updated_at + ", ".join(collection.tags) ]) - headers = ["Collection", "Name", "Description", "Tags", "Created", "Updated"] + headers = ["Collection", "Name", "Description", "Tags"] print(tabulate.tabulate( table, headers=headers, tablefmt="pretty", stralign="left", - maxcolwidths=[20, 30, 50, 30, 19, 19], + maxcolwidths=[20, 30, 50, 30], )) def main(): diff --git a/trustgraph-cli/trustgraph/cli/list_config_items.py b/trustgraph-cli/trustgraph/cli/list_config_items.py index 33e8f7ba..5cd0f233 100644 --- a/trustgraph-cli/trustgraph/cli/list_config_items.py +++ b/trustgraph-cli/trustgraph/cli/list_config_items.py @@ -8,10 +8,11 @@ import json from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def list_config_items(url, config_type, format_type): +def list_config_items(url, config_type, format_type, token=None): - api = Api(url).config() + api = Api(url, token=token).config() keys = api.list(config_type) @@ -47,6 +48,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -55,6 +62,7 @@ def main(): url=args.api_url, config_type=args.type, format_type=args.format, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/load_knowledge.py b/trustgraph-cli/trustgraph/cli/load_knowledge.py index 58081fa1..ff6ca980 100644 --- a/trustgraph-cli/trustgraph/cli/load_knowledge.py +++ b/trustgraph-cli/trustgraph/cli/load_knowledge.py @@ -2,18 +2,17 @@ Loads triples and entity contexts into the knowledge graph. """ -import asyncio import argparse import os import time import rdflib -import json -from websockets.asyncio.client import connect -from typing import List, Dict, Any +from typing import Iterator, Tuple +from trustgraph.api import Api, Triple from trustgraph.log_level import LogLevel -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' @@ -26,108 +25,114 @@ class KnowledgeLoader: user, collection, document_id, - url = default_url, + url=default_url, + token=None, ): - - if not url.endswith("/"): - url += "/" - - self.triples_url = url + f"api/v1/flow/{flow}/import/triples" - self.entity_contexts_url = url + f"api/v1/flow/{flow}/import/entity-contexts" - self.files = files + self.flow = flow self.user = user self.collection = collection self.document_id = document_id + self.url = url + self.token = token - async def run(self): - - try: - # Load triples first - async with connect(self.triples_url) as ws: - for file in self.files: - await self.load_triples(file, ws) - - # Then load entity contexts - async with connect(self.entity_contexts_url) as ws: - for file in self.files: - await self.load_entity_contexts(file, ws) - - except Exception as e: - print(e, flush=True) - - async def load_triples(self, file, ws): + def load_triples_from_file(self, file) -> Iterator[Triple]: + """Generator that yields Triple objects from a Turtle file""" g = rdflib.Graph() g.parse(file, format="turtle") - def Value(value, is_uri): - return { "v": value, "e": is_uri } - for e in g: - s = Value(value=str(e[0]), is_uri=True) - p = Value(value=str(e[1]), is_uri=True) - if type(e[2]) == rdflib.term.URIRef: - o = Value(value=str(e[2]), is_uri=True) + # Extract subject, predicate, object + s_value = str(e[0]) + p_value = str(e[1]) + + # Check if object is a URI or literal + if isinstance(e[2], rdflib.term.URIRef): + o_value = str(e[2]) + o_is_uri = True else: - o = Value(value=str(e[2]), is_uri=False) + o_value = str(e[2]) + o_is_uri = False - req = { - "metadata": { - "id": self.document_id, - "metadata": [], - "user": self.user, - "collection": self.collection - }, - "triples": [ - { - "s": s, - "p": p, - "o": o, - } - ] - } + # Create Triple object + # Note: The Triple dataclass has 's', 'p', 'o' fields as strings + # The API will handle the metadata wrapping + yield Triple(s=s_value, p=p_value, o=o_value) - await ws.send(json.dumps(req)) - - async def load_entity_contexts(self, file, ws): - """ - Load entity contexts by extracting entities from the RDF graph - and generating contextual descriptions based on their relationships. - """ + def load_entity_contexts_from_file(self, file) -> Iterator[Tuple[str, str]]: + """Generator that yields (entity, context) tuples from a Turtle file""" g = rdflib.Graph() g.parse(file, format="turtle") for s, p, o in g: - # If object is a URI, do nothing + # If object is a URI, skip (we only want literal contexts) if isinstance(o, rdflib.term.URIRef): continue - - # If object is a literal, create entity context for subject with literal as context + + # If object is a literal, create entity context for subject s_str = str(s) o_str = str(o) - - req = { - "metadata": { - "id": self.document_id, - "metadata": [], - "user": self.user, - "collection": self.collection - }, - "entities": [ - { - "entity": { - "v": s_str, - "e": True - }, - "context": o_str + + yield (s_str, o_str) + + def run(self): + """Load triples and entity contexts using Python API""" + + try: + # Create API client + api = Api(url=self.url, token=self.token) + bulk = api.bulk() + + # Load triples from all files + print("Loading triples...") + for file in self.files: + print(f" Processing {file}...") + triples = self.load_triples_from_file(file) + + bulk.import_triples( + flow=self.flow, + triples=triples, + metadata={ + "id": self.document_id, + "metadata": [], + "user": self.user, + "collection": self.collection } - ] - } + ) - await ws.send(json.dumps(req)) + print("Triples loaded.") + # Load entity contexts from all files + print("Loading entity contexts...") + for file in self.files: + print(f" Processing {file}...") + + # Convert tuples to the format expected by import_entity_contexts + def entity_context_generator(): + for entity, context in self.load_entity_contexts_from_file(file): + yield { + "entity": {"v": entity, "e": True}, + "context": context + } + + bulk.import_entity_contexts( + flow=self.flow, + entities=entity_context_generator(), + metadata={ + "id": self.document_id, + "metadata": [], + "user": self.user, + "collection": self.collection + } + ) + + print("Entity contexts loaded.") + + except Exception as e: + print(f"Error: {e}", flush=True) + raise def main(): @@ -142,6 +147,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -166,7 +177,6 @@ def main(): help=f'Collection ID (default: {default_collection})' ) - parser.add_argument( 'files', nargs='+', help=f'Turtle files to load' @@ -178,15 +188,16 @@ def main(): try: loader = KnowledgeLoader( - document_id = args.document_id, - url = args.api_url, - flow = args.flow_id, - files = args.files, - user = args.user, - collection = args.collection, + document_id=args.document_id, + url=args.api_url, + token=args.token, + flow=args.flow_id, + files=args.files, + user=args.user, + collection=args.collection, ) - asyncio.run(loader.run()) + loader.run() print("Triples and entity contexts loaded.") break @@ -199,4 +210,4 @@ def main(): time.sleep(10) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/load_sample_documents.py b/trustgraph-cli/trustgraph/cli/load_sample_documents.py index fd6751be..186006a8 100644 --- a/trustgraph-cli/trustgraph/cli/load_sample_documents.py +++ b/trustgraph-cli/trustgraph/cli/load_sample_documents.py @@ -13,6 +13,7 @@ from trustgraph.api.types import hash, Uri, Literal, Triple default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) from requests.adapters import HTTPAdapter @@ -655,10 +656,10 @@ documents = [ class Loader: def __init__( - self, url, user + self, url, user, token=None ): - self.api = Api(url).library() + self.api = Api(url, token=token).library() self.user = user def load(self, documents): @@ -719,6 +720,12 @@ def main(): help=f'User ID (default: {default_user})' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -726,6 +733,7 @@ def main(): p = Loader( url=args.url, user=args.user, + token=args.token, ) p.load(documents) diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py index 9bb9f78c..bf112417 100644 --- a/trustgraph-cli/trustgraph/cli/load_structured_data.py +++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py @@ -22,6 +22,7 @@ import logging logger = logging.getLogger(__name__) default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def load_structured_data( @@ -41,7 +42,8 @@ def load_structured_data( user: str = 'trustgraph', collection: str = 'default', dry_run: bool = False, - verbose: bool = False + verbose: bool = False, + token: str = None ): """ Load structured data using a descriptor configuration. @@ -133,9 +135,9 @@ def load_structured_data( # Get batch size from descriptor batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) - + # Send to TrustGraph using shared function - imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size) + imported_count = _send_to_trustgraph(output_objects, api_url, flow, batch_size, token=token) # Summary format_info = descriptor.get('format', {}) @@ -288,10 +290,10 @@ def load_structured_data( # Get batch size from descriptor or use default batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) - + # Send to TrustGraph print(f"🚀 Importing {len(output_records)} records to TrustGraph...") - imported_count = _send_to_trustgraph(output_records, api_url, flow, batch_size) + imported_count = _send_to_trustgraph(output_records, api_url, flow, batch_size, token=token) # Get summary info from descriptor format_info = descriptor.get('format', {}) @@ -571,66 +573,30 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample return output_records, descriptor -def _send_to_trustgraph(objects, api_url, flow, batch_size=1000): - """Send ExtractedObject records to TrustGraph using WebSocket""" - import json - import asyncio - from websockets.asyncio.client import connect - +def _send_to_trustgraph(objects, api_url, flow, batch_size=1000, token=None): + """Send ExtractedObject records to TrustGraph using Python API""" + from trustgraph.api import Api + try: - # Construct objects import URL similar to load_knowledge pattern - if not api_url.endswith("/"): - api_url += "/" - - # Convert HTTP URL to WebSocket URL if needed - ws_url = api_url.replace("http://", "ws://").replace("https://", "wss://") - objects_url = ws_url + f"api/v1/flow/{flow}/import/objects" - - logger.info(f"Connecting to objects import endpoint: {objects_url}") - - async def import_objects(): - async with connect(objects_url) as ws: - imported_count = 0 - - for record in objects: - try: - # Send individual ExtractedObject records - await ws.send(json.dumps(record)) - imported_count += 1 - - if imported_count % 100 == 0: - logger.debug(f"Imported {imported_count}/{len(objects)} records...") - - except Exception as e: - logger.error(f"Failed to send record {imported_count + 1}: {e}") - print(f"❌ Failed to send record {imported_count + 1}: {e}") - - logger.info(f"Successfully imported {imported_count} records to TrustGraph") - return imported_count - - # Run the async import - imported_count = asyncio.run(import_objects()) - - # Summary total_records = len(objects) - failed_count = total_records - imported_count - + logger.info(f"Importing {total_records} records to TrustGraph...") + + # Use Python API bulk import + api = Api(api_url, token=token) + bulk = api.bulk() + + bulk.import_objects(flow=flow, objects=iter(objects)) + + logger.info(f"Successfully imported {total_records} records to TrustGraph") + + # Summary print(f"\n📊 Import Summary:") print(f"- Total records: {total_records}") - print(f"- Successfully imported: {imported_count}") - print(f"- Failed: {failed_count}") - - if failed_count > 0: - print(f"⚠️ {failed_count} records failed to import. Check logs for details.") - else: - print("✅ All records imported successfully!") - - return imported_count - - except ImportError as e: - logger.error(f"Failed to import required modules: {e}") - print(f"Error: Required modules not available - {e}") - raise + print(f"- Successfully imported: {total_records}") + print("✅ All records imported successfully!") + + return total_records + except Exception as e: logger.error(f"Failed to import data to TrustGraph: {e}") print(f"Import failed: {e}") @@ -1024,7 +990,13 @@ For more information on the descriptor format, see: '--error-file', help='Path to write error records (optional)' ) - + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() # Input validation @@ -1077,7 +1049,8 @@ For more information on the descriptor format, see: user=args.user, collection=args.collection, dry_run=args.dry_run, - verbose=args.verbose + verbose=args.verbose, + token=args.token ) except FileNotFoundError as e: print(f"Error: File not found - {e}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/load_turtle.py b/trustgraph-cli/trustgraph/cli/load_turtle.py index c357c5d9..adb578f5 100644 --- a/trustgraph-cli/trustgraph/cli/load_turtle.py +++ b/trustgraph-cli/trustgraph/cli/load_turtle.py @@ -1,18 +1,18 @@ """ -Loads triples into the knowledge graph. +Loads triples into the knowledge graph from Turtle files. """ -import asyncio import argparse import os import time import rdflib -import json -from websockets.asyncio.client import connect +from typing import Iterator +from trustgraph.api import Api, Triple from trustgraph.log_level import LogLevel -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = 'trustgraph' default_collection = 'default' @@ -25,67 +25,67 @@ class Loader: user, collection, document_id, - url = default_url, + url=default_url, + token=None, ): - - if not url.endswith("/"): - url += "/" - - url = url + f"api/v1/flow/{flow}/import/triples" - - self.url = url - self.files = files + self.flow = flow self.user = user self.collection = collection self.document_id = document_id + self.url = url + self.token = token - async def run(self): - - try: - - async with connect(self.url) as ws: - for file in self.files: - await self.load_file(file, ws) - - except Exception as e: - print(e, flush=True) - - async def load_file(self, file, ws): + def load_triples_from_file(self, file) -> Iterator[Triple]: + """Generator that yields Triple objects from a Turtle file""" g = rdflib.Graph() g.parse(file, format="turtle") - def Value(value, is_uri): - return { "v": value, "e": is_uri } - - triples = [] - for e in g: - s = Value(value=str(e[0]), is_uri=True) - p = Value(value=str(e[1]), is_uri=True) - if type(e[2]) == rdflib.term.URIRef: - o = Value(value=str(e[2]), is_uri=True) + # Extract subject, predicate, object + s_value = str(e[0]) + p_value = str(e[1]) + + # Check if object is a URI or literal + if isinstance(e[2], rdflib.term.URIRef): + o_value = str(e[2]) else: - o = Value(value=str(e[2]), is_uri=False) + o_value = str(e[2]) - req = { - "metadata": { - "id": self.document_id, - "metadata": [], - "user": self.user, - "collection": self.collection - }, - "triples": [ - { - "s": s, - "p": p, - "o": o, + # Create Triple object + yield Triple(s=s_value, p=p_value, o=o_value) + + def run(self): + """Load triples using Python API""" + + try: + # Create API client + api = Api(url=self.url, token=self.token) + bulk = api.bulk() + + # Load triples from all files + print("Loading triples...") + for file in self.files: + print(f" Processing {file}...") + triples = self.load_triples_from_file(file) + + bulk.import_triples( + flow=self.flow, + triples=triples, + metadata={ + "id": self.document_id, + "metadata": [], + "user": self.user, + "collection": self.collection } - ] - } + ) - await ws.send(json.dumps(req)) + print("Triples loaded.") + + except Exception as e: + print(f"Error: {e}", flush=True) + raise def main(): @@ -100,6 +100,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-i', '--document-id', required=True, @@ -134,16 +140,17 @@ def main(): while True: try: - p = Loader( - document_id = args.document_id, - url = args.api_url, - flow = args.flow_id, - files = args.files, - user = args.user, - collection = args.collection, + loader = Loader( + document_id=args.document_id, + url=args.api_url, + token=args.token, + flow=args.flow_id, + files=args.files, + user=args.user, + collection=args.collection, ) - asyncio.run(p.run()) + loader.run() print("File loaded.") break @@ -156,4 +163,4 @@ def main(): time.sleep(10) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/put_config_item.py b/trustgraph-cli/trustgraph/cli/put_config_item.py index d48e29a7..d79864a4 100644 --- a/trustgraph-cli/trustgraph/cli/put_config_item.py +++ b/trustgraph-cli/trustgraph/cli/put_config_item.py @@ -9,10 +9,11 @@ from trustgraph.api import Api from trustgraph.api.types import ConfigValue default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def put_config_item(url, config_type, key, value): +def put_config_item(url, config_type, key, value, token=None): - api = Api(url).config() + api = Api(url, token=token).config() config_value = ConfigValue(type=config_type, key=key, value=value) api.put([config_value]) @@ -56,6 +57,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -70,6 +77,7 @@ def main(): config_type=args.type, key=args.key, value=value, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_flow_class.py b/trustgraph-cli/trustgraph/cli/put_flow_class.py index 5b4bc44b..6a88421d 100644 --- a/trustgraph-cli/trustgraph/cli/put_flow_class.py +++ b/trustgraph-cli/trustgraph/cli/put_flow_class.py @@ -9,10 +9,11 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def put_flow_class(url, class_name, config): +def put_flow_class(url, class_name, config, token=None): - api = Api(url) + api = Api(url, token=token) class_names = api.flow().put_class(class_name, config) @@ -29,6 +30,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-n', '--class-name', help=f'Flow class name', @@ -47,6 +54,7 @@ def main(): url=args.api_url, class_name=args.class_name, config=json.loads(args.config), + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index 6374e2f6..cd0738fe 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -13,6 +13,7 @@ import msgpack default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def read_message(unpacked, id, user): @@ -47,13 +48,16 @@ def read_message(unpacked, id, user): else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, user, id, input): +async def put(url, user, id, input, token=None): if not url.endswith("/"): url += "/" url = url + "api/v1/socket" + if token: + url = f"{url}?token={token}" + async with connect(url) as ws: @@ -160,6 +164,12 @@ def main(): help=f'Input file' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -170,6 +180,7 @@ def main(): user = args.user, id = args.id, input = args.input, + token = args.token, ) ) diff --git a/trustgraph-cli/trustgraph/cli/remove_library_document.py b/trustgraph-cli/trustgraph/cli/remove_library_document.py index f6e6813c..07a1fd59 100644 --- a/trustgraph-cli/trustgraph/cli/remove_library_document.py +++ b/trustgraph-cli/trustgraph/cli/remove_library_document.py @@ -10,11 +10,12 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def remove_doc(url, user, id): +def remove_doc(url, user, id, token=None): - api = Api(url).library() + api = Api(url, token=token).library() api.remove_document(user=user, id=id) @@ -43,11 +44,17 @@ def main(): help=f'Document ID' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: - remove_doc(args.url, args.user, args.identifier) + remove_doc(args.url, args.user, args.identifier, token=args.token) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_collection.py b/trustgraph-cli/trustgraph/cli/set_collection.py index e987c4c8..dd4148ea 100644 --- a/trustgraph-cli/trustgraph/cli/set_collection.py +++ b/trustgraph-cli/trustgraph/cli/set_collection.py @@ -9,10 +9,11 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def set_collection(url, user, collection, name, description, tags): +def set_collection(url, user, collection, name, description, tags, token=None): - api = Api(url).collection() + api = Api(url, token=token).collection() result = api.update_collection( user=user, @@ -30,7 +31,6 @@ def set_collection(url, user, collection, name, description, tags): table.append(("Name", result.name)) table.append(("Description", result.description)) table.append(("Tags", ", ".join(result.tags))) - table.append(("Updated", result.updated_at)) print(tabulate.tabulate( table, @@ -82,6 +82,12 @@ def main(): help='Collection tags (can be specified multiple times)' ) + parser.add_argument( + '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -92,7 +98,8 @@ def main(): collection = args.collection, name = args.name, description = args.description, - tags = args.tags + tags = args.tags, + token = args.token ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py index 05e3823c..7976adbc 100644 --- a/trustgraph-cli/trustgraph/cli/set_mcp_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_mcp_tool.py @@ -20,6 +20,7 @@ import textwrap import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def set_mcp_tool( url : str, @@ -27,9 +28,10 @@ def set_mcp_tool( remote_name : str, tool_url : str, auth_token : str = None, + token : str = None, ): - api = Api(url).config() + api = Api(url, token=token).config() # Build the MCP tool configuration config = { @@ -72,6 +74,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-i', '--id', required=True, @@ -116,7 +124,8 @@ def main(): id=args.id, remote_name=remote_name, tool_url=args.tool_url, - auth_token=args.auth_token + auth_token=args.auth_token, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_prompt.py b/trustgraph-cli/trustgraph/cli/set_prompt.py index f287a9cc..bffc2cf2 100644 --- a/trustgraph-cli/trustgraph/cli/set_prompt.py +++ b/trustgraph-cli/trustgraph/cli/set_prompt.py @@ -10,10 +10,11 @@ import tabulate import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def set_system(url, system): +def set_system(url, system, token=None): - api = Api(url).config() + api = Api(url, token=token).config() api.put([ ConfigValue(type="prompt", key="system", value=json.dumps(system)) @@ -21,9 +22,9 @@ def set_system(url, system): print("System prompt set.") -def set_prompt(url, id, prompt, response, schema): +def set_prompt(url, id, prompt, response, schema, token=None): - api = Api(url).config() + api = Api(url, token=token).config() values = api.get([ ConfigKey(type="prompt", key="template-index") @@ -71,6 +72,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '--id', help=f'Prompt ID', @@ -103,9 +110,9 @@ def main(): if args.system: if args.id or args.prompt or args.schema or args.response: raise RuntimeError("Can't use --system with other args") - + set_system( - url=args.api_url, system=args.system + url=args.api_url, system=args.system, token=args.token ) else: @@ -130,7 +137,7 @@ def main(): set_prompt( url=args.api_url, id=args.id, prompt=args.prompt, - response=args.response, schema=schobj + response=args.response, schema=schobj, token=args.token ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/set_token_costs.py b/trustgraph-cli/trustgraph/cli/set_token_costs.py index 87a4e264..19b8c703 100644 --- a/trustgraph-cli/trustgraph/cli/set_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/set_token_costs.py @@ -10,10 +10,11 @@ import tabulate import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def set_costs(api_url, model, input_costs, output_costs): +def set_costs(api_url, model, input_costs, output_costs, token=None): - api = Api(api_url).config() + api = Api(api_url, token=token).config() api.put([ ConfigValue( @@ -95,6 +96,12 @@ def main(): help=f'Input costs in $ per 1M tokens', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: diff --git a/trustgraph-cli/trustgraph/cli/set_tool.py b/trustgraph-cli/trustgraph/cli/set_tool.py index 2174c79b..36701a8e 100644 --- a/trustgraph-cli/trustgraph/cli/set_tool.py +++ b/trustgraph-cli/trustgraph/cli/set_tool.py @@ -26,6 +26,7 @@ import textwrap import dataclasses default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @dataclasses.dataclass class Argument: @@ -67,9 +68,10 @@ def set_tool( group : List[str], state : str, applicable_states : List[str], + token : str = None, ): - api = Api(url).config() + api = Api(url, token=token).config() values = api.get([ ConfigKey(type="agent", key="tool-index") @@ -156,6 +158,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '--id', help=f'Unique tool identifier', @@ -257,6 +265,7 @@ def main(): group=args.group, state=args.state, applicable_states=args.applicable_states, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_config.py b/trustgraph-cli/trustgraph/cli/show_config.py index 03b2636a..6f426533 100644 --- a/trustgraph-cli/trustgraph/cli/show_config.py +++ b/trustgraph-cli/trustgraph/cli/show_config.py @@ -8,10 +8,11 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_config(url): +def show_config(url, token=None): - api = Api(url).config() + api = Api(url, token=token).config() config, version = api.all() @@ -31,12 +32,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_config( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flow_classes.py b/trustgraph-cli/trustgraph/cli/show_flow_classes.py index d9ce96a7..123f5380 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_classes.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_classes.py @@ -9,6 +9,7 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def format_parameters(params_metadata, config_api): """ @@ -57,9 +58,9 @@ def format_parameters(params_metadata, config_api): return "\n".join(param_list) -def show_flow_classes(url): +def show_flow_classes(url, token=None): - api = Api(url) + api = Api(url, token=token) flow_api = api.flow() config_api = api.config() @@ -106,12 +107,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_flow_classes( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index ca6d2b1d..6ca4df8f 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -9,10 +9,11 @@ import os default_metrics_url = "http://localhost:8088/api/metrics" default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def dump_status(metrics_url, api_url, flow_id): +def dump_status(metrics_url, api_url, flow_id, token=None): - api = Api(api_url).flow() + api = Api(api_url, token=token).flow() flow = api.get(flow_id) class_name = flow["class-name"] @@ -77,11 +78,17 @@ def main(): help=f'Metrics URL (default: {default_metrics_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: - dump_status(args.metrics_url, args.api_url, args.flow_id) + dump_status(args.metrics_url, args.api_url, args.flow_id, token=args.token) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py index 18c1234e..b383ff56 100644 --- a/trustgraph-cli/trustgraph/cli/show_flows.py +++ b/trustgraph-cli/trustgraph/cli/show_flows.py @@ -9,6 +9,7 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def get_interface(config_api, i): @@ -128,9 +129,9 @@ def format_parameters(flow_params, class_params_metadata, config_api): return "\n".join(param_list) if param_list else "None" -def show_flows(url): +def show_flows(url, token=None): - api = Api(url) + api = Api(url, token=token) config_api = api.config() flow_api = api.flow() @@ -199,12 +200,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_flows( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index 232ebb34..b5b15e3c 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -9,10 +9,11 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_graph(url, flow_id, user, collection): +def show_graph(url, flow_id, user, collection, token=None): - api = Api(url).flow().id(flow_id) + api = Api(url, token=token).flow().id(flow_id) rows = api.triples_query( user=user, collection=collection, @@ -53,6 +54,12 @@ def main(): help=f'Collection ID (default: {default_collection})' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: @@ -62,6 +69,7 @@ def main(): flow_id = args.flow_id, user = args.user, collection = args.collection, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_kg_cores.py b/trustgraph-cli/trustgraph/cli/show_kg_cores.py index e3cf9eb4..ea295543 100644 --- a/trustgraph-cli/trustgraph/cli/show_kg_cores.py +++ b/trustgraph-cli/trustgraph/cli/show_kg_cores.py @@ -9,10 +9,11 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_cores(url, user): +def show_cores(url, user, token=None): - api = Api(url).knowledge() + api = Api(url, token=token).knowledge() ids = api.list_kg_cores() @@ -35,6 +36,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-U', '--user', default="trustgraph", @@ -46,7 +53,9 @@ def main(): try: show_cores( - url=args.api_url, user=args.user + url=args.api_url, + user=args.user, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_library_documents.py b/trustgraph-cli/trustgraph/cli/show_library_documents.py index b086238d..6eeceb70 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_documents.py +++ b/trustgraph-cli/trustgraph/cli/show_library_documents.py @@ -9,11 +9,12 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = "trustgraph" -def show_docs(url, user): +def show_docs(url, user, token=None): - api = Api(url).library() + api = Api(url, token=token).library() docs = api.get_documents(user=user) @@ -52,6 +53,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-U', '--user', default=default_user, @@ -63,7 +70,9 @@ def main(): try: show_docs( - url = args.api_url, user = args.user + url = args.api_url, + user = args.user, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_library_processing.py b/trustgraph-cli/trustgraph/cli/show_library_processing.py index 51dbe865..9ab69355 100644 --- a/trustgraph-cli/trustgraph/cli/show_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/show_library_processing.py @@ -9,10 +9,11 @@ import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = "trustgraph" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_procs(url, user): +def show_procs(url, user, token=None): - api = Api(url).library() + api = Api(url, token=token).library() procs = api.get_processings(user = user) @@ -57,12 +58,18 @@ def main(): help=f'User ID (default: {default_user})' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_procs( - url = args.api_url, user = args.user + url = args.api_url, user = args.user, token = args.token ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py index da0154ed..24cbfcfe 100644 --- a/trustgraph-cli/trustgraph/cli/show_mcp_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_mcp_tools.py @@ -10,10 +10,11 @@ import tabulate import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_config(url): +def show_config(url, token=None): - api = Api(url).config() + api = Api(url, token=token).config() values = api.get_values(type="mcp") @@ -57,12 +58,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_config( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_parameter_types.py b/trustgraph-cli/trustgraph/cli/show_parameter_types.py index 606c5016..e5b842b5 100644 --- a/trustgraph-cli/trustgraph/cli/show_parameter_types.py +++ b/trustgraph-cli/trustgraph/cli/show_parameter_types.py @@ -13,6 +13,7 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) def format_enum_values(enum_list): """ @@ -75,11 +76,11 @@ def format_constraints(param_type_def): return ", ".join(constraints) if constraints else "None" -def show_parameter_types(url): +def show_parameter_types(url, token=None): """ Show all parameter type definitions """ - api = Api(url) + api = Api(url, token=token) config_api = api.config() # Get list of all parameter types @@ -145,6 +146,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-t', '--type', help='Show only the specified parameter type', @@ -155,19 +162,19 @@ def main(): try: if args.type: # Show specific parameter type - show_specific_parameter_type(args.api_url, args.type) + show_specific_parameter_type(args.api_url, args.type, args.token) else: # Show all parameter types - show_parameter_types(args.api_url) + show_parameter_types(args.api_url, args.token) except Exception as e: print("Exception:", e, flush=True) -def show_specific_parameter_type(url, param_type_name): +def show_specific_parameter_type(url, param_type_name, token=None): """ Show a specific parameter type definition """ - api = Api(url) + api = Api(url, token=token) config_api = api.config() try: diff --git a/trustgraph-cli/trustgraph/cli/show_prompts.py b/trustgraph-cli/trustgraph/cli/show_prompts.py index 4c2ca4d7..0e1cb2ae 100644 --- a/trustgraph-cli/trustgraph/cli/show_prompts.py +++ b/trustgraph-cli/trustgraph/cli/show_prompts.py @@ -10,10 +10,11 @@ import tabulate import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_config(url): +def show_config(url, token=None): - api = Api(url).config() + api = Api(url, token=token).config() values = api.get([ ConfigKey(type="prompt", key="system"), @@ -78,12 +79,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_config( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_token_costs.py b/trustgraph-cli/trustgraph/cli/show_token_costs.py index 2f889eef..9e7c352a 100644 --- a/trustgraph-cli/trustgraph/cli/show_token_costs.py +++ b/trustgraph-cli/trustgraph/cli/show_token_costs.py @@ -12,10 +12,11 @@ import textwrap tabulate.PRESERVE_WHITESPACE = True default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_config(url): +def show_config(url, token=None): - api = Api(url).config() + api = Api(url, token=token).config() models = api.list("token-costs") @@ -61,12 +62,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_config( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_tools.py b/trustgraph-cli/trustgraph/cli/show_tools.py index ce79fffc..b8c9a012 100644 --- a/trustgraph-cli/trustgraph/cli/show_tools.py +++ b/trustgraph-cli/trustgraph/cli/show_tools.py @@ -17,10 +17,11 @@ import tabulate import textwrap default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_config(url): +def show_config(url, token=None): - api = Api(url).config() + api = Api(url, token=token).config() values = api.get_values(type="tool") @@ -100,12 +101,19 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + args = parser.parse_args() try: show_config( url=args.api_url, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_flow.py b/trustgraph-cli/trustgraph/cli/start_flow.py index fa9ce6a8..4f9954b0 100644 --- a/trustgraph-cli/trustgraph/cli/start_flow.py +++ b/trustgraph-cli/trustgraph/cli/start_flow.py @@ -17,10 +17,11 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def start_flow(url, class_name, flow_id, description, parameters=None): +def start_flow(url, class_name, flow_id, description, parameters=None, token=None): - api = Api(url).flow() + api = Api(url, token=token).flow() api.start( class_name = class_name, @@ -42,6 +43,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-n', '--class-name', required=True, @@ -112,6 +119,7 @@ def main(): flow_id = args.flow_id, description = args.description, parameters = parameters, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/start_library_processing.py b/trustgraph-cli/trustgraph/cli/start_library_processing.py index 3619628c..ff87ea9f 100644 --- a/trustgraph-cli/trustgraph/cli/start_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/start_library_processing.py @@ -9,13 +9,14 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = "trustgraph" def start_processing( - url, user, document_id, id, flow, collection, tags + url, user, document_id, id, flow, collection, tags, token=None ): - api = Api(url).library() + api = Api(url, token=token).library() if tags: tags = tags.split(",") @@ -44,6 +45,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-U', '--user', default=default_user, @@ -90,7 +97,8 @@ def main(): id = args.id, flow = args.flow_id, collection = args.collection, - tags = args.tags + tags = args.tags, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/stop_flow.py b/trustgraph-cli/trustgraph/cli/stop_flow.py index a5107579..ae3a0415 100644 --- a/trustgraph-cli/trustgraph/cli/stop_flow.py +++ b/trustgraph-cli/trustgraph/cli/stop_flow.py @@ -9,10 +9,11 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def stop_flow(url, flow_id): +def stop_flow(url, flow_id, token=None): - api = Api(url).flow() + api = Api(url, token=token).flow() api.stop(id = flow_id) @@ -29,6 +30,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-i', '--flow-id', required=True, @@ -42,6 +49,7 @@ def main(): stop_flow( url=args.api_url, flow_id=args.flow_id, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/stop_library_processing.py b/trustgraph-cli/trustgraph/cli/stop_library_processing.py index 638ab71c..3d8a2c56 100644 --- a/trustgraph-cli/trustgraph/cli/stop_library_processing.py +++ b/trustgraph-cli/trustgraph/cli/stop_library_processing.py @@ -10,13 +10,14 @@ from trustgraph.api import Api, ConfigKey import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_user = "trustgraph" def stop_processing( - url, user, id + url, user, id, token=None ): - api = Api(url).library() + api = Api(url, token=token).library() api.stop_processing(user = user, id = id) @@ -33,6 +34,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-U', '--user', default=default_user, @@ -53,6 +60,7 @@ def main(): url = args.api_url, user = args.user, id = args.id, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/unload_kg_core.py b/trustgraph-cli/trustgraph/cli/unload_kg_core.py index 079766d2..47f811f3 100644 --- a/trustgraph-cli/trustgraph/cli/unload_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/unload_kg_core.py @@ -11,12 +11,13 @@ from trustgraph.api import Api import json default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_flow = "default" default_collection = "default" -def unload_kg_core(url, user, id, flow): +def unload_kg_core(url, user, id, flow, token=None): - api = Api(url).knowledge() + api = Api(url, token=token).knowledge() class_names = api.unload_kg_core(user = user, id = id, flow=flow) @@ -33,6 +34,12 @@ def main(): help=f'API URL (default: {default_url})', ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + parser.add_argument( '-U', '--user', default="trustgraph", @@ -60,6 +67,7 @@ def main(): user=args.user, id=args.id, flow=args.flow_id, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py new file mode 100644 index 00000000..294a3738 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -0,0 +1,492 @@ +""" +Verifies TrustGraph system health by running comprehensive checks. + +This utility monitors system startup and health by checking: +1. Infrastructure (Pulsar, API Gateway) +2. Core services (processors, flows, prompts) +3. Data services (library) +4. UI (workbench) + +Includes intelligent retry logic with configurable timeouts. +""" + +import argparse +import os +import sys +import time +import requests +from typing import Tuple, Optional + +# Import existing CLI functions to reuse logic +from trustgraph.api import Api + +default_pulsar_url = "http://localhost:8080" +default_api_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +default_ui_url = "http://localhost:8888" +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) + + +class HealthChecker: + """Manages health check execution with retry logic and timeouts.""" + + def __init__( + self, + global_timeout: int = 120, + check_timeout: int = 10, + retry_delay: int = 3, + verbose: bool = False + ): + self.global_timeout = global_timeout + self.check_timeout = check_timeout + self.retry_delay = retry_delay + self.verbose = verbose + self.start_time = time.time() + self.checks_passed = 0 + self.checks_failed = 0 + + def elapsed(self) -> str: + """Return formatted elapsed time MM:SS.""" + elapsed_sec = int(time.time() - self.start_time) + minutes = elapsed_sec // 60 + seconds = elapsed_sec % 60 + return f"{minutes:02d}:{seconds:02d}" + + def time_remaining(self) -> float: + """Return seconds remaining in global timeout.""" + return self.global_timeout - (time.time() - self.start_time) + + def log(self, message: str, level: str = "info"): + """Log a message with timestamp.""" + timestamp = self.elapsed() + if level == "success": + icon = "✓" + elif level == "error": + icon = "✗" + elif level == "progress": + icon = "⏳" + else: + icon = " " + print(f"[{timestamp}] {icon} {message}", flush=True) + + def debug(self, message: str): + """Log a debug message if verbose mode is enabled.""" + if self.verbose: + timestamp = self.elapsed() + print(f"[{timestamp}] {message}", flush=True) + + def run_check( + self, + name: str, + check_func, + *args, + **kwargs + ) -> bool: + """ + Run a check with retry logic until success or global timeout. + + Args: + name: Human-readable name of the check + check_func: Function that returns (success: bool, message: str) + *args, **kwargs: Arguments to pass to check_func + + Returns: + True if check passed, False otherwise + """ + attempt = 0 + + while self.time_remaining() > 0: + attempt += 1 + + if attempt > 1: + self.log(f"Checking {name}... (attempt {attempt})", "progress") + else: + self.log(f"Checking {name}...", "progress") + + try: + # Run the check with timeout + success, message = check_func(*args, **kwargs) + + if success: + self.log(f"{name}: {message}", "success") + self.checks_passed += 1 + return True + else: + self.debug(f"{name} not ready: {message}") + + except Exception as e: + self.debug(f"{name} check failed with exception: {e}") + + # Check if we have time for another attempt + if self.time_remaining() < self.retry_delay: + break + + # Wait before retry + time.sleep(self.retry_delay) + + # Check failed + self.log(f"{name}: Failed (timeout after {attempt} attempts)", "error") + self.checks_failed += 1 + return False + + +def check_pulsar(url: str, timeout: int) -> Tuple[bool, str]: + """Check if Pulsar admin API is responding.""" + try: + resp = requests.get(f"{url}/admin/v2/clusters", timeout=timeout) + if resp.status_code == 200: + clusters = resp.json() + return True, f"Pulsar healthy ({len(clusters)} cluster(s))" + else: + return False, f"Pulsar returned status {resp.status_code}" + except requests.exceptions.Timeout: + return False, "Pulsar connection timeout" + except requests.exceptions.ConnectionError: + return False, "Cannot connect to Pulsar" + except Exception as e: + return False, f"Pulsar error: {e}" + + +def check_api_gateway(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if API Gateway is responding.""" + try: + # Try to hit the base URL + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + if not url.endswith('/'): + url += '/' + + resp = requests.get(url, headers=headers, timeout=timeout) + if resp.status_code in [200, 404]: # 404 is OK, means gateway is up + return True, "API Gateway is responding" + else: + return False, f"API Gateway returned status {resp.status_code}" + except requests.exceptions.Timeout: + return False, "API Gateway connection timeout" + except requests.exceptions.ConnectionError: + return False, "Cannot connect to API Gateway" + except Exception as e: + return False, f"API Gateway error: {e}" + + +def check_processors(url: str, min_processors: int, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if processors are running via metrics endpoint.""" + try: + # Construct metrics URL from API URL + if not url.endswith('/'): + url += '/' + metrics_url = f"{url}api/metrics/query?query=processor_info" + + resp = requests.get(metrics_url, timeout=timeout) + if resp.status_code == 200: + data = resp.json() + processor_count = len(data.get("data", {}).get("result", [])) + + if processor_count >= min_processors: + return True, f"Found {processor_count} processors (≥ {min_processors})" + else: + return False, f"Only {processor_count} processors running (need {min_processors})" + else: + return False, f"Metrics returned status {resp.status_code}" + + except Exception as e: + return False, f"Processor check error: {e}" + + +def check_flow_classes(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if flow classes are loaded.""" + try: + api = Api(url, token=token, timeout=timeout) + flow_api = api.flow() + + classes = flow_api.list_classes() + + if classes and len(classes) > 0: + return True, f"Found {len(classes)} flow class(es)" + else: + return False, "No flow classes found" + + except Exception as e: + return False, f"Flow classes check error: {e}" + + +def check_flows(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if flow manager is responding.""" + try: + api = Api(url, token=token, timeout=timeout) + flow_api = api.flow() + + flows = flow_api.list() + + # Success if we get a response (even if empty) + return True, f"Flow manager responding ({len(flows)} flow(s))" + + except Exception as e: + return False, f"Flow manager check error: {e}" + + +def check_prompts(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if prompts are loaded.""" + try: + api = Api(url, token=token, timeout=timeout) + config = api.config() + + # Import ConfigKey here to avoid top-level import issues + from trustgraph.api.types import ConfigKey + import json + + # Get the template-index which lists all prompts + values = config.get([ + ConfigKey(type="prompt", key="template-index") + ]) + + ix = json.loads(values[0].value) + + if ix and len(ix) > 0: + return True, f"Found {len(ix)} prompt(s)" + else: + return False, "No prompts found" + + except Exception as e: + return False, f"Prompts check error: {e}" + + +def check_library(url: str, timeout: int, token: Optional[str] = None) -> Tuple[bool, str]: + """Check if library service is responding.""" + try: + api = Api(url, token=token, timeout=timeout) + library_api = api.library() + + # Try to get documents (with default user) + docs = library_api.get_documents(user="trustgraph") + + # Success if we get a valid response (even if empty) + return True, f"Library responding ({len(docs)} document(s))" + + except Exception as e: + return False, f"Library check error: {e}" + + +def check_ui(url: str, timeout: int) -> Tuple[bool, str]: + """Check if Workbench UI is responding.""" + try: + if not url.endswith('/'): + url += '/' + + resp = requests.get(f"{url}index.html", timeout=timeout) + if resp.status_code == 200: + return True, "Workbench UI is responding" + else: + return False, f"UI returned status {resp.status_code}" + except requests.exceptions.Timeout: + return False, "UI connection timeout" + except requests.exceptions.ConnectionError: + return False, "Cannot connect to UI" + except Exception as e: + return False, f"UI error: {e}" + + +def main(): + """Main entry point for the CLI.""" + + parser = argparse.ArgumentParser( + prog='tg-verify-system-status', + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--global-timeout', + type=int, + default=120, + help='Total timeout in seconds (default: 120)' + ) + + parser.add_argument( + '--check-timeout', + type=int, + default=10, + help='Per-check timeout in seconds (default: 10)' + ) + + parser.add_argument( + '--retry-delay', + type=int, + default=3, + help='Delay between retries in seconds (default: 3)' + ) + + parser.add_argument( + '--min-processors', + type=int, + default=15, + help='Minimum processors required (default: 15)' + ) + + parser.add_argument( + '--pulsar-url', + default=default_pulsar_url, + help=f'Pulsar admin URL (default: {default_pulsar_url})' + ) + + parser.add_argument( + '--api-url', + default=default_api_url, + help=f'API Gateway URL (default: {default_api_url})' + ) + + parser.add_argument( + '--ui-url', + default=default_ui_url, + help=f'Workbench UI URL (default: {default_ui_url})' + ) + + parser.add_argument( + '--skip-ui', + action='store_true', + help='Skip UI check (for headless deployments)' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)' + ) + + parser.add_argument( + '-v', '--verbose', + action='store_true', + help='Show detailed output' + ) + + args = parser.parse_args() + + # Create health checker + checker = HealthChecker( + global_timeout=args.global_timeout, + check_timeout=args.check_timeout, + retry_delay=args.retry_delay, + verbose=args.verbose + ) + + print("=" * 60) + print("TrustGraph System Status Verification") + print("=" * 60) +# print(f"Global timeout: {args.global_timeout}s") +# print(f"Check timeout: {args.check_timeout}s") +# print(f"Retry delay: {args.retry_delay}s") +# print("=" * 60) + print() + + # Phase 1: Infrastructure + print("Phase 1: Infrastructure") + print("-" * 60) + + if not checker.run_check( + "Pulsar", + check_pulsar, + args.pulsar_url, + args.check_timeout + ): + print("\n⚠️ Pulsar is not responding - other checks may fail") + print() + + checker.run_check( + "API Gateway", + check_api_gateway, + args.api_url, + args.check_timeout, + args.token + ) + + print() + + # Phase 2: Core Services + print("Phase 2: Core Services") + print("-" * 60) + + checker.run_check( + "Processors", + check_processors, + args.api_url, + args.min_processors, + args.check_timeout, + args.token + ) + + checker.run_check( + "Flow Classes", + check_flow_classes, + args.api_url, + args.check_timeout, + args.token + ) + + checker.run_check( + "Flows", + check_flows, + args.api_url, + args.check_timeout, + args.token + ) + + checker.run_check( + "Prompts", + check_prompts, + args.api_url, + args.check_timeout, + args.token + ) + + print() + + # Phase 3: Data Services + print("Phase 3: Data Services") + print("-" * 60) + + checker.run_check( + "Library", + check_library, + args.api_url, + args.check_timeout, + args.token + ) + + print() + + # Phase 4: UI (optional) + if not args.skip_ui: + print("Phase 4: User Interface") + print("-" * 60) + + checker.run_check( + "Workbench UI", + check_ui, + args.ui_url, + args.check_timeout + ) + + print() + + # Summary + print("=" * 60) + print("Summary") + print("=" * 60) + + total_checks = checker.checks_passed + checker.checks_failed + + print(f"Checks passed: {checker.checks_passed}/{total_checks}") + print(f"Checks failed: {checker.checks_failed}/{total_checks}") + print(f"Total time: {checker.elapsed()}") + + if checker.checks_failed == 0: + print("\n✓ System is healthy!") + sys.exit(0) + else: + print(f"\n✗ System has {checker.checks_failed} failing check(s)") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 478b0fff..3d4fa65c 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", - "trustgraph-flow>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", + "trustgraph-flow>=1.8,<1.9", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 479c0adc..70140147 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", "aiohttp", "anthropic", "scylla-driver", diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 90bc445c..87cee33d 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -241,8 +241,8 @@ class AgentManager: logger.info("DEBUG: StreamingReActParser created") # Create async chunk callback that feeds parser and sends collected chunks - async def on_chunk(text): - logger.info(f"DEBUG: on_chunk called with {len(text)} chars") + async def on_chunk(text, end_of_stream): + logger.info(f"DEBUG: on_chunk called with {len(text)} chars, end_of_stream={end_of_stream}") # Track what we had before prev_thought_count = len(thought_chunks) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index a4238e36..d4a4d72f 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -433,13 +433,11 @@ class Processor(AgentService): end_of_dialog=True, # Legacy fields for backward compatibility error=error_obj, - response=None, ) else: # Legacy format r = AgentResponse( error=error_obj, - response=None, ) await respond(r) diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index 701d7f58..c0a5be1e 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -95,9 +95,6 @@ class Configuration: return ConfigResponse( version = await self.get_version(), values = values, - directory = None, - config = None, - error = None, ) async def handle_list(self, v): @@ -117,10 +114,7 @@ class Configuration: return ConfigResponse( version = await self.get_version(), - values = None, directory = await self.table_store.get_keys(v.type), - config = None, - error = None, ) async def handle_getvalues(self, v): @@ -150,9 +144,6 @@ class Configuration: return ConfigResponse( version = await self.get_version(), values = list(values), - directory = None, - config = None, - error = None, ) async def handle_delete(self, v): @@ -179,12 +170,6 @@ class Configuration: await self.push() return ConfigResponse( - version = None, - value = None, - directory = None, - values = None, - config = None, - error = None, ) async def handle_put(self, v): @@ -198,11 +183,6 @@ class Configuration: await self.push() return ConfigResponse( - version = None, - value = None, - directory = None, - values = None, - error = None, ) async def get_config(self): @@ -224,11 +204,7 @@ class Configuration: return ConfigResponse( version = await self.get_version(), - value = None, - directory = None, - values = None, config = config, - error = None, ) async def handle(self, msg): @@ -262,9 +238,6 @@ class Configuration: else: resp = ConfigResponse( - value=None, - directory=None, - values=None, error=Error( type = "bad-operation", message = "Bad operation" diff --git a/trustgraph-flow/trustgraph/config/service/flow.py b/trustgraph-flow/trustgraph/config/service/flow.py index b99b7d0a..42696c31 100644 --- a/trustgraph-flow/trustgraph/config/service/flow.py +++ b/trustgraph-flow/trustgraph/config/service/flow.py @@ -361,9 +361,6 @@ class FlowConfig: else: resp = FlowResponse( - value=None, - directory=None, - values=None, error=Error( type = "bad-operation", message = "Bad operation" diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index 84ed2a6a..42b256df 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -26,9 +26,6 @@ from ... base import Consumer, Producer # Module logger logger = logging.getLogger(__name__) -# FIXME: How to ensure this doesn't conflict with other usage? -keyspace = "config" - default_ident = "config-svc" default_config_request_queue = config_request_queue @@ -64,12 +61,13 @@ class Processor(AsyncProcessor): cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - + # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, - password=cassandra_password + password=cassandra_password, + default_keyspace="config" ) # Store resolved configuration @@ -114,7 +112,7 @@ class Processor(AsyncProcessor): self.config_request_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, topic = config_request_queue, subscriber = id, @@ -124,14 +122,14 @@ class Processor(AsyncProcessor): ) self.config_response_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = config_response_queue, schema = ConfigResponse, metrics = config_response_metrics, ) self.config_push_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = config_push_queue, schema = ConfigPush, metrics = config_push_metrics, @@ -139,7 +137,7 @@ class Processor(AsyncProcessor): self.flow_request_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, topic = flow_request_queue, subscriber = id, @@ -149,7 +147,7 @@ class Processor(AsyncProcessor): ) self.flow_response_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = flow_response_queue, schema = FlowResponse, metrics = flow_response_metrics, @@ -180,11 +178,7 @@ class Processor(AsyncProcessor): resp = ConfigPush( version = version, - value = None, - directory = None, - values = None, config = config, - error = None, ) await self.config_push_producer.send(resp) @@ -217,7 +211,6 @@ class Processor(AsyncProcessor): type = "config-error", message = str(e), ), - text=None, ) await self.config_response_producer.send( @@ -242,13 +235,12 @@ class Processor(AsyncProcessor): ) except Exception as e: - + resp = FlowResponse( error=Error( type = "flow-error", message = str(e), ), - text=None, ) await self.flow_response_producer.send( @@ -272,11 +264,7 @@ class Processor(AsyncProcessor): help=f'Config response queue {default_config_response_queue}', ) - parser.add_argument( - '--push-queue', - default=default_config_push_queue, - help=f'Config push queue (default: {default_config_push_queue})' - ) + # Note: --config-push-queue is already added by AsyncProcessor.add_args() parser.add_argument( '--flow-request-queue', diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index 449f1c3b..0d5c3d82 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -234,11 +234,11 @@ class KnowledgeManager: logger.debug(f"Graph embeddings queue: {ge_q}") t_pub = Publisher( - self.flow_config.pulsar_client, t_q, + self.flow_config.pubsub, t_q, schema=Triples, ) ge_pub = Publisher( - self.flow_config.pulsar_client, ge_q, + self.flow_config.pubsub, ge_q, schema=GraphEmbeddings ) diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 9cb0e1d0..18154fc5 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -33,9 +33,6 @@ default_knowledge_response_queue = knowledge_response_queue default_cassandra_host = "cassandra" -# FIXME: How to ensure this doesn't conflict with other usage? -keyspace = "knowledge" - class Processor(AsyncProcessor): def __init__(self, **params): @@ -53,14 +50,15 @@ class Processor(AsyncProcessor): cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - + # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, - password=cassandra_password + password=cassandra_password, + default_keyspace="knowledge" ) - + # Store resolved configuration self.cassandra_host = hosts self.cassandra_username = username @@ -86,7 +84,7 @@ class Processor(AsyncProcessor): self.knowledge_request_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, topic = knowledge_request_queue, subscriber = id, @@ -96,7 +94,7 @@ class Processor(AsyncProcessor): ) self.knowledge_response_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = knowledge_response_queue, schema = KnowledgeResponse, metrics = knowledge_response_metrics, diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py b/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py new file mode 100644 index 00000000..712aadbe --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/entity_normalizer.py @@ -0,0 +1,164 @@ +""" +Entity URI normalization for ontology-based knowledge extraction. + +Converts entity names and types into consistent, collision-free URIs. +""" + +import re +from typing import Tuple + + +def normalize_entity_name(entity_name: str) -> str: + """Normalize entity name to URI-safe identifier. + + Args: + entity_name: Natural language entity name (e.g., "Cornish pasty") + + Returns: + Normalized identifier (e.g., "cornish-pasty") + """ + # Convert to lowercase + normalized = entity_name.lower() + + # Replace spaces and underscores with hyphens + normalized = re.sub(r'[\s_]+', '-', normalized) + + # Remove any characters that aren't alphanumeric, hyphens, or periods + normalized = re.sub(r'[^a-z0-9\-.]', '', normalized) + + # Remove leading/trailing hyphens + normalized = normalized.strip('-') + + # Collapse multiple hyphens + normalized = re.sub(r'-+', '-', normalized) + + return normalized + + +def normalize_type_identifier(type_id: str) -> str: + """Normalize ontology type identifier to URI-safe format. + + Handles prefixed types like "fo/Recipe" by converting to "fo-recipe". + + Args: + type_id: Ontology type identifier (e.g., "fo/Recipe", "Food") + + Returns: + Normalized type identifier (e.g., "fo-recipe", "food") + """ + # Convert to lowercase + normalized = type_id.lower() + + # Replace slashes, colons, and spaces with hyphens + normalized = re.sub(r'[/:.\s_]+', '-', normalized) + + # Remove any remaining non-alphanumeric characters except hyphens + normalized = re.sub(r'[^a-z0-9\-]', '', normalized) + + # Remove leading/trailing hyphens + normalized = normalized.strip('-') + + # Collapse multiple hyphens + normalized = re.sub(r'-+', '-', normalized) + + return normalized + + +def build_entity_uri(entity_name: str, entity_type: str, ontology_id: str, + base_uri: str = "https://trustgraph.ai") -> str: + """Build a unique URI for an entity based on its name and type. + + The type is included in the URI to prevent collisions when the same + name refers to different entity types (e.g., "Cornish pasty" as both + Recipe and Food). + + Args: + entity_name: Natural language entity name (e.g., "Cornish pasty") + entity_type: Ontology type (e.g., "fo/Recipe") + ontology_id: Ontology identifier (e.g., "food") + base_uri: Base URI for entity URIs (default: "https://trustgraph.ai") + + Returns: + Full entity URI (e.g., "https://trustgraph.ai/food/fo-recipe-cornish-pasty") + + Examples: + >>> build_entity_uri("Cornish pasty", "fo/Recipe", "food") + 'https://trustgraph.ai/food/fo-recipe-cornish-pasty' + + >>> build_entity_uri("Cornish pasty", "fo/Food", "food") + 'https://trustgraph.ai/food/fo-food-cornish-pasty' + + >>> build_entity_uri("beef", "fo/Food", "food") + 'https://trustgraph.ai/food/fo-food-beef' + """ + type_part = normalize_type_identifier(entity_type) + name_part = normalize_entity_name(entity_name) + + # Combine type and name to ensure uniqueness + entity_id = f"{type_part}-{name_part}" + + # Build full URI + return f"{base_uri}/{ontology_id}/{entity_id}" + + +class EntityRegistry: + """Registry to track entity name/type tuples and their assigned URIs. + + Ensures that the same (entity_name, entity_type) tuple always maps + to the same URI, enabling deduplication across the extraction process. + """ + + def __init__(self, ontology_id: str, base_uri: str = "https://trustgraph.ai"): + """Initialize the entity registry. + + Args: + ontology_id: Ontology identifier (e.g., "food") + base_uri: Base URI for entity URIs + """ + self.ontology_id = ontology_id + self.base_uri = base_uri + self._registry = {} # (entity_name, entity_type) -> uri + + def get_or_create_uri(self, entity_name: str, entity_type: str) -> str: + """Get existing URI or create new one for entity. + + Args: + entity_name: Natural language entity name + entity_type: Ontology type identifier + + Returns: + URI for this entity (same URI for same name/type tuple) + """ + key = (entity_name, entity_type) + + if key not in self._registry: + uri = build_entity_uri( + entity_name, + entity_type, + self.ontology_id, + self.base_uri + ) + self._registry[key] = uri + + return self._registry[key] + + def lookup(self, entity_name: str, entity_type: str) -> str: + """Look up URI for entity (returns None if not registered). + + Args: + entity_name: Natural language entity name + entity_type: Ontology type identifier + + Returns: + URI for this entity, or None if not found + """ + key = (entity_name, entity_type) + return self._registry.get(key) + + def clear(self): + """Clear all registered entities.""" + self._registry.clear() + + def size(self) -> int: + """Get number of registered entities.""" + return len(self._registry) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 12832eaf..335f07d2 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -20,6 +20,8 @@ from .ontology_embedder import OntologyEmbedder from .vector_store import InMemoryVectorStore from .text_processor import TextProcessor from .ontology_selector import OntologySelector, OntologySubset +from .simplified_parser import parse_extraction_response +from .triple_converter import TripleConverter logger = logging.getLogger(__name__) @@ -298,25 +300,10 @@ class Processor(FlowProcessor): # Build extraction prompt variables prompt_variables = self.build_extraction_variables(chunk, ontology_subset) - # Call prompt service for extraction - try: - # Use prompt() method with extract-with-ontologies prompt ID - triples_response = await flow("prompt-request").prompt( - id="extract-with-ontologies", - variables=prompt_variables - ) - logger.debug(f"Extraction response: {triples_response}") - - if not isinstance(triples_response, list): - logger.error("Expected list of triples from prompt service") - triples_response = [] - - except Exception as e: - logger.error(f"Prompt service error: {e}", exc_info=True) - triples_response = [] - - # Parse and validate triples - triples = self.parse_and_validate_triples(triples_response, ontology_subset) + # Extract using simplified entity-relationship-attribute format + triples = await self.extract_with_simplified_format( + flow, chunk, ontology_subset, prompt_variables + ) # Add metadata triples for t in v.metadata.metadata: @@ -362,6 +349,55 @@ class Processor(FlowProcessor): [] ) + async def extract_with_simplified_format( + self, + flow, + chunk: str, + ontology_subset: OntologySubset, + prompt_variables: Dict[str, Any] + ) -> List[Triple]: + """Extract triples using simplified entity-relationship-attribute format. + + Args: + flow: Flow object for accessing services + chunk: Text chunk to extract from + ontology_subset: Selected ontology subset + prompt_variables: Variables for prompt template + + Returns: + List of Triple objects + """ + try: + # Call prompt service with simplified format prompt + extraction_response = await flow("prompt-request").prompt( + id="extract-with-ontologies", + variables=prompt_variables + ) + logger.debug(f"Simplified extraction response: {extraction_response}") + + # Parse response into structured format + extraction_result = parse_extraction_response(extraction_response) + + if not extraction_result: + logger.warning("Failed to parse extraction response") + return [] + + logger.info(f"Parsed {len(extraction_result.entities)} entities, " + f"{len(extraction_result.relationships)} relationships, " + f"{len(extraction_result.attributes)} attributes") + + # Convert to RDF triples + converter = TripleConverter(ontology_subset, ontology_subset.ontology_id) + triples = converter.convert_all(extraction_result) + + logger.info(f"Generated {len(triples)} RDF triples from simplified extraction") + + return triples + + except Exception as e: + logger.error(f"Simplified extraction error: {e}", exc_info=True) + return [] + def build_extraction_variables(self, chunk: str, ontology_subset: OntologySubset) -> Dict[str, Any]: """Build variables for ontology-based extraction prompt template. diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py b/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py new file mode 100644 index 00000000..3131d977 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/simplified_parser.py @@ -0,0 +1,234 @@ +""" +Parser for simplified ontology extraction JSON format. + +Parses the new entity-relationship-attribute format from LLM responses. +""" + +import json +import logging +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class Entity: + """Represents an extracted entity.""" + entity: str + type: str + + +@dataclass +class Relationship: + """Represents an extracted relationship.""" + subject: str + subject_type: str + relation: str + object: str + object_type: str + + +@dataclass +class Attribute: + """Represents an extracted attribute.""" + entity: str + entity_type: str + attribute: str + value: str + + +@dataclass +class ExtractionResult: + """Complete extraction result.""" + entities: List[Entity] + relationships: List[Relationship] + attributes: List[Attribute] + + +def parse_extraction_response(response: Any) -> Optional[ExtractionResult]: + """Parse LLM extraction response into structured format. + + Args: + response: LLM response (string JSON or already parsed dict) + + Returns: + ExtractionResult with parsed entities/relationships/attributes, + or None if parsing fails + """ + # Handle string response (parse JSON) + if isinstance(response, str): + try: + data = json.loads(response) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON response: {e}") + logger.debug(f"Response was: {response[:500]}") + return None + elif isinstance(response, dict): + data = response + else: + logger.error(f"Unexpected response type: {type(response)}") + return None + + # Validate structure + if not isinstance(data, dict): + logger.error(f"Expected dict, got {type(data)}") + return None + + # Parse entities + entities = [] + entities_data = data.get('entities', []) + if not isinstance(entities_data, list): + logger.warning(f"'entities' is not a list: {type(entities_data)}") + entities_data = [] + + for entity_data in entities_data: + try: + entity = parse_entity(entity_data) + if entity: + entities.append(entity) + except Exception as e: + logger.warning(f"Failed to parse entity {entity_data}: {e}") + + # Parse relationships + relationships = [] + relationships_data = data.get('relationships', []) + if not isinstance(relationships_data, list): + logger.warning(f"'relationships' is not a list: {type(relationships_data)}") + relationships_data = [] + + for rel_data in relationships_data: + try: + relationship = parse_relationship(rel_data) + if relationship: + relationships.append(relationship) + except Exception as e: + logger.warning(f"Failed to parse relationship {rel_data}: {e}") + + # Parse attributes + attributes = [] + attributes_data = data.get('attributes', []) + if not isinstance(attributes_data, list): + logger.warning(f"'attributes' is not a list: {type(attributes_data)}") + attributes_data = [] + + for attr_data in attributes_data: + try: + attribute = parse_attribute(attr_data) + if attribute: + attributes.append(attribute) + except Exception as e: + logger.warning(f"Failed to parse attribute {attr_data}: {e}") + + return ExtractionResult( + entities=entities, + relationships=relationships, + attributes=attributes + ) + + +def parse_entity(data: Dict[str, Any]) -> Optional[Entity]: + """Parse entity from dict. + + Supports both kebab-case and snake_case field names for compatibility. + + Args: + data: Entity dict with 'entity' and 'type' fields + + Returns: + Entity object or None if invalid + """ + if not isinstance(data, dict): + logger.warning(f"Entity data is not a dict: {type(data)}") + return None + + entity = data.get('entity') + entity_type = data.get('type') + + if not entity or not entity_type: + logger.warning(f"Missing required fields in entity: {data}") + return None + + if not isinstance(entity, str) or not isinstance(entity_type, str): + logger.warning(f"Entity fields must be strings: {data}") + return None + + return Entity(entity=entity, type=entity_type) + + +def parse_relationship(data: Dict[str, Any]) -> Optional[Relationship]: + """Parse relationship from dict. + + Supports both kebab-case and snake_case field names for compatibility. + + Args: + data: Relationship dict with subject, subject-type, relation, object, object-type + + Returns: + Relationship object or None if invalid + """ + if not isinstance(data, dict): + logger.warning(f"Relationship data is not a dict: {type(data)}") + return None + + subject = data.get('subject') + subject_type = data.get('subject-type') or data.get('subject_type') + relation = data.get('relation') + obj = data.get('object') + object_type = data.get('object-type') or data.get('object_type') + + if not all([subject, subject_type, relation, obj, object_type]): + logger.warning(f"Missing required fields in relationship: {data}") + return None + + if not all(isinstance(v, str) for v in [subject, subject_type, relation, obj, object_type]): + logger.warning(f"Relationship fields must be strings: {data}") + return None + + return Relationship( + subject=subject, + subject_type=subject_type, + relation=relation, + object=obj, + object_type=object_type + ) + + +def parse_attribute(data: Dict[str, Any]) -> Optional[Attribute]: + """Parse attribute from dict. + + Supports both kebab-case and snake_case field names for compatibility. + + Args: + data: Attribute dict with entity, entity-type, attribute, value + + Returns: + Attribute object or None if invalid + """ + if not isinstance(data, dict): + logger.warning(f"Attribute data is not a dict: {type(data)}") + return None + + entity = data.get('entity') + entity_type = data.get('entity-type') or data.get('entity_type') + attribute = data.get('attribute') + value = data.get('value') + + if not all([entity, entity_type, attribute, value is not None]): + logger.warning(f"Missing required fields in attribute: {data}") + return None + + if not all(isinstance(v, str) for v in [entity, entity_type, attribute]): + logger.warning(f"Attribute fields must be strings: {data}") + return None + + # Value can be string, number, bool - convert to string + if not isinstance(value, str): + value = str(value) + + return Attribute( + entity=entity, + entity_type=entity_type, + attribute=attribute, + value=value + ) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py new file mode 100644 index 00000000..2eb43b19 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/triple_converter.py @@ -0,0 +1,228 @@ +""" +Converts simplified extraction format to RDF triples. + +Transforms entities, relationships, and attributes into proper RDF triples +with full URIs and correct is_uri flags. +""" + +import logging +from typing import List, Optional + +from .... schema import Triple, Value +from .... rdf import RDF_TYPE, RDF_LABEL + +from .simplified_parser import Entity, Relationship, Attribute, ExtractionResult +from .entity_normalizer import EntityRegistry +from .ontology_selector import OntologySubset + +logger = logging.getLogger(__name__) + + +class TripleConverter: + """Converts extraction results to RDF triples.""" + + def __init__(self, ontology_subset: OntologySubset, ontology_id: str): + """Initialize converter. + + Args: + ontology_subset: Ontology subset with classes and properties + ontology_id: Ontology identifier for URI generation + """ + self.ontology_subset = ontology_subset + self.ontology_id = ontology_id + self.entity_registry = EntityRegistry(ontology_id) + + def convert_all(self, extraction: ExtractionResult) -> List[Triple]: + """Convert complete extraction result to RDF triples. + + Args: + extraction: Parsed extraction with entities/relationships/attributes + + Returns: + List of RDF Triple objects + """ + triples = [] + + # Convert entities (generates type + label triples) + for entity in extraction.entities: + entity_triples = self.convert_entity(entity) + triples.extend(entity_triples) + + # Convert relationships + for relationship in extraction.relationships: + rel_triple = self.convert_relationship(relationship) + if rel_triple: + triples.append(rel_triple) + + # Convert attributes + for attribute in extraction.attributes: + attr_triple = self.convert_attribute(attribute) + if attr_triple: + triples.append(attr_triple) + + return triples + + def convert_entity(self, entity: Entity) -> List[Triple]: + """Convert entity to RDF triples (type + label). + + Args: + entity: Entity object with name and type + + Returns: + List containing type triple and label triple + """ + triples = [] + + # Get or create URI for this entity + entity_uri = self.entity_registry.get_or_create_uri( + entity.entity, + entity.type + ) + + # Look up class URI from ontology + class_uri = self._get_class_uri(entity.type) + if not class_uri: + logger.warning(f"Unknown entity type '{entity.type}', skipping entity '{entity.entity}'") + return triples + + # Generate type triple: entity rdf:type ClassURI + type_triple = Triple( + s=Value(value=entity_uri, is_uri=True), + p=Value(value=RDF_TYPE, is_uri=True), + o=Value(value=class_uri, is_uri=True) + ) + triples.append(type_triple) + + # Generate label triple: entity rdfs:label "entity name" + label_triple = Triple( + s=Value(value=entity_uri, is_uri=True), + p=Value(value=RDF_LABEL, is_uri=True), + o=Value(value=entity.entity, is_uri=False) # Literal! + ) + triples.append(label_triple) + + return triples + + def convert_relationship(self, relationship: Relationship) -> Optional[Triple]: + """Convert relationship to RDF triple. + + Args: + relationship: Relationship with subject/object entities and relation + + Returns: + Triple connecting two entity URIs via property URI, or None if invalid + """ + # Get URIs for subject and object entities + subject_uri = self.entity_registry.get_or_create_uri( + relationship.subject, + relationship.subject_type + ) + + object_uri = self.entity_registry.get_or_create_uri( + relationship.object, + relationship.object_type + ) + + # Look up property URI from ontology + property_uri = self._get_object_property_uri(relationship.relation) + if not property_uri: + logger.warning(f"Unknown relationship '{relationship.relation}', skipping") + return None + + # Generate triple: subject property object + return Triple( + s=Value(value=subject_uri, is_uri=True), + p=Value(value=property_uri, is_uri=True), + o=Value(value=object_uri, is_uri=True) + ) + + def convert_attribute(self, attribute: Attribute) -> Optional[Triple]: + """Convert attribute to RDF triple. + + Args: + attribute: Attribute with entity, attribute name, and literal value + + Returns: + Triple with entity URI, property URI, and literal value, or None if invalid + """ + # Get URI for entity + entity_uri = self.entity_registry.get_or_create_uri( + attribute.entity, + attribute.entity_type + ) + + # Look up property URI from ontology + property_uri = self._get_datatype_property_uri(attribute.attribute) + if not property_uri: + logger.warning(f"Unknown attribute '{attribute.attribute}', skipping") + return None + + # Generate triple: entity property "literal value" + return Triple( + s=Value(value=entity_uri, is_uri=True), + p=Value(value=property_uri, is_uri=True), + o=Value(value=attribute.value, is_uri=False) # Literal! + ) + + def _get_class_uri(self, class_id: str) -> Optional[str]: + """Get full URI for ontology class. + + Args: + class_id: Class identifier (e.g., "fo/Recipe") + + Returns: + Full class URI or None if not found + """ + if class_id not in self.ontology_subset.classes: + return None + + class_def = self.ontology_subset.classes[class_id] + + # Extract URI from class definition + if isinstance(class_def, dict) and 'uri' in class_def: + return class_def['uri'] + + # Fallback: construct URI + return f"https://trustgraph.ai/ontology/{self.ontology_id}#{class_id}" + + def _get_object_property_uri(self, property_id: str) -> Optional[str]: + """Get full URI for object property. + + Args: + property_id: Property identifier (e.g., "fo/has_ingredient") + + Returns: + Full property URI or None if not found + """ + if property_id not in self.ontology_subset.object_properties: + return None + + prop_def = self.ontology_subset.object_properties[property_id] + + # Extract URI from property definition + if isinstance(prop_def, dict) and 'uri' in prop_def: + return prop_def['uri'] + + # Fallback: construct URI + return f"https://trustgraph.ai/ontology/{self.ontology_id}#{property_id}" + + def _get_datatype_property_uri(self, property_id: str) -> Optional[str]: + """Get full URI for datatype property. + + Args: + property_id: Property identifier (e.g., "fo/serves") + + Returns: + Full property URI or None if not found + """ + if property_id not in self.ontology_subset.datatype_properties: + return None + + prop_def = self.ontology_subset.datatype_properties[property_id] + + # Extract URI from property definition + if isinstance(prop_def, dict) and 'uri' in prop_def: + return prop_def['uri'] + + # Fallback: construct URI + return f"https://trustgraph.ai/ontology/{self.ontology_id}#{property_id}" diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index 0427e236..bdd123a9 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -34,9 +34,9 @@ logger.setLevel(logging.INFO) class ConfigReceiver: - def __init__(self, pulsar_client): + def __init__(self, backend): - self.pulsar_client = pulsar_client + self.backend = backend self.flow_handlers = [] @@ -104,8 +104,8 @@ class ConfigReceiver: self.config_cons = Consumer( taskgroup = tg, flow = None, - client = self.pulsar_client, - subscriber = f"gateway-{id}", + backend = self.backend, + subscriber = f"gateway-{id}", topic = config_push_queue, schema = ConfigPush, handler = self.on_config, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py index 1a5e8299..8867956d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class AgentRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(AgentRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=AgentRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py index f2755ae8..2fa3759d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/collection_management.py @@ -5,14 +5,20 @@ from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor class CollectionManagementRequestor(ServiceRequestor): - def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = collection_request_queue + if response_queue is None: + response_queue = collection_response_queue super(CollectionManagementRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, consumer_name = consumer, subscription = subscriber, - request_queue=collection_request_queue, - response_queue=collection_response_queue, + request_queue=request_queue, + response_queue=response_queue, request_schema=CollectionManagementRequest, response_schema=CollectionManagementResponse, timeout=timeout, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/config.py b/trustgraph-flow/trustgraph/gateway/dispatch/config.py index c4fac5fa..9d40e8cc 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/config.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/config.py @@ -7,14 +7,20 @@ from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor class ConfigRequestor(ServiceRequestor): - def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = config_request_queue + if response_queue is None: + response_queue = config_response_queue super(ConfigRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, consumer_name = consumer, subscription = subscriber, - request_queue=config_request_queue, - response_queue=config_response_queue, + request_queue=request_queue, + response_queue=response_queue, request_schema=ConfigRequest, response_schema=ConfigResponse, timeout=timeout, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 61b0bcbc..62626046 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -10,9 +10,9 @@ logger = logging.getLogger(__name__) class CoreExport: - def __init__(self, pulsar_client): - self.pulsar_client = pulsar_client - + def __init__(self, backend): + self.backend = backend + async def process(self, data, error, ok, request): id = request.query["id"] @@ -21,7 +21,7 @@ class CoreExport: response = await ok() kr = KnowledgeRequestor( - pulsar_client = self.pulsar_client, + backend = self.backend, consumer = "api-gateway-core-export-" + str(uuid.uuid4()), subscriber = "api-gateway-core-export-" + str(uuid.uuid4()), ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index b32fb7f7..af22a5b0 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -11,8 +11,8 @@ logger = logging.getLogger(__name__) class CoreImport: - def __init__(self, pulsar_client): - self.pulsar_client = pulsar_client + def __init__(self, backend): + self.backend = backend async def process(self, data, error, ok, request): @@ -20,7 +20,7 @@ class CoreImport: user = request.query["user"] kr = KnowledgeRequestor( - pulsar_client = self.pulsar_client, + backend = self.backend, consumer = "api-gateway-core-import-" + str(uuid.uuid4()), subscriber = "api-gateway-core-import-" + str(uuid.uuid4()), ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py index f7d53005..8866972d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_export.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) class DocumentEmbeddingsExport: def __init__( - self, ws, running, pulsar_client, queue, consumer, subscriber + self, ws, running, backend, queue, consumer, subscriber ): self.ws = ws self.running = running - self.pulsar_client = pulsar_client + self.backend = backend self.queue = queue self.consumer = consumer self.subscriber = subscriber @@ -48,9 +48,9 @@ class DocumentEmbeddingsExport: async def run(self): """Enhanced run with better error handling""" self.subs = Subscriber( - client = self.pulsar_client, + backend = self.backend, topic = self.queue, - consumer_name = self.consumer, + consumer_name = self.consumer, subscription = self.subscriber, schema = DocumentEmbeddings, backpressure_strategy = "block" # Configurable diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py index 7ec2f595..bd5f9666 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class DocumentEmbeddingsImport: def __init__( - self, ws, running, pulsar_client, queue + self, ws, running, backend, queue ): self.ws = ws @@ -23,7 +23,7 @@ class DocumentEmbeddingsImport: self.translator = DocumentEmbeddingsTranslator() self.publisher = Publisher( - pulsar_client, topic = queue, schema = DocumentEmbeddings + backend, topic = queue, schema = DocumentEmbeddings ) async def start(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py index 7e38877c..eb68b0b1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py @@ -11,10 +11,10 @@ from . sender import ServiceSender logger = logging.getLogger(__name__) class DocumentLoad(ServiceSender): - def __init__(self, pulsar_client, queue): + def __init__(self, backend, queue): super(DocumentLoad, self).__init__( - pulsar_client = pulsar_client, + backend = backend, queue = queue, schema = Document, ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py index a7f3634e..83b3cb9a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class DocumentRagRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(DocumentRagRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=DocumentRagQuery, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py index 47146e57..6c1b55ba 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class EmbeddingsRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(EmbeddingsRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=EmbeddingsRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py index 2be9c703..c03bdda6 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) class EntityContextsExport: def __init__( - self, ws, running, pulsar_client, queue, consumer, subscriber + self, ws, running, backend, queue, consumer, subscriber ): self.ws = ws self.running = running - self.pulsar_client = pulsar_client + self.backend = backend self.queue = queue self.consumer = consumer self.subscriber = subscriber @@ -48,9 +48,9 @@ class EntityContextsExport: async def run(self): """Enhanced run with better error handling""" self.subs = Subscriber( - client = self.pulsar_client, + backend = self.backend, topic = self.queue, - consumer_name = self.consumer, + consumer_name = self.consumer, subscription = self.subscriber, schema = EntityContexts, backpressure_strategy = "block" # Configurable diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index c76f1612..6e01a5ca 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -16,14 +16,14 @@ logger = logging.getLogger(__name__) class EntityContextsImport: def __init__( - self, ws, running, pulsar_client, queue + self, ws, running, backend, queue ): self.ws = ws self.running = running self.publisher = Publisher( - pulsar_client, topic = queue, schema = EntityContexts + backend, topic = queue, schema = EntityContexts ) async def start(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py index 30f8d45e..be91995d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py @@ -7,14 +7,20 @@ from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor class FlowRequestor(ServiceRequestor): - def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = flow_request_queue + if response_queue is None: + response_queue = flow_response_queue super(FlowRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, consumer_name = consumer, subscription = subscriber, - request_queue=flow_request_queue, - response_queue=flow_response_queue, + request_queue=request_queue, + response_queue=response_queue, request_schema=FlowRequest, response_schema=FlowResponse, timeout=timeout, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py index d4abec73..d6d7a1c5 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_export.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) class GraphEmbeddingsExport: def __init__( - self, ws, running, pulsar_client, queue, consumer, subscriber + self, ws, running, backend, queue, consumer, subscriber ): self.ws = ws self.running = running - self.pulsar_client = pulsar_client + self.backend = backend self.queue = queue self.consumer = consumer self.subscriber = subscriber @@ -48,9 +48,9 @@ class GraphEmbeddingsExport: async def run(self): """Enhanced run with better error handling""" self.subs = Subscriber( - client = self.pulsar_client, + backend = self.backend, topic = self.queue, - consumer_name = self.consumer, + consumer_name = self.consumer, subscription = self.subscriber, schema = GraphEmbeddings, backpressure_strategy = "block" # Configurable diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index ee3d88ef..8abf5e9c 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -16,14 +16,14 @@ logger = logging.getLogger(__name__) class GraphEmbeddingsImport: def __init__( - self, ws, running, pulsar_client, queue + self, ws, running, backend, queue ): self.ws = ws self.running = running self.publisher = Publisher( - pulsar_client, topic = queue, schema = GraphEmbeddings + backend, topic = queue, schema = GraphEmbeddings ) async def start(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py index f5be06fb..a7bb1bd8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class GraphEmbeddingsQueryRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(GraphEmbeddingsQueryRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=GraphEmbeddingsRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py index a15a1aee..a0299a43 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class GraphRagRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(GraphRagRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=GraphRagQuery, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py index 950b3430..83aefbd0 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py @@ -10,14 +10,20 @@ from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor class KnowledgeRequestor(ServiceRequestor): - def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = knowledge_request_queue + if response_queue is None: + response_queue = knowledge_response_queue super(KnowledgeRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, consumer_name = consumer, subscription = subscriber, - request_queue=knowledge_request_queue, - response_queue=knowledge_response_queue, + request_queue=request_queue, + response_queue=response_queue, request_schema=KnowledgeRequest, response_schema=KnowledgeResponse, timeout=timeout, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py index 2155aa5d..bbf7190e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py @@ -9,14 +9,20 @@ from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor class LibrarianRequestor(ServiceRequestor): - def __init__(self, pulsar_client, consumer, subscriber, timeout=120): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = librarian_request_queue + if response_queue is None: + response_queue = librarian_response_queue super(LibrarianRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, consumer_name = consumer, subscription = subscriber, - request_queue=librarian_request_queue, - response_queue=librarian_response_queue, + request_queue=request_queue, + response_queue=response_queue, request_schema=LibrarianRequest, response_schema=LibrarianResponse, timeout=timeout, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index a1821e84..0766e232 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -98,12 +98,17 @@ class DispatcherWrapper: class DispatcherManager: - def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"): - self.pulsar_client = pulsar_client + def __init__(self, backend, config_receiver, prefix="api-gateway", + queue_overrides=None): + self.backend = backend self.config_receiver = config_receiver self.config_receiver.add_handler(self) self.prefix = prefix + # Store queue overrides for global services + # Format: {"config": {"request": "...", "response": "..."}, ...} + self.queue_overrides = queue_overrides or {} + self.flows = {} self.dispatchers = {} @@ -128,12 +133,12 @@ class DispatcherManager: async def process_core_import(self, data, error, ok, request): - ci = CoreImport(self.pulsar_client) + ci = CoreImport(self.backend) return await ci.process(data, error, ok, request) async def process_core_export(self, data, error, ok, request): - ce = CoreExport(self.pulsar_client) + ce = CoreExport(self.backend) return await ce.process(data, error, ok, request) async def process_global_service(self, data, responder, params): @@ -148,11 +153,20 @@ class DispatcherManager: if key in self.dispatchers: return await self.dispatchers[key].process(data, responder) + # Get queue overrides if specified for this service + request_queue = None + response_queue = None + if kind in self.queue_overrides: + request_queue = self.queue_overrides[kind].get("request") + response_queue = self.queue_overrides[kind].get("response") + dispatcher = global_dispatchers[kind]( - pulsar_client = self.pulsar_client, + backend = self.backend, timeout = 120, consumer = f"{self.prefix}-{kind}-request", subscriber = f"{self.prefix}-{kind}-request", + request_queue = request_queue, + response_queue = response_queue, ) await dispatcher.start() @@ -202,7 +216,7 @@ class DispatcherManager: id = str(uuid.uuid4()) dispatcher = import_dispatchers[kind]( - pulsar_client = self.pulsar_client, + backend = self.backend, ws = ws, running = running, queue = qconfig, @@ -240,7 +254,7 @@ class DispatcherManager: id = str(uuid.uuid4()) dispatcher = export_dispatchers[kind]( - pulsar_client = self.pulsar_client, + backend = self.backend, ws = ws, running = running, queue = qconfig, @@ -282,7 +296,7 @@ class DispatcherManager: if kind in request_response_dispatchers: dispatcher = request_response_dispatchers[kind]( - pulsar_client = self.pulsar_client, + backend = self.backend, request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, @@ -291,7 +305,7 @@ class DispatcherManager: ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( - pulsar_client = self.pulsar_client, + backend = self.backend, queue = qconfig, ) else: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py index da2a7bb0..a5f9398e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mcp_tool.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class McpToolRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(McpToolRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=ToolRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py index 3cf5684a..3a6314f2 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/nlp_query.py @@ -5,12 +5,12 @@ from . requestor import ServiceRequestor class NLPQueryRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(NLPQueryRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=QuestionToStructuredQueryRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py index bc0c1b85..fc982b69 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/objects_import.py @@ -15,14 +15,14 @@ logger = logging.getLogger(__name__) class ObjectsImport: def __init__( - self, ws, running, pulsar_client, queue + self, ws, running, backend, queue ): self.ws = ws self.running = running self.publisher = Publisher( - pulsar_client, topic = queue, schema = ExtractedObject + backend, topic = queue, schema = ExtractedObject ) async def start(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py index 2f2535a9..fb8dc81d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/objects_query.py @@ -5,12 +5,12 @@ from . requestor import ServiceRequestor class ObjectsQueryRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(ObjectsQueryRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=ObjectsQueryRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py index 5c316cf6..23017733 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py @@ -8,12 +8,12 @@ from . requestor import ServiceRequestor class PromptRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(PromptRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=PromptRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/requestor.py b/trustgraph-flow/trustgraph/gateway/dispatch/requestor.py index 1acac5e5..e8f0a63e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/requestor.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/requestor.py @@ -13,7 +13,7 @@ class ServiceRequestor: def __init__( self, - pulsar_client, + backend, request_queue, request_schema, response_queue, response_schema, subscription="api-gateway", consumer_name="api-gateway", @@ -21,12 +21,12 @@ class ServiceRequestor: ): self.pub = Publisher( - pulsar_client, request_queue, + backend, request_queue, schema=request_schema, ) self.sub = Subscriber( - pulsar_client, response_queue, + backend, response_queue, subscription, consumer_name, response_schema ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/sender.py b/trustgraph-flow/trustgraph/gateway/dispatch/sender.py index 2435cdc1..17324b19 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/sender.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/sender.py @@ -14,12 +14,12 @@ class ServiceSender: def __init__( self, - pulsar_client, + backend, queue, schema, ): self.pub = Publisher( - pulsar_client, queue, + backend, queue, schema=schema, ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/streamer.py b/trustgraph-flow/trustgraph/gateway/dispatch/streamer.py index 54674906..9c6d4251 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/streamer.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/streamer.py @@ -13,7 +13,7 @@ class ServiceRequestor: def __init__( self, - pulsar_client, + backend, queue, schema, handler, subscription="api-gateway", consumer_name="api-gateway", @@ -21,7 +21,7 @@ class ServiceRequestor: ): self.sub = Subscriber( - pulsar_client, queue, + backend, queue, subscription, consumer_name, schema ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py b/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py index 8dae646d..895b55be 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/structured_diag.py @@ -5,12 +5,12 @@ from . requestor import ServiceRequestor class StructuredDiagRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(StructuredDiagRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=StructuredDataDiagnosisRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py index f08ef038..9a9fbb6a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/structured_query.py @@ -5,12 +5,12 @@ from . requestor import ServiceRequestor class StructuredQueryRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(StructuredQueryRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=StructuredQueryRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py index d29d1918..0e77584e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class TextCompletionRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(TextCompletionRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=TextCompletionRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py index 36922c89..b2562938 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py @@ -11,10 +11,10 @@ from . sender import ServiceSender logger = logging.getLogger(__name__) class TextLoad(ServiceSender): - def __init__(self, pulsar_client, queue): + def __init__(self, backend, queue): super(TextLoad, self).__init__( - pulsar_client = pulsar_client, + backend = backend, queue = queue, schema = TextDocument, ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py index ff91e461..69fc588d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_export.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) class TriplesExport: def __init__( - self, ws, running, pulsar_client, queue, consumer, subscriber + self, ws, running, backend, queue, consumer, subscriber ): self.ws = ws self.running = running - self.pulsar_client = pulsar_client + self.backend = backend self.queue = queue self.consumer = consumer self.subscriber = subscriber @@ -48,9 +48,9 @@ class TriplesExport: async def run(self): """Enhanced run with better error handling""" self.subs = Subscriber( - client = self.pulsar_client, + backend = self.backend, topic = self.queue, - consumer_name = self.consumer, + consumer_name = self.consumer, subscription = self.subscriber, schema = Triples, backpressure_strategy = "block" # Configurable diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 520a9cbc..6bb46975 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -16,14 +16,14 @@ logger = logging.getLogger(__name__) class TriplesImport: def __init__( - self, ws, running, pulsar_client, queue + self, ws, running, backend, queue ): self.ws = ws self.running = running self.publisher = Publisher( - pulsar_client, topic = queue, schema = Triples + backend, topic = queue, schema = Triples ) async def start(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py index d2def9c1..6b306139 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py @@ -6,12 +6,12 @@ from . requestor import ServiceRequestor class TriplesQueryRequestor(ServiceRequestor): def __init__( - self, pulsar_client, request_queue, response_queue, timeout, + self, backend, request_queue, response_queue, timeout, consumer, subscriber, ): super(TriplesQueryRequestor, self).__init__( - pulsar_client=pulsar_client, + backend=backend, request_queue=request_queue, response_queue=response_queue, request_schema=TriplesQueryRequest, diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 1e2fdb23..aaa6f725 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -9,7 +9,8 @@ from aiohttp import web import logging import os -from .. log_level import LogLevel +from trustgraph.base.logging import setup_logging +from trustgraph.base.pubsub import get_pubsub from . auth import Authenticator from . config.receiver import ConfigReceiver @@ -20,8 +21,15 @@ from . endpoint.manager import EndpointManager import pulsar from prometheus_client import start_http_server +# Import default queue names +from .. schema import ( + config_request_queue, config_response_queue, + flow_request_queue, flow_response_queue, + knowledge_request_queue, knowledge_response_queue, + librarian_request_queue, librarian_response_queue, +) + logger = logging.getLogger("api") -logger.setLevel(logging.INFO) default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090") @@ -43,15 +51,8 @@ class Api: self.pulsar_listener = config.get("pulsar_listener", None) - if self.pulsar_api_key: - self.pulsar_client = pulsar.Client( - self.pulsar_host, listener_name=self.pulsar_listener, - authentication=pulsar.AuthenticationToken(self.pulsar_api_key) - ) - else: - self.pulsar_client = pulsar.Client( - self.pulsar_host, listener_name=self.pulsar_listener, - ) + # Create backend using factory + self.pubsub_backend = get_pubsub(**config) self.prometheus_url = config.get( "prometheus_url", default_prometheus_url, @@ -68,12 +69,56 @@ class Api: else: self.auth = Authenticator(allow_all=True) - self.config_receiver = ConfigReceiver(self.pulsar_client) + self.config_receiver = ConfigReceiver(self.pubsub_backend) + + # Build queue overrides dictionary from CLI arguments + queue_overrides = {} + + # Config service + config_req = config.get("config_request_queue") + config_resp = config.get("config_response_queue") + if config_req or config_resp: + queue_overrides["config"] = {} + if config_req: + queue_overrides["config"]["request"] = config_req + if config_resp: + queue_overrides["config"]["response"] = config_resp + + # Flow service + flow_req = config.get("flow_request_queue") + flow_resp = config.get("flow_response_queue") + if flow_req or flow_resp: + queue_overrides["flow"] = {} + if flow_req: + queue_overrides["flow"]["request"] = flow_req + if flow_resp: + queue_overrides["flow"]["response"] = flow_resp + + # Knowledge service + knowledge_req = config.get("knowledge_request_queue") + knowledge_resp = config.get("knowledge_response_queue") + if knowledge_req or knowledge_resp: + queue_overrides["knowledge"] = {} + if knowledge_req: + queue_overrides["knowledge"]["request"] = knowledge_req + if knowledge_resp: + queue_overrides["knowledge"]["response"] = knowledge_resp + + # Librarian service + librarian_req = config.get("librarian_request_queue") + librarian_resp = config.get("librarian_response_queue") + if librarian_req or librarian_resp: + queue_overrides["librarian"] = {} + if librarian_req: + queue_overrides["librarian"]["request"] = librarian_req + if librarian_resp: + queue_overrides["librarian"]["response"] = librarian_resp self.dispatcher_manager = DispatcherManager( - pulsar_client = self.pulsar_client, + backend = self.pubsub_backend, config_receiver = self.config_receiver, prefix = "gateway", + queue_overrides = queue_overrides, ) self.endpoint_manager = EndpointManager( @@ -117,6 +162,20 @@ def run(): description=__doc__ ) + parser.add_argument( + '--id', + default='api-gateway', + help='Service identifier for logging and metrics (default: api-gateway)', + ) + + # Pub/sub backend selection + parser.add_argument( + '--pubsub-backend', + default=os.getenv('PUBSUB_BACKEND', 'pulsar'), + choices=['pulsar', 'mqtt'], + help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', + ) + parser.add_argument( '-p', '--pulsar-host', default=default_pulsar_host, @@ -181,9 +240,61 @@ def run(): help=f'Prometheus metrics port (default: 8000)', ) + # Queue override arguments for multi-tenant deployments + parser.add_argument( + '--config-request-queue', + default=None, + help=f'Config service request queue (default: {config_request_queue})', + ) + + parser.add_argument( + '--config-response-queue', + default=None, + help=f'Config service response queue (default: {config_response_queue})', + ) + + parser.add_argument( + '--flow-request-queue', + default=None, + help=f'Flow service request queue (default: {flow_request_queue})', + ) + + parser.add_argument( + '--flow-response-queue', + default=None, + help=f'Flow service response queue (default: {flow_response_queue})', + ) + + parser.add_argument( + '--knowledge-request-queue', + default=None, + help=f'Knowledge service request queue (default: {knowledge_request_queue})', + ) + + parser.add_argument( + '--knowledge-response-queue', + default=None, + help=f'Knowledge service response queue (default: {knowledge_response_queue})', + ) + + parser.add_argument( + '--librarian-request-queue', + default=None, + help=f'Librarian service request queue (default: {librarian_request_queue})', + ) + + parser.add_argument( + '--librarian-response-queue', + default=None, + help=f'Librarian service response queue (default: {librarian_response_queue})', + ) + args = parser.parse_args() args = vars(args) + # Setup logging before creating API instance + setup_logging(args) + if args["metrics"]: start_http_server(args["metrics_port"]) diff --git a/trustgraph-flow/trustgraph/librarian/blob_store.py b/trustgraph-flow/trustgraph/librarian/blob_store.py index e4ccfad9..436e2718 100644 --- a/trustgraph-flow/trustgraph/librarian/blob_store.py +++ b/trustgraph-flow/trustgraph/librarian/blob_store.py @@ -14,29 +14,32 @@ class BlobStore: def __init__( self, - minio_host, minio_access_key, minio_secret_key, bucket_name, + endpoint, access_key, secret_key, bucket_name, + use_ssl=False, region=None, ): - self.minio = Minio( - endpoint = minio_host, - access_key = minio_access_key, - secret_key = minio_secret_key, - secure = False, + self.client = Minio( + endpoint = endpoint, + access_key = access_key, + secret_key = secret_key, + secure = use_ssl, + region = region, ) self.bucket_name = bucket_name - logger.info("Connected to MinIO") + protocol = "https" if use_ssl else "http" + logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}") self.ensure_bucket() def ensure_bucket(self): # Make the bucket if it doesn't exist. - found = self.minio.bucket_exists(bucket_name=self.bucket_name) + found = self.client.bucket_exists(bucket_name=self.bucket_name) if not found: - self.minio.make_bucket(bucket_name=self.bucket_name) + self.client.make_bucket(bucket_name=self.bucket_name) logger.info(f"Created bucket {self.bucket_name}") else: logger.debug(f"Bucket {self.bucket_name} already exists") @@ -44,7 +47,7 @@ class BlobStore: async def add(self, object_id, blob, kind): # FIXME: Loop retry - self.minio.put_object( + self.client.put_object( bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), length = len(blob), @@ -57,7 +60,7 @@ class BlobStore: async def remove(self, object_id): # FIXME: Loop retry - self.minio.remove_object( + self.client.remove_object( bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), ) @@ -68,7 +71,7 @@ class BlobStore: async def get(self, object_id): # FIXME: Loop retry - resp = self.minio.get_object( + resp = self.client.get_object( bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), ) diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py index 1530ed84..34ce1de8 100644 --- a/trustgraph-flow/trustgraph/librarian/collection_manager.py +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -1,142 +1,150 @@ """ -Collection management for the librarian +Collection management for the librarian - uses config service for storage """ import asyncio import logging +import json +import uuid from datetime import datetime from typing import Dict, Any, List, Optional from .. schema import CollectionManagementRequest, CollectionManagementResponse, Error from .. schema import CollectionMetadata -from .. schema import StorageManagementRequest, StorageManagementResponse +from .. schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue from .. exceptions import RequestError -from .. tables.library import LibraryTableStore # Module logger logger = logging.getLogger(__name__) +def metadata_to_dict(metadata: CollectionMetadata) -> dict: + """Convert CollectionMetadata to dictionary for JSON serialization""" + return { + 'user': metadata.user, + 'collection': metadata.collection, + 'name': metadata.name, + 'description': metadata.description, + 'tags': list(metadata.tags) + } + class CollectionManager: - """Manages collection metadata and coordinates collection operations across storage types""" + """Manages collection metadata via config service""" def __init__( self, - cassandra_host, - cassandra_username, - cassandra_password, - keyspace, - vector_storage_producer=None, - object_storage_producer=None, - triples_storage_producer=None, - storage_response_consumer=None + config_request_producer, + config_response_consumer, + taskgroup ): """ Initialize the CollectionManager Args: - cassandra_host: Cassandra host(s) - cassandra_username: Cassandra username - cassandra_password: Cassandra password - keyspace: Cassandra keyspace for library data - vector_storage_producer: Producer for vector storage management - object_storage_producer: Producer for object storage management - triples_storage_producer: Producer for triples storage management - storage_response_consumer: Consumer for storage management responses + config_request_producer: Producer for config service requests + config_response_consumer: Consumer for config service responses + taskgroup: Task group for async operations """ - self.table_store = LibraryTableStore( - cassandra_host, cassandra_username, cassandra_password, keyspace - ) + self.config_request_producer = config_request_producer + self.config_response_consumer = config_response_consumer + self.taskgroup = taskgroup - # Storage management producers - self.vector_storage_producer = vector_storage_producer - self.object_storage_producer = object_storage_producer - self.triples_storage_producer = triples_storage_producer - self.storage_response_consumer = storage_response_consumer + # Track pending config requests + self.pending_config_requests = {} - # Track pending deletion operations - self.pending_deletions = {} + logger.info("Collection manager initialized with config service backend") - logger.info("Collection manager initialized") + async def send_config_request(self, request: ConfigRequest) -> ConfigResponse: + """ + Send config request and wait for response + + Args: + request: Config service request (without id field) + + Returns: + ConfigResponse from config service + """ + # Generate request ID - passed via message properties, not in schema + request_id = str(uuid.uuid4()) + + event = asyncio.Event() + self.pending_config_requests[request_id] = event + + # Send request with ID in message properties + await self.config_request_producer.send(request, properties={"id": request_id}) + await event.wait() + + response = self.pending_config_requests.pop(request_id + "_response") + return response + + async def on_config_response(self, message, consumer, flow): + """ + Handle config response + + Args: + message: Pulsar message + consumer: Consumer instance + flow: Flow context + """ + # Get ID from message properties + response_id = message.properties().get("id") + if response_id and response_id in self.pending_config_requests: + response = message.value() + self.pending_config_requests[response_id + "_response"] = response + self.pending_config_requests[response_id].set() async def ensure_collection_exists(self, user: str, collection: str): """ - Ensure a collection exists, creating it if necessary with broadcast to storage + Ensure a collection exists, creating it if necessary Args: user: User ID collection: Collection ID """ try: - # Check if collection already exists - existing = await self.table_store.get_collection(user, collection) - if existing: + # Check if collection exists via config service + request = ConfigRequest( + operation='get', + keys=[ConfigKey(type='collection', key=f'{user}:{collection}')] + ) + + response = await self.send_config_request(request) + + # Validate response + if not response.values or len(response.values) == 0: + raise Exception(f"Invalid response from config service when checking collection {user}/{collection}") + + # Check if collection exists (value not None means it exists) + if response.values[0].value is not None: logger.debug(f"Collection {user}/{collection} already exists") return + # Collection doesn't exist (value is None), proceed to create # Create new collection with default metadata - logger.info(f"Auto-creating collection {user}/{collection} from document submission") - await self.table_store.create_collection( + logger.info(f"Auto-creating collection {user}/{collection}") + + metadata = CollectionMetadata( user=user, collection=collection, name=collection, # Default name to collection ID description="", - tags=set() + tags=[] ) - # Broadcast collection creation to all storage backends - creation_key = (user, collection) - logger.info(f"Broadcasting create-collection for {creation_key}") - - self.pending_deletions[creation_key] = { - "responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples - "responses_received": [], - "all_successful": True, - "error_messages": [], - "deletion_complete": asyncio.Event() - } - - storage_request = StorageManagementRequest( - operation="create-collection", - user=user, - collection=collection + request = ConfigRequest( + operation='put', + values=[ConfigValue( + type='collection', + key=f'{user}:{collection}', + value=json.dumps(metadata_to_dict(metadata)) + )] ) - # Send creation requests to all storage types - if self.vector_storage_producer: - await self.vector_storage_producer.send(storage_request) - if self.object_storage_producer: - await self.object_storage_producer.send(storage_request) - if self.triples_storage_producer: - await self.triples_storage_producer.send(storage_request) + response = await self.send_config_request(request) - # Wait for all storage creations to complete (with timeout) - creation_info = self.pending_deletions[creation_key] - try: - await asyncio.wait_for( - creation_info["deletion_complete"].wait(), - timeout=30.0 # 30 second timeout - ) - except asyncio.TimeoutError: - logger.error(f"Timeout waiting for storage creation responses for {creation_key}") - creation_info["all_successful"] = False - creation_info["error_messages"].append("Timeout waiting for storage creation") + if response.error: + raise RuntimeError(f"Config update failed: {response.error.message}") - # Check if all creations succeeded - if not creation_info["all_successful"]: - error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}" - logger.error(error_msg) - - # Clean up metadata on failure - await self.table_store.delete_collection(user, collection) - - # Clean up tracking - del self.pending_deletions[creation_key] - - raise RuntimeError(error_msg) - - # Clean up tracking - del self.pending_deletions[creation_key] - logger.info(f"Collection {creation_key} auto-created successfully in all storage backends") + logger.info(f"Collection {user}/{collection} auto-created in config service") except Exception as e: logger.error(f"Error ensuring collection exists: {e}") @@ -144,7 +152,7 @@ class CollectionManager: async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse: """ - List collections for a user with optional tag filtering + List collections for a user from config service Args: request: Collection management request @@ -153,25 +161,42 @@ class CollectionManager: CollectionManagementResponse with list of collections """ try: - tag_filter = list(request.tag_filter) if request.tag_filter else None - collections = await self.table_store.list_collections(request.user, tag_filter) + # Get all collections from config service + config_request = ConfigRequest( + operation='getvalues', + type='collection' + ) - collection_metadata = [ - CollectionMetadata( - user=coll["user"], - collection=coll["collection"], - name=coll["name"], - description=coll["description"], - tags=coll["tags"], - created_at=coll["created_at"], - updated_at=coll["updated_at"] - ) - for coll in collections - ] + response = await self.send_config_request(config_request) + + if response.error: + raise RuntimeError(f"Config query failed: {response.error.message}") + + # Parse collections and filter by user + collections = [] + for config_value in response.values: + if ":" in config_value.key: + coll_user, coll_name = config_value.key.split(":", 1) + if coll_user == request.user: + metadata_dict = json.loads(config_value.value) + metadata = CollectionMetadata(**metadata_dict) + collections.append(metadata) + + # Apply tag filtering if specified + if request.tag_filter: + tag_filter_set = set(request.tag_filter) + collections = [ + c for c in collections + if any(tag in tag_filter_set for tag in c.tags) + ] + + # Apply limit if specified + if request.limit and request.limit > 0: + collections = collections[:request.limit] return CollectionManagementResponse( error=None, - collections=collection_metadata, + collections=collections, timestamp=datetime.now().isoformat() ) @@ -181,7 +206,7 @@ class CollectionManager: async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: """ - Update collection metadata (creates if doesn't exist) + Update collection metadata via config service (creates if doesn't exist) Args: request: Collection management request @@ -190,120 +215,42 @@ class CollectionManager: CollectionManagementResponse with updated collection """ try: - # Check if collection exists, create if it doesn't - existing = await self.table_store.get_collection(request.user, request.collection) - if not existing: - # Create new collection with provided metadata - logger.info(f"Creating new collection {request.user}/{request.collection}") + # Create metadata from request + name = request.name if request.name else request.collection + description = request.description if request.description else "" + tags = list(request.tags) if request.tags else [] - name = request.name if request.name else request.collection - description = request.description if request.description else "" - tags = set(request.tags) if request.tags else set() + metadata = CollectionMetadata( + user=request.user, + collection=request.collection, + name=name, + description=description, + tags=tags + ) - await self.table_store.create_collection( - user=request.user, - collection=request.collection, - name=name, - description=description, - tags=tags - ) + # Send put request to config service + config_request = ConfigRequest( + operation='put', + values=[ConfigValue( + type='collection', + key=f'{request.user}:{request.collection}', + value=json.dumps(metadata_to_dict(metadata)) + )] + ) - # Broadcast collection creation to all storage backends - creation_key = (request.user, request.collection) - logger.info(f"Broadcasting create-collection for {creation_key}") + response = await self.send_config_request(config_request) - self.pending_deletions[creation_key] = { - "responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples - "responses_received": [], - "all_successful": True, - "error_messages": [], - "deletion_complete": asyncio.Event() - } + if response.error: + raise RuntimeError(f"Config update failed: {response.error.message}") - storage_request = StorageManagementRequest( - operation="create-collection", - user=request.user, - collection=request.collection - ) + logger.info(f"Collection {request.user}/{request.collection} updated in config service") - # Send creation requests to all storage types - if self.vector_storage_producer: - await self.vector_storage_producer.send(storage_request) - if self.object_storage_producer: - await self.object_storage_producer.send(storage_request) - if self.triples_storage_producer: - await self.triples_storage_producer.send(storage_request) - - # Wait for all storage creations to complete (with timeout) - creation_info = self.pending_deletions[creation_key] - try: - await asyncio.wait_for( - creation_info["deletion_complete"].wait(), - timeout=30.0 # 30 second timeout - ) - except asyncio.TimeoutError: - logger.error(f"Timeout waiting for storage creation responses for {creation_key}") - creation_info["all_successful"] = False - creation_info["error_messages"].append("Timeout waiting for storage creation") - - # Check if all creations succeeded - if not creation_info["all_successful"]: - error_msg = f"Storage creation failed: {'; '.join(creation_info['error_messages'])}" - logger.error(error_msg) - - # Clean up metadata on failure - await self.table_store.delete_collection(request.user, request.collection) - - # Clean up tracking - del self.pending_deletions[creation_key] - - return CollectionManagementResponse( - error=Error( - type="storage_creation_error", - message=error_msg - ), - timestamp=datetime.now().isoformat() - ) - - # Clean up tracking - del self.pending_deletions[creation_key] - logger.info(f"Collection {creation_key} created successfully in all storage backends") - - # Get the newly created collection for response - created_collection = await self.table_store.get_collection(request.user, request.collection) - - collection_metadata = CollectionMetadata( - user=created_collection["user"], - collection=created_collection["collection"], - name=created_collection["name"], - description=created_collection["description"], - tags=created_collection["tags"], - created_at=created_collection["created_at"], - updated_at=created_collection["updated_at"] - ) - else: - # Collection exists, update it - name = request.name if request.name else None - description = request.description if request.description else None - tags = list(request.tags) if request.tags else None - - updated_collection = await self.table_store.update_collection( - request.user, request.collection, name, description, tags - ) - - collection_metadata = CollectionMetadata( - user=updated_collection["user"], - collection=updated_collection["collection"], - name=updated_collection["name"], - description=updated_collection["description"], - tags=updated_collection["tags"], - created_at="", # Not returned by update - updated_at=updated_collection["updated_at"] - ) + # Config service will trigger config push automatically + # Storage services will receive update and create/update collections return CollectionManagementResponse( error=None, - collections=[collection_metadata], + collections=[metadata], timestamp=datetime.now().isoformat() ) @@ -313,7 +260,7 @@ class CollectionManager: async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: """ - Delete collection with cascade to all storage types + Delete collection via config service Args: request: Collection management request @@ -322,68 +269,23 @@ class CollectionManager: CollectionManagementResponse indicating success or failure """ try: - deletion_key = (request.user, request.collection) + logger.info(f"Deleting collection {request.user}/{request.collection}") - logger.info(f"Starting cascade deletion for {request.user}/{request.collection}") - - # Track this deletion request - self.pending_deletions[deletion_key] = { - "responses_pending": 4, # doc-embeddings, graph-embeddings, object, triples - "responses_received": [], - "all_successful": True, - "error_messages": [], - "deletion_complete": asyncio.Event() - } - - # Create storage management request - storage_request = StorageManagementRequest( - operation="delete-collection", - user=request.user, - collection=request.collection + # Send delete request to config service + config_request = ConfigRequest( + operation='delete', + keys=[ConfigKey(type='collection', key=f'{request.user}:{request.collection}')] ) - # Send deletion requests to all storage types - if self.vector_storage_producer: - await self.vector_storage_producer.send(storage_request) - if self.object_storage_producer: - await self.object_storage_producer.send(storage_request) - if self.triples_storage_producer: - await self.triples_storage_producer.send(storage_request) + response = await self.send_config_request(config_request) - # Wait for all storage deletions to complete (with timeout) - deletion_info = self.pending_deletions[deletion_key] - try: - await asyncio.wait_for( - deletion_info["deletion_complete"].wait(), - timeout=30.0 # 30 second timeout - ) - except asyncio.TimeoutError: - logger.error(f"Timeout waiting for storage deletion responses for {deletion_key}") - deletion_info["all_successful"] = False - deletion_info["error_messages"].append("Timeout waiting for storage deletion") + if response.error: + raise RuntimeError(f"Config delete failed: {response.error.message}") - # Check if all deletions succeeded - if not deletion_info["all_successful"]: - error_msg = f"Storage deletion failed: {'; '.join(deletion_info['error_messages'])}" - logger.error(error_msg) + logger.info(f"Collection {request.user}/{request.collection} deleted from config service") - # Clean up tracking - del self.pending_deletions[deletion_key] - - return CollectionManagementResponse( - error=Error( - type="storage_deletion_error", - message=error_msg - ), - timestamp=datetime.now().isoformat() - ) - - # All storage deletions succeeded, now delete metadata - logger.info(f"Storage deletions complete, removing metadata for {deletion_key}") - await self.table_store.delete_collection(request.user, request.collection) - - # Clean up tracking - del self.pending_deletions[deletion_key] + # Config service will trigger config push automatically + # Storage services will receive update and delete collections return CollectionManagementResponse( error=None, @@ -392,39 +294,4 @@ class CollectionManager: except Exception as e: logger.error(f"Error deleting collection: {e}") - # Clean up tracking on error - if deletion_key in self.pending_deletions: - del self.pending_deletions[deletion_key] raise RequestError(f"Failed to delete collection: {str(e)}") - - async def on_storage_response(self, response: StorageManagementResponse): - """ - Handle storage management responses for deletion tracking - - Args: - response: Storage management response - """ - logger.debug(f"Received storage response: error={response.error}") - - # Find matching deletion by checking all pending deletions - # Note: This is simplified correlation - in production we'd want better correlation - for deletion_key, info in list(self.pending_deletions.items()): - if info["responses_pending"] > 0: - # Record this response - info["responses_received"].append(response) - info["responses_pending"] -= 1 - - # Check if this response indicates failure - if response.error and response.error.message: - info["all_successful"] = False - info["error_messages"].append(response.error.message) - logger.warning(f"Storage operation failed for {deletion_key}: {response.error.message}") - else: - logger.debug(f"Storage operation succeeded for {deletion_key}") - - # If all responses received, signal completion - if info["responses_pending"] == 0: - logger.info(f"All storage responses received for {deletion_key}") - info["deletion_complete"].set() - - break # Only process for first matching deletion \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 56fcb040..8835cc73 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -17,12 +17,14 @@ class Librarian: def __init__( self, cassandra_host, cassandra_username, cassandra_password, - minio_host, minio_access_key, minio_secret_key, + object_store_endpoint, object_store_access_key, object_store_secret_key, bucket_name, keyspace, load_document, + object_store_use_ssl=False, object_store_region=None, ): self.blob_store = BlobStore( - minio_host, minio_access_key, minio_secret_key, bucket_name + object_store_endpoint, object_store_access_key, object_store_secret_key, bucket_name, + use_ssl=object_store_use_ssl, region=object_store_region, ) self.table_store = LibraryTableStore( diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index 00d64010..1d04ee06 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -18,9 +18,8 @@ from .. schema import LibrarianRequest, LibrarianResponse, Error from .. schema import librarian_request_queue, librarian_response_queue from .. schema import CollectionManagementRequest, CollectionManagementResponse from .. schema import collection_request_queue, collection_response_queue -from .. schema import StorageManagementRequest, StorageManagementResponse -from .. schema import vector_storage_management_topic, object_storage_management_topic -from .. schema import triples_storage_management_topic, storage_management_response_topic +from .. schema import ConfigRequest, ConfigResponse +from .. schema import config_request_queue, config_response_queue from .. schema import Document, Metadata from .. schema import TextDocument, Metadata @@ -39,17 +38,18 @@ default_librarian_request_queue = librarian_request_queue default_librarian_response_queue = librarian_response_queue default_collection_request_queue = collection_request_queue default_collection_response_queue = collection_response_queue +default_config_request_queue = config_request_queue +default_config_response_queue = config_response_queue -default_minio_host = "minio:9000" -default_minio_access_key = "minioadmin" -default_minio_secret_key = "minioadmin" +default_object_store_endpoint = "ceph-rgw:7480" +default_object_store_access_key = "object-user" +default_object_store_secret_key = "object-password" +default_object_store_use_ssl = False +default_object_store_region = None default_cassandra_host = "cassandra" bucket_name = "library" -# FIXME: How to ensure this doesn't conflict with other usage? -keyspace = "librarian" - class Processor(AsyncProcessor): def __init__(self, **params): @@ -74,27 +74,44 @@ class Processor(AsyncProcessor): "collection_response_queue", default_collection_response_queue ) - minio_host = params.get("minio_host", default_minio_host) - minio_access_key = params.get( - "minio_access_key", - default_minio_access_key + config_request_queue = params.get( + "config_request_queue", default_config_request_queue ) - minio_secret_key = params.get( - "minio_secret_key", - default_minio_secret_key + + config_response_queue = params.get( + "config_response_queue", default_config_response_queue + ) + + object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint) + object_store_access_key = params.get( + "object_store_access_key", + default_object_store_access_key + ) + object_store_secret_key = params.get( + "object_store_secret_key", + default_object_store_secret_key + ) + object_store_use_ssl = params.get( + "object_store_use_ssl", + default_object_store_use_ssl + ) + object_store_region = params.get( + "object_store_region", + default_object_store_region ) cassandra_host = params.get("cassandra_host") cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - + # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, - password=cassandra_password + password=cassandra_password, + default_keyspace="librarian" ) - + # Store resolved configuration self.cassandra_host = hosts self.cassandra_username = username @@ -106,8 +123,8 @@ class Processor(AsyncProcessor): "librarian_response_queue": librarian_response_queue, "collection_request_queue": collection_request_queue, "collection_response_queue": collection_response_queue, - "minio_host": minio_host, - "minio_access_key": minio_access_key, + "object_store_endpoint": object_store_endpoint, + "object_store_access_key": object_store_access_key, "cassandra_host": self.cassandra_host, "cassandra_username": self.cassandra_username, "cassandra_password": self.cassandra_password, @@ -136,7 +153,7 @@ class Processor(AsyncProcessor): self.librarian_request_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, topic = librarian_request_queue, subscriber = id, @@ -146,7 +163,7 @@ class Processor(AsyncProcessor): ) self.librarian_response_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = librarian_response_queue, schema = LibrarianResponse, metrics = librarian_response_metrics, @@ -154,7 +171,7 @@ class Processor(AsyncProcessor): self.collection_request_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, topic = collection_request_queue, subscriber = id, @@ -164,63 +181,57 @@ class Processor(AsyncProcessor): ) self.collection_response_producer = Producer( - client = self.pulsar_client, + backend = self.pubsub, topic = collection_response_queue, schema = CollectionManagementResponse, metrics = collection_response_metrics, ) - # Storage management producers for collection deletion - self.vector_storage_producer = Producer( - client = self.pulsar_client, - topic = vector_storage_management_topic, - schema = StorageManagementRequest, + # Config service client for collection management + config_request_metrics = ProducerMetrics( + processor = id, flow = None, name = "config-request" ) - self.object_storage_producer = Producer( - client = self.pulsar_client, - topic = object_storage_management_topic, - schema = StorageManagementRequest, + self.config_request_producer = Producer( + backend = self.pubsub, + topic = config_request_queue, + schema = ConfigRequest, + metrics = config_request_metrics, ) - self.triples_storage_producer = Producer( - client = self.pulsar_client, - topic = triples_storage_management_topic, - schema = StorageManagementRequest, + config_response_metrics = ConsumerMetrics( + processor = id, flow = None, name = "config-response" ) - self.storage_response_consumer = Consumer( + self.config_response_consumer = Consumer( taskgroup = self.taskgroup, - client = self.pulsar_client, + backend = self.pubsub, flow = None, - topic = storage_management_response_topic, - subscriber = id, - schema = StorageManagementResponse, - handler = self.on_storage_response, - metrics = storage_response_metrics, + topic = config_response_queue, + subscriber = f"{id}-config", + schema = ConfigResponse, + handler = self.on_config_response, + metrics = config_response_metrics, ) self.librarian = Librarian( cassandra_host = self.cassandra_host, cassandra_username = self.cassandra_username, cassandra_password = self.cassandra_password, - minio_host = minio_host, - minio_access_key = minio_access_key, - minio_secret_key = minio_secret_key, + object_store_endpoint = object_store_endpoint, + object_store_access_key = object_store_access_key, + object_store_secret_key = object_store_secret_key, bucket_name = bucket_name, keyspace = keyspace, load_document = self.load_document, + object_store_use_ssl = object_store_use_ssl, + object_store_region = object_store_region, ) self.collection_manager = CollectionManager( - cassandra_host = self.cassandra_host, - cassandra_username = self.cassandra_username, - cassandra_password = self.cassandra_password, - keyspace = keyspace, - vector_storage_producer = self.vector_storage_producer, - object_storage_producer = self.object_storage_producer, - triples_storage_producer = self.triples_storage_producer, - storage_response_consumer = self.storage_response_consumer, + config_request_producer = self.config_request_producer, + config_response_consumer = self.config_response_consumer, + taskgroup = self.taskgroup, ) self.register_config_handler(self.on_librarian_config) @@ -236,10 +247,12 @@ class Processor(AsyncProcessor): await self.librarian_response_producer.start() await self.collection_request_consumer.start() await self.collection_response_producer.start() - await self.vector_storage_producer.start() - await self.object_storage_producer.start() - await self.triples_storage_producer.start() - await self.storage_response_consumer.start() + await self.config_request_producer.start() + await self.config_response_consumer.start() + + async def on_config_response(self, message, consumer, flow): + """Forward config responses to collection manager""" + await self.collection_manager.on_config_response(message, consumer, flow) async def on_librarian_config(self, config, version): @@ -298,14 +311,13 @@ class Processor(AsyncProcessor): collection = processing.collection ), data = base64.b64encode(content).decode("utf-8") - ) schema = Document logger.debug(f"Submitting to queue {q}...") pub = Publisher( - self.pulsar_client, q, schema=schema + self.pubsub, q, schema=schema ) await pub.start() @@ -464,14 +476,6 @@ class Processor(AsyncProcessor): logger.debug("Collection request processing complete") - async def on_storage_response(self, msg, consumer, flow): - """ - Handle storage management response messages - """ - v = msg.value() - logger.debug("Received storage management response") - await self.collection_manager.on_storage_response(v) - @staticmethod def add_args(parser): @@ -502,23 +506,36 @@ class Processor(AsyncProcessor): ) parser.add_argument( - '--minio-host', - default=default_minio_host, - help=f'Minio hostname (default: {default_minio_host})', + '--object-store-endpoint', + default=default_object_store_endpoint, + help=f'Object storage endpoint (default: {default_object_store_endpoint})', ) parser.add_argument( - '--minio-access-key', - default='minioadmin', - help='Minio access key / username ' - f'(default: {default_minio_access_key})', + '--object-store-access-key', + default=default_object_store_access_key, + help='Object storage access key / username ' + f'(default: {default_object_store_access_key})', ) parser.add_argument( - '--minio-secret-key', - default='minioadmin', - help='Minio secret key / password ' - f'(default: {default_minio_access_key})', + '--object-store-secret-key', + default=default_object_store_secret_key, + help='Object storage secret key / password ' + f'(default: {default_object_store_secret_key})', + ) + + parser.add_argument( + '--object-store-use-ssl', + action='store_true', + default=default_object_store_use_ssl, + help=f'Use SSL/TLS for object storage connection (default: {default_object_store_use_ssl})', + ) + + parser.add_argument( + '--object-store-region', + default=default_object_store_region, + help='Object storage region (optional)', ) add_cassandra_args(parser) diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 35449151..07dea8ba 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -20,24 +20,18 @@ class Processor(FlowProcessor): id = params.get("id", default_ident) - if not hasattr(__class__, "input_token_metric"): - __class__.input_token_metric = Counter( - 'input_tokens', 'Input token count' + if not hasattr(__class__, "token_metric"): + __class__.token_metric = Counter( + 'tokens', + 'Token count', + ['model', 'direction'] ) - if not hasattr(__class__, "output_token_metric"): - __class__.output_token_metric = Counter( - 'output_tokens', 'Output token count' - ) - - if not hasattr(__class__, "input_cost_metric"): - __class__.input_cost_metric = Counter( - 'input_cost', 'Input cost' - ) - - if not hasattr(__class__, "output_cost_metric"): - __class__.output_cost_metric = Counter( - 'output_cost', 'Output cost' + if not hasattr(__class__, "cost_metric"): + __class__.cost_metric = Counter( + 'cost', + 'Cost in USD', + ['model', 'direction'] ) super(Processor, self).__init__( @@ -87,12 +81,13 @@ class Processor(FlowProcessor): v = msg.value() - modelname = v.model - num_in = v.in_token - num_out = v.out_token + modelname = v.model or "unknown" + num_in = v.in_token or 0 + num_out = v.out_token or 0 - __class__.input_token_metric.inc(num_in) - __class__.output_token_metric.inc(num_out) + # Increment token metrics with model and direction labels + __class__.token_metric.labels(model=modelname, direction="input").inc(num_in) + __class__.token_metric.labels(model=modelname, direction="output").inc(num_out) model_input_price, model_output_price = self.get_prices(modelname) @@ -103,9 +98,11 @@ class Processor(FlowProcessor): cost_out = num_out * model_output_price cost_per_call = round(cost_in + cost_out, 6) - __class__.input_cost_metric.inc(cost_in) - __class__.output_cost_metric.inc(cost_out) + # Increment cost metrics with model and direction labels + __class__.cost_metric.labels(model=modelname, direction="input").inc(cost_in) + __class__.cost_metric.labels(model=modelname, direction="output").inc(cost_out) + logger.info(f"Model: {modelname}") logger.info(f"Input Tokens: {num_in}") logger.info(f"Output Tokens: {num_out}") logger.info(f"Cost for call: ${cost_per_call}") diff --git a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py b/trustgraph-flow/trustgraph/query/objects/cassandra/service.py index a4726d90..a6683c40 100644 --- a/trustgraph-flow/trustgraph/query/objects/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/objects/cassandra/service.py @@ -74,7 +74,7 @@ class Processor(FlowProcessor): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, password=cassandra_password diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index cf2757af..13726ac3 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -28,7 +28,7 @@ class Processor(TriplesQueryService): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, password=cassandra_password diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 670d71a1..14d71d97 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -95,19 +95,20 @@ class Processor(FlowProcessor): # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks - async def send_chunk(chunk): + # Receives chunk text and end_of_stream flag from prompt client + async def send_chunk(chunk, end_of_stream): await flow("response").send( DocumentRagResponse( - chunk=chunk, - end_of_stream=False, - response=None, + response=chunk, + end_of_stream=end_of_stream, error=None ), properties={"id": id} ) # Query with streaming enabled - full_response = await self.rag.query( + # All chunks (including final one with end_of_stream=True) are sent via callback + await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -115,17 +116,6 @@ class Processor(FlowProcessor): streaming=True, chunk_callback=send_chunk, ) - - # Send final message with complete response - await flow("response").send( - DocumentRagResponse( - chunk=None, - end_of_stream=True, - response=full_response, - error=None - ), - properties={"id": id} - ) else: # Non-streaming path (existing behavior) response = await self.rag.query( diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 565921a3..d159dbae 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -138,19 +138,20 @@ class Processor(FlowProcessor): # Check if streaming is requested if v.streaming: # Define async callback for streaming chunks - async def send_chunk(chunk): + # Receives chunk text and end_of_stream flag from prompt client + async def send_chunk(chunk, end_of_stream): await flow("response").send( GraphRagResponse( - chunk=chunk, - end_of_stream=False, - response=None, + response=chunk, + end_of_stream=end_of_stream, error=None ), properties={"id": id} ) # Query with streaming enabled - full_response = await rag.query( + # All chunks (including final one with end_of_stream=True) are sent via callback + await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -158,17 +159,6 @@ class Processor(FlowProcessor): streaming = True, chunk_callback = send_chunk, ) - - # Send final message with complete response - await flow("response").send( - GraphRagResponse( - chunk=None, - end_of_stream=True, - response=full_response, - error=None - ), - properties={"id": id} - ) else: # Non-streaming path (existing behavior) response = await rag.query( diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py index 03e79c0d..986558ec 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -26,19 +26,19 @@ class WebSocketResponder: self.completed = True class MessageDispatcher: - - def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None): + + def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): self.max_workers = max_workers self.semaphore = asyncio.Semaphore(max_workers) self.active_tasks = set() - self.pulsar_client = pulsar_client - + self.backend = backend + # Use DispatcherManager for flow and service management - if pulsar_client and config_receiver: - self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway") + if backend and config_receiver: + self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") else: self.dispatcher_manager = None - logger.warning("No pulsar_client or config_receiver provided - using fallback mode") + logger.warning("No backend or config_receiver provided - using fallback mode") # Service name mapping from websocket protocol to translator registry self.service_mapping = { @@ -78,7 +78,7 @@ class MessageDispatcher: try: if not self.dispatcher_manager: - raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required") + raise RuntimeError("DispatcherManager not available - backend and config_receiver required") # Use DispatcherManager for flow-based processing responder = WebSocketResponder() diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py index c8e78af2..cc905172 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/service.py +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -7,10 +7,10 @@ import os from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse from typing import Optional from urllib.parse import urlparse, urlunparse -import pulsar from .dispatcher import MessageDispatcher from ..gateway.config.receiver import ConfigReceiver +from ..base import get_pubsub logger = logging.getLogger("rev_gateway") logger.setLevel(logging.INFO) @@ -56,25 +56,20 @@ class ReverseGateway: self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) self.pulsar_listener = pulsar_listener - - # Initialize Pulsar client - if self.pulsar_api_key: - self.pulsar_client = pulsar.Client( - self.pulsar_host, - listener_name=self.pulsar_listener, - authentication=pulsar.AuthenticationToken(self.pulsar_api_key) - ) - else: - self.pulsar_client = pulsar.Client( - self.pulsar_host, - listener_name=self.pulsar_listener - ) - + + # Create backend using factory + backend_params = { + 'pulsar_host': self.pulsar_host, + 'pulsar_api_key': self.pulsar_api_key, + 'pulsar_listener': self.pulsar_listener, + } + self.backend = get_pubsub(**backend_params) + # Initialize config receiver - self.config_receiver = ConfigReceiver(self.pulsar_client) - - # Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver - self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client) + self.config_receiver = ConfigReceiver(self.backend) + + # Initialize dispatcher with config_receiver and backend - must be created after config_receiver + self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend) async def connect(self) -> bool: try: @@ -170,10 +165,10 @@ class ReverseGateway: self.running = False await self.dispatcher.shutdown() await self.disconnect() - - # Close Pulsar client - if hasattr(self, 'pulsar_client'): - self.pulsar_client.close() + + # Close backend + if hasattr(self, 'backend'): + self.backend.close() def stop(self): self.running = False diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 012d91b7..07dbf0eb 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -6,11 +6,9 @@ Accepts entity/vector pairs and writes them to a Milvus store. import logging from .... direct.milvus_doc_embeddings import DocVectors -from .... base import DocumentEmbeddingsStoreService +from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -18,7 +16,7 @@ logger = logging.getLogger(__name__) default_ident = "de-write" default_store_uri = 'http://localhost:19530' -class Processor(DocumentEmbeddingsStoreService): +class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): def __init__(self, **params): @@ -32,51 +30,11 @@ class Processor(DocumentEmbeddingsStoreService): self.vecstore = DocVectors(store_uri) - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) - - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() + # Register for config push notifications + self.register_config_handler(self.on_collection_config) async def store_document_embeddings(self, message): - # Validate collection exists before accepting writes - if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for emb in message.chunks: if emb.chunk is None or emb.chunk == b"": continue @@ -102,72 +60,27 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - collections are created lazily on first write + Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - self.vecstore.create_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(user, collection) except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """Delete the collection for document embeddings""" + async def delete_collection(self, user: str, collection: str): + """Delete the collection for document embeddings via config push""" try: - self.vecstore.delete_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + self.vecstore.delete_collection(user, collection) + logger.info(f"Successfully deleted collection {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 4d3c43bb..6d1b23ba 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -11,11 +11,9 @@ import uuid import os import logging -from .... base import DocumentEmbeddingsStoreService +from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -25,7 +23,7 @@ default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" -class Processor(DocumentEmbeddingsStoreService): +class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): def __init__(self, **params): @@ -59,33 +57,8 @@ class Processor(DocumentEmbeddingsStoreService): self.last_index_name = None - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) def create_index(self, index_name, dim): @@ -115,14 +88,17 @@ class Processor(DocumentEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - async def store_document_embeddings(self, message): + # Validate collection exists in config before processing + if not self.collection_exists(message.metadata.user, message.metadata.collection): + logger.warning( + f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"does not exist in config (likely deleted while data was in-flight). " + f"Dropping message." + ) + return + for emb in message.chunks: if emb.chunk is None or emb.chunk == b"": continue @@ -138,7 +114,7 @@ class Processor(DocumentEmbeddingsStoreService): f"d-{message.metadata.user}-{message.metadata.collection}-{dim}" ) - # Lazily create index if it doesn't exist + # Lazily create index if it doesn't exist (but only if authorized in config) if not self.pinecone.has_index(index_name): logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") self.create_index(index_name, dim) @@ -188,65 +164,22 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - indexes are created lazily on first write + Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """ - Delete all dimension variants of the index for document embeddings. - Since indexes are created with dimension suffixes (e.g., d-user-coll-384), - we need to find and delete all matching indexes. - """ + async def delete_collection(self, user: str, collection: str): + """Delete the collection for document embeddings via config push""" try: - prefix = f"d-{request.user}-{request.collection}-" + prefix = f"d-{user}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -261,16 +194,10 @@ class Processor(DocumentEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) + logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 225beb9c..edfa8aa9 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -9,11 +9,9 @@ from qdrant_client.models import Distance, VectorParams import uuid import logging -from .... base import DocumentEmbeddingsStoreService +from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -22,7 +20,7 @@ default_ident = "de-write" default_store_uri = 'http://localhost:6333' -class Processor(DocumentEmbeddingsStoreService): +class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): def __init__(self, **params): @@ -38,47 +36,20 @@ class Processor(DocumentEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - # Set up storage management if base class attributes are available - # (they may not be in unit tests) - if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'): - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) - - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - if hasattr(self, 'storage_request_consumer'): - await self.storage_request_consumer.start() - if hasattr(self, 'storage_response_producer'): - await self.storage_response_producer.start() + # Register for config push notifications + self.register_config_handler(self.on_collection_config) async def store_document_embeddings(self, message): + # Validate collection exists in config before processing + if not self.collection_exists(message.metadata.user, message.metadata.collection): + logger.warning( + f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"does not exist in config (likely deleted while data was in-flight). " + f"Dropping message." + ) + return + for emb in message.chunks: chunk = emb.chunk.decode("utf-8") @@ -92,7 +63,7 @@ class Processor(DocumentEmbeddingsStoreService): f"d_{message.metadata.user}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist + # Lazily create collection if it doesn't exist (but only if authorized in config) if not self.qdrant.collection_exists(collection): logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") self.qdrant.create_collection( @@ -133,65 +104,22 @@ class Processor(DocumentEmbeddingsStoreService): help=f'Qdrant API key (default: None)' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - collections are created lazily on first write + Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """ - Delete all dimension variants of the collection for document embeddings. - Since collections are created with dimension suffixes (e.g., d_user_coll_384), - we need to find and delete all matching collections. - """ + async def delete_collection(self, user: str, collection: str): + """Delete the collection for document embeddings via config push""" try: - prefix = f"d_{request.user}_{request.collection}_" + prefix = f"d_{user}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -206,16 +134,10 @@ class Processor(DocumentEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) + logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index cca0de95..2e192cd6 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -6,11 +6,9 @@ Accepts entity/vector pairs and writes them to a Milvus store. import logging from .... direct.milvus_graph_embeddings import EntityVectors -from .... base import GraphEmbeddingsStoreService +from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -18,7 +16,7 @@ logger = logging.getLogger(__name__) default_ident = "ge-write" default_store_uri = 'http://localhost:19530' -class Processor(GraphEmbeddingsStoreService): +class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): def __init__(self, **params): @@ -32,51 +30,11 @@ class Processor(GraphEmbeddingsStoreService): self.vecstore = EntityVectors(store_uri) - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) - - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() + # Register for config push notifications + self.register_config_handler(self.on_collection_config) async def store_graph_embeddings(self, message): - # Validate collection exists before accepting writes - if not self.vecstore.collection_exists(message.metadata.user, message.metadata.collection): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for entity in message.entities: if entity.entity.value != "" and entity.entity.value is not None: @@ -98,72 +56,27 @@ class Processor(GraphEmbeddingsStoreService): help=f'Milvus store URI (default: {default_store_uri})' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - collections are created lazily on first write + Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - self.vecstore.create_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") + self.vecstore.create_collection(user, collection) except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """Delete the collection for graph embeddings""" + async def delete_collection(self, user: str, collection: str): + """Delete the collection for graph embeddings via config push""" try: - self.vecstore.delete_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + self.vecstore.delete_collection(user, collection) + logger.info(f"Successfully deleted collection {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index 30d3d3e5..0bee6ceb 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -11,11 +11,9 @@ import uuid import os import logging -from .... base import GraphEmbeddingsStoreService +from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -25,7 +23,7 @@ default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" default_region = "us-east-1" -class Processor(GraphEmbeddingsStoreService): +class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): def __init__(self, **params): @@ -59,33 +57,8 @@ class Processor(GraphEmbeddingsStoreService): self.last_index_name = None - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) def create_index(self, index_name, dim): @@ -115,14 +88,17 @@ class Processor(GraphEmbeddingsStoreService): "Gave up waiting for index creation" ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - async def store_graph_embeddings(self, message): + # Validate collection exists in config before processing + if not self.collection_exists(message.metadata.user, message.metadata.collection): + logger.warning( + f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"does not exist in config (likely deleted while data was in-flight). " + f"Dropping message." + ) + return + for entity in message.entities: if entity.entity.value == "" or entity.entity.value is None: @@ -136,7 +112,7 @@ class Processor(GraphEmbeddingsStoreService): f"t-{message.metadata.user}-{message.metadata.collection}-{dim}" ) - # Lazily create index if it doesn't exist + # Lazily create index if it doesn't exist (but only if authorized in config) if not self.pinecone.has_index(index_name): logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}") self.create_index(index_name, dim) @@ -186,65 +162,22 @@ class Processor(GraphEmbeddingsStoreService): help=f'Pinecone region, (default: {default_region}' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - indexes are created lazily on first write + Create collection via config push - indexes are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """ - Delete all dimension variants of the index for graph embeddings. - Since indexes are created with dimension suffixes (e.g., t-user-coll-384), - we need to find and delete all matching indexes. - """ + async def delete_collection(self, user: str, collection: str): + """Delete the collection for graph embeddings via config push""" try: - prefix = f"t-{request.user}-{request.collection}-" + prefix = f"t-{user}-{collection}-" # Get all indexes and filter for matches all_indexes = self.pinecone.list_indexes() @@ -259,16 +192,10 @@ class Processor(GraphEmbeddingsStoreService): for index_name in matching_indexes: self.pinecone.delete_index(index_name) logger.info(f"Deleted Pinecone index: {index_name}") - logger.info(f"Deleted {len(matching_indexes)} index(es) for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) + logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 0b15996f..e3c2b6bc 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -9,11 +9,9 @@ from qdrant_client.models import Distance, VectorParams import uuid import logging -from .... base import GraphEmbeddingsStoreService +from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import vector_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -22,7 +20,7 @@ default_ident = "ge-write" default_store_uri = 'http://localhost:6333' -class Processor(GraphEmbeddingsStoreService): +class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): def __init__(self, **params): @@ -38,47 +36,20 @@ class Processor(GraphEmbeddingsStoreService): self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - # Set up storage management if base class attributes are available - # (they may not be in unit tests) - if hasattr(self, 'id') and hasattr(self, 'taskgroup') and hasattr(self, 'pulsar_client'): - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=vector_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) - - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - if hasattr(self, 'storage_request_consumer'): - await self.storage_request_consumer.start() - if hasattr(self, 'storage_response_producer'): - await self.storage_response_producer.start() + # Register for config push notifications + self.register_config_handler(self.on_collection_config) async def store_graph_embeddings(self, message): + # Validate collection exists in config before processing + if not self.collection_exists(message.metadata.user, message.metadata.collection): + logger.warning( + f"Collection {message.metadata.collection} for user {message.metadata.user} " + f"does not exist in config (likely deleted while data was in-flight). " + f"Dropping message." + ) + return + for entity in message.entities: if entity.entity.value == "" or entity.entity.value is None: return @@ -91,7 +62,7 @@ class Processor(GraphEmbeddingsStoreService): f"t_{message.metadata.user}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist + # Lazily create collection if it doesn't exist (but only if authorized in config) if not self.qdrant.collection_exists(collection): logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") self.qdrant.create_collection( @@ -132,65 +103,22 @@ class Processor(GraphEmbeddingsStoreService): help=f'Qdrant API key' ) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): + async def create_collection(self, user: str, collection: str, metadata: dict): """ - No-op for collection creation - collections are created lazily on first write + Create collection via config push - collections are created lazily on first write with the correct dimension determined from the actual embeddings. """ try: - logger.info(f"Collection create request for {request.user}/{request.collection} - will be created lazily on first write") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") except Exception as e: - logger.error(f"Failed to handle create collection request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """ - Delete all dimension variants of the collection for graph embeddings. - Since collections are created with dimension suffixes (e.g., t_user_coll_384), - we need to find and delete all matching collections. - """ + async def delete_collection(self, user: str, collection: str): + """Delete the collection for graph embeddings via config push""" try: - prefix = f"t_{request.user}_{request.collection}_" + prefix = f"t_{user}_{collection}_" # Get all collections and filter for matches all_collections = self.qdrant.get_collections().collections @@ -205,16 +133,10 @@ class Processor(GraphEmbeddingsStoreService): for collection_name in matching_collections: self.qdrant.delete_collection(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") - logger.info(f"Deleted {len(matching_collections)} collection(s) for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) + logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py index b39fe09f..a79b7b83 100644 --- a/trustgraph-flow/trustgraph/storage/knowledge/store.py +++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py @@ -23,10 +23,11 @@ class Processor(FlowProcessor): id = params.get("id") # Use helper to resolve configuration - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=params.get("cassandra_host"), username=params.get("cassandra_username"), - password=params.get("cassandra_password") + password=params.get("cassandra_password"), + default_keyspace='knowledge' ) super(Processor, self).__init__( diff --git a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py index e9dda4d6..bcb0d57f 100644 --- a/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/objects/cassandra/write.py @@ -13,9 +13,8 @@ from cassandra import ConsistencyLevel from .... schema import ExtractedObject from .... schema import RowSchema, Field -from .... schema import StorageManagementRequest, StorageManagementResponse -from .... schema import object_storage_management_topic, storage_management_response_topic from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base import CollectionConfigHandler from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config # Module logger @@ -23,7 +22,7 @@ logger = logging.getLogger(__name__) default_ident = "objects-write" -class Processor(FlowProcessor): +class Processor(CollectionConfigHandler, FlowProcessor): def __init__(self, **params): @@ -35,7 +34,7 @@ class Processor(FlowProcessor): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, password=cassandra_password @@ -55,7 +54,7 @@ class Processor(FlowProcessor): "config_type": self.config_key, } ) - + self.register_specification( ConsumerSpec( name = "input", @@ -64,39 +63,9 @@ class Processor(FlowProcessor): ) ) - # Set up storage management consumer and producer directly - # (FlowProcessor doesn't support topic-based specs outside of flows) - from .... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics - - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Create storage management consumer - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=object_storage_management_topic, - subscriber=f"{id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Create storage management response producer - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) - - # Register config handler for schema updates + # Register config handlers self.register_config_handler(self.on_schema_config) + self.register_config_handler(self.on_collection_config) # Cache of known keyspaces/tables self.known_keyspaces: Set[str] = set() @@ -341,41 +310,20 @@ class Processor(FlowProcessor): except Exception as e: logger.warning(f"Failed to convert value {value} to type {field_type}: {e}") return str(value) - - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - async def on_object(self, msg, consumer, flow): """Process incoming ExtractedObject and store in Cassandra""" obj = msg.value() logger.info(f"Storing {len(obj.values)} objects for schema {obj.schema_name} from {obj.metadata.id}") - # Validate collection/keyspace exists before accepting writes - safe_keyspace = self.sanitize_name(obj.metadata.user) - if safe_keyspace not in self.known_keyspaces: - # Check if keyspace actually exists in Cassandra - self.connect_cassandra() - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) - # Check if result is None (mock case) or has no rows - if result is None or not result.one(): - error_msg = ( - f"Collection {obj.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - # Cache it if it exists - self.known_keyspaces.add(safe_keyspace) - if safe_keyspace not in self.known_tables: - self.known_tables[safe_keyspace] = set() + # Validate collection exists before accepting writes + if not self.collection_exists(obj.metadata.user, obj.metadata.collection): + error_msg = ( + f"Collection {obj.metadata.collection} does not exist. " + f"Create it first via collection management API." + ) + logger.error(error_msg) + raise ValueError(error_msg) # Get schema definition schema = self.schemas.get(obj.schema_name) @@ -454,55 +402,7 @@ class Processor(FlowProcessor): logger.error(f"Failed to insert object {obj_index}: {e}", exc_info=True) raise - async def on_storage_management(self, msg, consumer, flow): - """Handle storage management requests for collection operations""" - request = msg.value() - logger.info(f"Received storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.create_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully created collection {request.user}/{request.collection}") - elif request.operation == "delete-collection": - await self.delete_collection(request.user, request.collection) - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") - else: - logger.warning(f"Unknown storage management operation: {request.operation}") - # Send error response - from .... schema import Error - response = StorageManagementResponse( - error=Error( - type="unknown_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error handling storage management request: {e}", exc_info=True) - # Send error response - from .... schema import Error - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def create_collection(self, user: str, collection: str): + async def create_collection(self, user: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra object store""" # Connect if not already connected self.connect_cassandra() diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index ef79e605..1576b70c 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -42,7 +42,7 @@ class Processor(Consumer): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, password=cassandra_password diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 6497f95c..b9b42375 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -11,12 +11,10 @@ import time import logging from .... direct.cassandra_kg import KnowledgeGraph -from .... base import TriplesStoreService +from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -24,10 +22,10 @@ logger = logging.getLogger(__name__) default_ident = "triples-write" -class Processor(TriplesStoreService): +class Processor(CollectionConfigHandler, TriplesStoreService): def __init__(self, **params): - + id = params.get("id", default_ident) # Get Cassandra parameters @@ -36,7 +34,7 @@ class Processor(TriplesStoreService): cassandra_password = params.get("cassandra_password") # Resolve configuration with environment variable fallback - hosts, username, password = resolve_cassandra_config( + hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, password=cassandra_password @@ -48,39 +46,15 @@ class Processor(TriplesStoreService): "cassandra_username": username } ) - + self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password self.table = None + self.tg = None - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=triples_storage_management_topic, - subscriber=f"{id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) async def store_triples(self, message): @@ -109,15 +83,6 @@ class Processor(TriplesStoreService): self.table = user - # Validate collection exists before accepting writes - if not self.tg.collection_exists(message.metadata.collection): - error_msg = ( - f"Collection {message.metadata.collection} does not exist. " - f"Create it first with tg-set-collection." - ) - logger.error(error_msg) - raise ValueError(error_msg) - for t in message.triples: self.tg.insert( message.metadata.collection, @@ -126,133 +91,77 @@ class Processor(TriplesStoreService): t.o.value ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_create_collection(self, request): - """Create a collection in Cassandra triple store""" + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create a collection in Cassandra triple store via config push""" try: # Create or reuse connection for this user's keyspace - if self.table is None or self.table != request.user: + if self.table is None or self.table != user: self.tg = None try: if self.cassandra_username and self.cassandra_password: self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=request.user, + keyspace=user, username=self.cassandra_username, password=self.cassandra_password ) else: self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=request.user, + keyspace=user, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}") + logger.error(f"Failed to connect to Cassandra for user {user}: {e}") raise - self.table = request.user + self.table = user # Create collection using the built-in method - logger.info(f"Creating collection {request.collection} for user {request.user}") + logger.info(f"Creating collection {collection} for user {user}") - if self.tg.collection_exists(request.collection): - logger.info(f"Collection {request.collection} already exists") + if self.tg.collection_exists(collection): + logger.info(f"Collection {collection} already exists") else: - self.tg.create_collection(request.collection) - logger.info(f"Created collection {request.collection}") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + self.tg.create_collection(collection) + logger.info(f"Created collection {collection}") except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): + async def delete_collection(self, user: str, collection: str): """Delete all data for a specific collection from the unified triples table""" try: # Create or reuse connection for this user's keyspace - if self.table is None or self.table != request.user: + if self.table is None or self.table != user: self.tg = None try: if self.cassandra_username and self.cassandra_password: self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=request.user, + keyspace=user, username=self.cassandra_username, password=self.cassandra_password ) else: self.tg = KnowledgeGraph( hosts=self.cassandra_host, - keyspace=request.user, + keyspace=user, ) except Exception as e: - logger.error(f"Failed to connect to Cassandra for user {request.user}: {e}") + logger.error(f"Failed to connect to Cassandra for user {user}: {e}") raise - self.table = request.user + self.table = user # Delete all triples for this collection using the built-in method - try: - self.tg.delete_collection(request.collection) - logger.info(f"Deleted all triples for collection {request.collection} from keyspace {request.user}") - except Exception as e: - logger.error(f"Failed to delete collection data: {e}") - raise - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + self.tg.delete_collection(collection) + logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise @staticmethod diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index d0800b67..f08eeb91 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -12,11 +12,9 @@ import logging from falkordb import FalkorDB -from .... base import TriplesStoreService +from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -26,10 +24,10 @@ default_ident = "triples-write" default_graph_url = 'falkor://falkordb:6379' default_database = 'falkordb' -class Processor(TriplesStoreService): +class Processor(CollectionConfigHandler, TriplesStoreService): def __init__(self, **params): - + graph_url = params.get("graph_url", default_graph_url) database = params.get("database", default_database) @@ -44,33 +42,8 @@ class Processor(TriplesStoreService): self.io = FalkorDB.from_url(graph_url).select_graph(database) - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=triples_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) def create_node(self, uri, user, collection): @@ -184,7 +157,7 @@ class Processor(TriplesStoreService): if not self.collection_exists(user, collection): error_msg = ( f"Collection {collection} does not exist. " - f"Create it first with tg-set-collection." + f"Create it first via collection management API." ) logger.error(error_msg) raise ValueError(error_msg) @@ -217,95 +190,58 @@ class Processor(TriplesStoreService): help=f'FalkorDB database (default: {default_database})' ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create collection metadata in FalkorDB via config push""" try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) + # Check if collection exists + result = self.io.query( + "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1", + params={"user": user, "collection": collection} + ) + if result.result_set: + logger.info(f"Collection {user}/{collection} already exists") else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) + # Create collection metadata node + import datetime + self.io.query( + "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "SET c.created_at = $created_at", + params={ + "user": user, + "collection": collection, + "created_at": datetime.datetime.now().isoformat() + } ) - await self.storage_response_producer.send(response) + logger.info(f"Created collection {user}/{collection}") except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_create_collection(self, request): - """Create collection metadata in FalkorDB""" - try: - if self.collection_exists(request.user, request.collection): - logger.info(f"Collection {request.user}/{request.collection} already exists") - else: - self.create_collection(request.user, request.collection) - logger.info(f"Created collection {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_delete_collection(self, request): - """Delete the collection for FalkorDB triples""" + async def delete_collection(self, user: str, collection: str): + """Delete the collection for FalkorDB triples via config push""" try: # Delete all nodes and literals for this user/collection node_result = self.io.query( "MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": request.user, "collection": request.collection} + params={"user": user, "collection": collection} ) literal_result = self.io.query( "MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n", - params={"user": request.user, "collection": request.collection} + params={"user": user, "collection": collection} ) # Delete collection metadata node metadata_result = self.io.query( "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c", - params={"user": request.user, "collection": request.collection} + params={"user": user, "collection": collection} ) - logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index 84248952..8105b14e 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -12,11 +12,9 @@ import logging from neo4j import GraphDatabase -from .... base import TriplesStoreService +from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -28,10 +26,10 @@ default_username = 'memgraph' default_password = 'password' default_database = 'memgraph' -class Processor(TriplesStoreService): +class Processor(CollectionConfigHandler, TriplesStoreService): def __init__(self, **params): - + graph_host = params.get("graph_host", default_graph_host) username = params.get("username", default_username) password = params.get("password", default_password) @@ -53,33 +51,8 @@ class Processor(TriplesStoreService): with self.io.session(database=self.db) as session: self.create_indexes(session) - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=triples_storage_management_topic, - subscriber=f"{self.id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) def create_indexes(self, session): @@ -267,28 +240,6 @@ class Processor(TriplesStoreService): src=t.s.value, dest=t.o.value, uri=t.p.value, user=user, collection=collection, ) - def collection_exists(self, user, collection): - """Check if collection metadata node exists""" - with self.io.session(database=self.db) as session: - result = session.run( - "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " - "RETURN c LIMIT 1", - user=user, collection=collection - ) - return bool(list(result)) - - def create_collection(self, user, collection): - """Create collection metadata node""" - import datetime - with self.io.session(database=self.db) as session: - session.run( - "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " - "SET c.created_at = $created_at", - user=user, collection=collection, - created_at=datetime.datetime.now().isoformat() - ) - logger.info(f"Created collection metadata node for {user}/{collection}") - async def store_triples(self, message): # Extract user and collection from metadata @@ -299,7 +250,7 @@ class Processor(TriplesStoreService): if not self.collection_exists(user, collection): error_msg = ( f"Collection {collection} does not exist. " - f"Create it first with tg-set-collection." + f"Create it first via collection management API." ) logger.error(error_msg) raise ValueError(error_msg) @@ -348,73 +299,50 @@ class Processor(TriplesStoreService): help=f'Memgraph database (default: {default_database})' ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() + def _collection_exists_in_db(self, user, collection): + """Check if collection metadata node exists""" + with self.io.session(database=self.db) as session: + result = session.run( + "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " + "RETURN c LIMIT 1", + user=user, collection=collection + ) + return bool(list(result)) - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") + def _create_collection_in_db(self, user, collection): + """Create collection metadata node""" + import datetime + with self.io.session(database=self.db) as session: + session.run( + "MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " + "SET c.created_at = $created_at", + user=user, collection=collection, + created_at=datetime.datetime.now().isoformat() + ) + logger.info(f"Created collection metadata node for {user}/{collection}") + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create collection metadata in Memgraph via config push""" try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) + if self._collection_exists_in_db(user, collection): + logger.info(f"Collection {user}/{collection} already exists") else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) + self._create_collection_in_db(user, collection) + logger.info(f"Created collection {user}/{collection}") except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_create_collection(self, request): - """Create collection metadata in Memgraph""" - try: - if self.collection_exists(request.user, request.collection): - logger.info(f"Collection {request.user}/{request.collection} already exists") - else: - self.create_collection(request.user, request.collection) - logger.info(f"Created collection {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - async def handle_delete_collection(self, request): - """Delete all data for a specific collection""" + async def delete_collection(self, user: str, collection: str): + """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: # Delete all nodes for this user and collection node_result = session.run( "MATCH (n:Node {user: $user, collection: $collection}) " "DETACH DELETE n", - user=request.user, collection=request.collection + user=user, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted @@ -422,7 +350,7 @@ class Processor(TriplesStoreService): literal_result = session.run( "MATCH (n:Literal {user: $user, collection: $collection}) " "DETACH DELETE n", - user=request.user, collection=request.collection + user=user, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted @@ -430,20 +358,13 @@ class Processor(TriplesStoreService): metadata_result = session.run( "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "DELETE c", - user=request.user, collection=request.collection + user=user, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted # Note: Relationships are automatically deleted with DETACH DELETE - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") except Exception as e: logger.error(f"Failed to delete collection: {e}") diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 227356ce..e33b26ca 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -11,11 +11,9 @@ import time import logging from neo4j import GraphDatabase -from .... base import TriplesStoreService +from .... base import TriplesStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer from .... base import ConsumerMetrics, ProducerMetrics -from .... schema import StorageManagementRequest, StorageManagementResponse, Error -from .... schema import triples_storage_management_topic, storage_management_response_topic # Module logger logger = logging.getLogger(__name__) @@ -27,10 +25,10 @@ default_username = 'neo4j' default_password = 'password' default_database = 'neo4j' -class Processor(TriplesStoreService): +class Processor(CollectionConfigHandler, TriplesStoreService): def __init__(self, **params): - + id = params.get("id", default_ident) graph_host = params.get("graph_host", default_graph_host) @@ -53,33 +51,8 @@ class Processor(TriplesStoreService): with self.io.session(database=self.db) as session: self.create_indexes(session) - # Set up metrics for storage management - storage_request_metrics = ConsumerMetrics( - processor=self.id, flow=None, name="storage-request" - ) - storage_response_metrics = ProducerMetrics( - processor=self.id, flow=None, name="storage-response" - ) - - # Set up consumer for storage management requests - self.storage_request_consumer = Consumer( - taskgroup=self.taskgroup, - client=self.pulsar_client, - flow=None, - topic=triples_storage_management_topic, - subscriber=f"{id}-storage", - schema=StorageManagementRequest, - handler=self.on_storage_management, - metrics=storage_request_metrics, - ) - - # Set up producer for storage management responses - self.storage_response_producer = Producer( - client=self.pulsar_client, - topic=storage_management_response_topic, - schema=StorageManagementResponse, - metrics=storage_response_metrics, - ) + # Register for config push notifications + self.register_config_handler(self.on_collection_config) def create_indexes(self, session): @@ -232,7 +205,7 @@ class Processor(TriplesStoreService): if not self.collection_exists(user, collection): error_msg = ( f"Collection {collection} does not exist. " - f"Create it first with tg-set-collection." + f"Create it first via collection management API." ) logger.error(error_msg) raise ValueError(error_msg) @@ -277,42 +250,7 @@ class Processor(TriplesStoreService): help=f'Neo4j database (default: {default_database})' ) - async def start(self): - """Start the processor and its storage management consumer""" - await super().start() - await self.storage_request_consumer.start() - await self.storage_response_producer.start() - - async def on_storage_management(self, message, consumer, flow): - """Handle storage management requests""" - request = message.value() - logger.info(f"Storage management request: {request.operation} for {request.user}/{request.collection}") - - try: - if request.operation == "create-collection": - await self.handle_create_collection(request) - elif request.operation == "delete-collection": - await self.handle_delete_collection(request) - else: - response = StorageManagementResponse( - error=Error( - type="invalid_operation", - message=f"Unknown operation: {request.operation}" - ) - ) - await self.storage_response_producer.send(response) - - except Exception as e: - logger.error(f"Error processing storage management request: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="processing_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) - - def collection_exists(self, user, collection): + def _collection_exists_in_db(self, user, collection): """Check if collection metadata node exists""" with self.io.session(database=self.db) as session: result = session.run( @@ -322,7 +260,7 @@ class Processor(TriplesStoreService): ) return bool(list(result)) - def create_collection(self, user, collection): + def _create_collection_in_db(self, user, collection): """Create collection metadata node""" import datetime with self.io.session(database=self.db) as session: @@ -334,38 +272,28 @@ class Processor(TriplesStoreService): ) logger.info(f"Created collection metadata node for {user}/{collection}") - async def handle_create_collection(self, request): - """Create collection metadata in Neo4j""" + async def create_collection(self, user: str, collection: str, metadata: dict): + """Create collection metadata in Neo4j via config push""" try: - if self.collection_exists(request.user, request.collection): - logger.info(f"Collection {request.user}/{request.collection} already exists") + if self._collection_exists_in_db(user, collection): + logger.info(f"Collection {user}/{collection} already exists") else: - self.create_collection(request.user, request.collection) - logger.info(f"Created collection {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse(error=None) - await self.storage_response_producer.send(response) + self._create_collection_in_db(user, collection) + logger.info(f"Created collection {user}/{collection}") except Exception as e: - logger.error(f"Failed to create collection: {e}", exc_info=True) - response = StorageManagementResponse( - error=Error( - type="creation_error", - message=str(e) - ) - ) - await self.storage_response_producer.send(response) + logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) + raise - async def handle_delete_collection(self, request): - """Delete all data for a specific collection""" + async def delete_collection(self, user: str, collection: str): + """Delete all data for a specific collection via config push""" try: with self.io.session(database=self.db) as session: # Delete all nodes for this user and collection node_result = session.run( "MATCH (n:Node {user: $user, collection: $collection}) " "DETACH DELETE n", - user=request.user, collection=request.collection + user=user, collection=collection ) nodes_deleted = node_result.consume().counters.nodes_deleted @@ -373,7 +301,7 @@ class Processor(TriplesStoreService): literal_result = session.run( "MATCH (n:Literal {user: $user, collection: $collection}) " "DETACH DELETE n", - user=request.user, collection=request.collection + user=user, collection=collection ) literals_deleted = literal_result.consume().counters.nodes_deleted @@ -383,21 +311,14 @@ class Processor(TriplesStoreService): metadata_result = session.run( "MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "DELETE c", - user=request.user, collection=request.collection + user=user, collection=collection ) metadata_deleted = metadata_result.consume().counters.nodes_deleted - logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {request.user}/{request.collection}") - - # Send success response - response = StorageManagementResponse( - error=None # No error means success - ) - await self.storage_response_producer.send(response) - logger.info(f"Successfully deleted collection {request.user}/{request.collection}") + logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") except Exception as e: - logger.error(f"Failed to delete collection: {e}") + logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise def run(): diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 839f3afa..0a7c6081 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -111,21 +111,6 @@ class LibraryTableStore: ); """); - logger.debug("collections table...") - - self.cassandra.execute(""" - CREATE TABLE IF NOT EXISTS collections ( - user text, - collection text, - name text, - description text, - tags set, - created_at timestamp, - updated_at timestamp, - PRIMARY KEY (user, collection) - ); - """); - logger.info("Cassandra schema OK.") def prepare_statements(self): @@ -202,43 +187,6 @@ class LibraryTableStore: LIMIT 1 """) - # Collection management statements - self.insert_collection_stmt = self.cassandra.prepare(""" - INSERT INTO collections - (user, collection, name, description, tags, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - """) - - self.update_collection_stmt = self.cassandra.prepare(""" - UPDATE collections - SET name = ?, description = ?, tags = ?, updated_at = ? - WHERE user = ? AND collection = ? - """) - - self.get_collection_stmt = self.cassandra.prepare(""" - SELECT collection, name, description, tags, created_at, updated_at - FROM collections - WHERE user = ? AND collection = ? - """) - - self.list_collections_stmt = self.cassandra.prepare(""" - SELECT collection, name, description, tags, created_at, updated_at - FROM collections - WHERE user = ? - """) - - self.delete_collection_stmt = self.cassandra.prepare(""" - DELETE FROM collections - WHERE user = ? AND collection = ? - """) - - self.collection_exists_stmt = self.cassandra.prepare(""" - SELECT collection - FROM collections - WHERE user = ? AND collection = ? - LIMIT 1 - """) - self.list_processing_stmt = self.cassandra.prepare(""" SELECT id, document_id, time, flow, collection, tags @@ -390,7 +338,6 @@ class LibraryTableStore: for m in row[5] ], tags = row[6] if row[6] else [], - object_id = row[7], ) for row in resp ] @@ -436,7 +383,6 @@ class LibraryTableStore: for m in row[4] ], tags = row[5] if row[5] else [], - object_id = row[6], ) logger.debug("Done") @@ -572,146 +518,3 @@ class LibraryTableStore: logger.debug("Done") return lst - - - - # Collection management methods - - async def ensure_collection_exists(self, user, collection): - """Ensure collection metadata record exists, create if not""" - try: - resp = await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.collection_exists_stmt, [user, collection] - ) - if resp: - return - import datetime - now = datetime.datetime.now() - await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.insert_collection_stmt, - [user, collection, collection, "", set(), now, now] - ) - logger.debug(f"Created collection metadata for {user}/{collection}") - except Exception as e: - logger.error(f"Error ensuring collection exists: {e}") - raise - - async def list_collections(self, user, tag_filter=None): - """List collections for a user, optionally filtered by tags""" - try: - resp = await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.list_collections_stmt, [user] - ) - collections = [] - for row in resp: - collection_data = { - "user": user, - "collection": row[0], - "name": row[1] or row[0], - "description": row[2] or "", - "tags": list(row[3]) if row[3] else [], - "created_at": row[4].isoformat() if row[4] else "", - "updated_at": row[5].isoformat() if row[5] else "" - } - if tag_filter: - collection_tags = set(collection_data["tags"]) - filter_tags = set(tag_filter) - if not filter_tags.intersection(collection_tags): - continue - collections.append(collection_data) - return collections - except Exception as e: - logger.error(f"Error listing collections: {e}") - raise - - async def update_collection(self, user, collection, name=None, description=None, tags=None): - """Update collection metadata""" - try: - resp = await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.get_collection_stmt, [user, collection] - ) - if not resp: - raise RequestError(f"Collection {collection} not found") - row = resp.one() - current_name = row[1] or collection - current_description = row[2] or "" - current_tags = set(row[3]) if row[3] else set() - new_name = name if name is not None else current_name - new_description = description if description is not None else current_description - new_tags = set(tags) if tags is not None else current_tags - import datetime - now = datetime.datetime.now() - await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.update_collection_stmt, - [new_name, new_description, new_tags, now, user, collection] - ) - return { - "user": user, "collection": collection, "name": new_name, - "description": new_description, "tags": list(new_tags), - "updated_at": now.isoformat() - } - except Exception as e: - logger.error(f"Error updating collection: {e}") - raise - - async def delete_collection(self, user, collection): - """Delete collection metadata record""" - try: - await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.delete_collection_stmt, [user, collection] - ) - logger.debug(f"Deleted collection metadata for {user}/{collection}") - except Exception as e: - logger.error(f"Error deleting collection metadata: {e}") - raise - - async def get_collection(self, user, collection): - """Get collection metadata""" - try: - resp = await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.get_collection_stmt, [user, collection] - ) - if not resp: - return None - row = resp.one() - return { - "user": user, "collection": row[0], "name": row[1] or row[0], - "description": row[2] or "", "tags": list(row[3]) if row[3] else [], - "created_at": row[4].isoformat() if row[4] else "", - "updated_at": row[5].isoformat() if row[5] else "" - } - except Exception as e: - logger.error(f"Error getting collection: {e}") - raise - - async def create_collection(self, user, collection, name=None, description=None, tags=None): - """Create a new collection metadata record""" - try: - import datetime - now = datetime.datetime.now() - - # Set defaults for optional parameters - name = name if name is not None else collection - description = description if description is not None else "" - tags = tags if tags is not None else set() - - await asyncio.get_event_loop().run_in_executor( - None, self.cassandra.execute, self.insert_collection_stmt, - [user, collection, name, description, tags, now, now] - ) - - logger.info(f"Created collection {user}/{collection}") - - # Return the created collection data - return { - "user": user, - "collection": collection, - "name": name, - "description": description, - "tags": list(tags) if isinstance(tags, set) else tags, - "created_at": now.isoformat(), - "updated_at": now.isoformat() - } - except Exception as e: - logger.error(f"Error creating collection: {e}") - raise diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index bf74291b..2c84d21c 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -16,6 +16,8 @@ from mcp.server.fastmcp import FastMCP, Context from mcp.types import TextContent from websockets.asyncio.client import connect +from trustgraph.base.logging import add_logging_args, setup_logging + from . tg_socket import WebSocketManager @dataclass @@ -2040,9 +2042,15 @@ def main(): parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)') parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)') parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)') - + + # Add logging arguments + add_logging_args(parser) + args = parser.parse_args() - + + # Setup logging before creating server + setup_logging(vars(args)) + # Create and run the MCP server server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) server.run() diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index 4646df38..1068d91e 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 09a92a44..a96f8338 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", "pulsar-client", "google-cloud-aiplatform", "prometheus-client", diff --git a/trustgraph/pyproject.toml b/trustgraph/pyproject.toml index 0e2d6089..4ddbf562 100644 --- a/trustgraph/pyproject.toml +++ b/trustgraph/pyproject.toml @@ -10,12 +10,12 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=1.7,<1.8", - "trustgraph-bedrock>=1.7,<1.8", - "trustgraph-cli>=1.7,<1.8", - "trustgraph-embeddings-hf>=1.7,<1.8", - "trustgraph-flow>=1.7,<1.8", - "trustgraph-vertexai>=1.7,<1.8", + "trustgraph-base>=1.8,<1.9", + "trustgraph-bedrock>=1.8,<1.9", + "trustgraph-cli>=1.8,<1.9", + "trustgraph-embeddings-hf>=1.8,<1.9", + "trustgraph-flow>=1.8,<1.9", + "trustgraph-vertexai>=1.8,<1.9", ] classifiers = [ "Programming Language :: Python :: 3",