mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge pull request #604 from trustgraph-ai/release/v1.8
Merge 1.8 into master
This commit is contained in:
commit
8dff90f36f
233 changed files with 13294 additions and 4542 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Setup packages
|
- name: Setup packages
|
||||||
run: make update-package-versions VERSION=1.7.999
|
run: make update-package-versions VERSION=1.8.999
|
||||||
|
|
||||||
- name: Setup environment
|
- name: Setup environment
|
||||||
run: python3 -m venv env
|
run: python3 -m venv env
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# TrustGraph Librarian API
|
# TrustGraph Librarian API
|
||||||
|
|
||||||
This API provides document library management for TrustGraph. It handles document storage,
|
This API provides document library management for TrustGraph. It handles document storage,
|
||||||
metadata management, and processing orchestration using hybrid storage (MinIO for content,
|
metadata management, and processing orchestration using hybrid storage (S3-compatible object
|
||||||
Cassandra for metadata) with multi-user support.
|
storage for content, Cassandra for metadata) with multi-user support.
|
||||||
|
|
||||||
## Request/response
|
## Request/response
|
||||||
|
|
||||||
|
|
@ -374,13 +374,14 @@ await client.add_processing(
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Hybrid Storage**: MinIO for content, Cassandra for metadata
|
- **Hybrid Storage**: S3-compatible object storage (MinIO, Ceph RGW, AWS S3, etc.) for content, Cassandra for metadata
|
||||||
- **Multi-user Support**: User-based document ownership and access control
|
- **Multi-user Support**: User-based document ownership and access control
|
||||||
- **Rich Metadata**: RDF-style metadata triples and tagging system
|
- **Rich Metadata**: RDF-style metadata triples and tagging system
|
||||||
- **Processing Integration**: Automatic triggering of document processing workflows
|
- **Processing Integration**: Automatic triggering of document processing workflows
|
||||||
- **Content Types**: Support for multiple document formats (PDF, text, etc.)
|
- **Content Types**: Support for multiple document formats (PDF, text, etc.)
|
||||||
- **Collection Management**: Optional document grouping by collection
|
- **Collection Management**: Optional document grouping by collection
|
||||||
- **Metadata Search**: Query documents by metadata criteria
|
- **Metadata Search**: Query documents by metadata criteria
|
||||||
|
- **Flexible Storage Backend**: Works with any S3-compatible storage (MinIO, Ceph RADOS Gateway, AWS S3, Cloudflare R2, etc.)
|
||||||
|
|
||||||
## Use Cases
|
## Use Cases
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -233,9 +233,13 @@ When a user initiates collection deletion through the librarian service:
|
||||||
|
|
||||||
#### Collection Management Interface
|
#### Collection Management Interface
|
||||||
|
|
||||||
All store writers implement a standardized collection management interface with a common schema:
|
**⚠️ LEGACY APPROACH - REPLACED BY CONFIG-BASED PATTERN**
|
||||||
|
|
||||||
**Message Schema (`StorageManagementRequest`):**
|
The queue-based architecture described below has been replaced with a config-based approach using `CollectionConfigHandler`. All storage backends now receive collection updates via config push messages instead of dedicated management queues.
|
||||||
|
|
||||||
|
~~All store writers implement a standardized collection management interface with a common schema:~~
|
||||||
|
|
||||||
|
~~**Message Schema (`StorageManagementRequest`):**~~
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"operation": "create-collection" | "delete-collection",
|
"operation": "create-collection" | "delete-collection",
|
||||||
|
|
@ -244,24 +248,26 @@ All store writers implement a standardized collection management interface with
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Queue Architecture:**
|
~~**Queue Architecture:**~~
|
||||||
- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores
|
- ~~**Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores~~
|
||||||
- **Object Store Management Queue** (`object-storage-management`): Object/document stores
|
- ~~**Object Store Management Queue** (`object-storage-management`): Object/document stores~~
|
||||||
- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores
|
- ~~**Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores~~
|
||||||
- **Storage Response Queue** (`storage-management-response`): All responses sent here
|
- ~~**Storage Response Queue** (`storage-management-response`): All responses sent here~~
|
||||||
|
|
||||||
Each store writer implements:
|
**Current Implementation:**
|
||||||
- **Collection Management Handler**: Processes `StorageManagementRequest` messages
|
|
||||||
- **Create Collection Operation**: Establishes collection in storage backend
|
|
||||||
- **Delete Collection Operation**: Removes all data associated with collection
|
|
||||||
- **Collection State Tracking**: Maintains knowledge of which collections exist
|
|
||||||
- **Message Processing**: Consumes from dedicated management queue
|
|
||||||
- **Status Reporting**: Returns success/failure via `StorageManagementResponse`
|
|
||||||
- **Idempotent Operations**: Safe to call create/delete multiple times
|
|
||||||
|
|
||||||
**Supported Operations:**
|
All storage backends now use `CollectionConfigHandler`:
|
||||||
- `create-collection`: Create collection in storage backend
|
- **Config Push Integration**: Storage services register for config push notifications
|
||||||
- `delete-collection`: Remove all collection data from storage backend
|
- **Automatic Synchronization**: Collections created/deleted based on config changes
|
||||||
|
- **Declarative Model**: Collections defined in config service, backends sync to match
|
||||||
|
- **No Request/Response**: Eliminates coordination overhead and response tracking
|
||||||
|
- **Collection State Tracking**: Maintained via `known_collections` cache
|
||||||
|
- **Idempotent Operations**: Safe to process same config multiple times
|
||||||
|
|
||||||
|
Each storage backend implements:
|
||||||
|
- `create_collection(user: str, collection: str, metadata: dict)` - Create collection structures
|
||||||
|
- `delete_collection(user: str, collection: str)` - Remove all collection data
|
||||||
|
- `collection_exists(user: str, collection: str) -> bool` - Validate before writes
|
||||||
|
|
||||||
#### Cassandra Triple Store Refactor
|
#### Cassandra Triple Store Refactor
|
||||||
|
|
||||||
|
|
@ -365,62 +371,33 @@ Comprehensive testing will cover:
|
||||||
- `triples_collection` table for SPO queries and deletion tracking
|
- `triples_collection` table for SPO queries and deletion tracking
|
||||||
- Collection deletion implemented with read-then-delete pattern
|
- Collection deletion implemented with read-then-delete pattern
|
||||||
|
|
||||||
### 🔄 In Progress Components
|
### ✅ Migration to Config-Based Pattern - COMPLETED
|
||||||
|
|
||||||
1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
|
**All storage backends have been migrated from the queue-based pattern to the config-based `CollectionConfigHandler` pattern.**
|
||||||
- Update `update_collection()` to send "create-collection" to storage backends
|
|
||||||
- Wait for confirmations from all storage processors
|
|
||||||
- Handle creation failures appropriately
|
|
||||||
|
|
||||||
2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar)
|
Completed migrations:
|
||||||
- Check if collection exists when document submitted
|
- ✅ `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||||
- If not exists: Create collection with defaults before processing document
|
- ✅ `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||||
- Trigger same "create-collection" broadcast as `tg-set-collection`
|
- ✅ `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||||
- Ensure collection established before document flows to storage processors
|
- ✅ `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||||
|
- ✅ `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||||
|
|
||||||
### ❌ Pending Components
|
All backends now:
|
||||||
|
- Inherit from `CollectionConfigHandler`
|
||||||
|
- Register for config push notifications via `self.register_config_handler(self.on_collection_config)`
|
||||||
|
- Implement `create_collection(user, collection, metadata)` and `delete_collection(user, collection)`
|
||||||
|
- Use `collection_exists(user, collection)` to validate before writes
|
||||||
|
- Automatically sync with config service changes
|
||||||
|
|
||||||
1. **Collection State Tracking** - Need to implement in each storage backend:
|
Legacy queue-based infrastructure removed:
|
||||||
- **Cassandra Triples**: Use `triples_collection` table with marker triples
|
- ✅ Removed `StorageManagementRequest` and `StorageManagementResponse` schemas
|
||||||
- **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes
|
- ✅ Removed storage management queue topic definitions
|
||||||
- **Qdrant/Milvus/Pinecone**: Use native collection APIs
|
- ✅ Removed storage management consumer/producer from all backends
|
||||||
- **Cassandra Objects**: Add collection metadata tracking
|
- ✅ Removed `on_storage_management` handlers from all backends
|
||||||
|
|
||||||
2. **Storage Management Handlers** - Need "create-collection" support in 12 files:
|
|
||||||
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
|
||||||
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
|
||||||
- Plus any other storage implementations
|
|
||||||
|
|
||||||
3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods
|
|
||||||
|
|
||||||
4. **Query Operation Handling** - Update queries to return empty for non-existent collections
|
|
||||||
|
|
||||||
### Next Implementation Steps
|
|
||||||
|
|
||||||
**Phase 1: Core Infrastructure (2-3 days)**
|
|
||||||
1. Add collection state tracking methods to all storage backends
|
|
||||||
2. Implement `collection_exists()` and `create_collection()` methods
|
|
||||||
|
|
||||||
**Phase 2: Storage Handlers (1 week)**
|
|
||||||
3. Add "create-collection" handlers to all storage processors
|
|
||||||
4. Add write validation to reject non-existent collections
|
|
||||||
5. Update query handling for non-existent collections
|
|
||||||
|
|
||||||
**Phase 3: Collection Manager (2-3 days)**
|
|
||||||
6. Update collection_manager to broadcast creates
|
|
||||||
7. Implement response tracking and error handling
|
|
||||||
|
|
||||||
**Phase 4: Testing (3-5 days)**
|
|
||||||
8. End-to-end testing of explicit creation workflow
|
|
||||||
9. Test all storage backends
|
|
||||||
10. Validate error handling and edge cases
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,29 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
TrustGraph uses Python's built-in `logging` module for all logging operations. This provides a standardized, flexible approach to logging across all components of the system.
|
TrustGraph uses Python's built-in `logging` module for all logging operations, with centralized configuration and optional Loki integration for log aggregation. This provides a standardized, flexible approach to logging across all components of the system.
|
||||||
|
|
||||||
## Default Configuration
|
## Default Configuration
|
||||||
|
|
||||||
### Logging Level
|
### Logging Level
|
||||||
- **Default Level**: `INFO`
|
- **Default Level**: `INFO`
|
||||||
- **Debug Mode**: `DEBUG` (enabled via command-line argument)
|
- **Configurable via**: `--log-level` command-line argument
|
||||||
- **Production**: `WARNING` or `ERROR` as appropriate
|
- **Choices**: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
|
||||||
|
|
||||||
### Output Destination
|
### Output Destinations
|
||||||
All logs should be written to **standard output (stdout)** to ensure compatibility with containerized environments and log aggregation systems.
|
1. **Console (stdout)**: Always enabled - ensures compatibility with containerized environments
|
||||||
|
2. **Loki**: Optional centralized log aggregation (enabled by default, can be disabled)
|
||||||
|
|
||||||
|
## Centralized Logging Module
|
||||||
|
|
||||||
|
All logging configuration is managed by `trustgraph.base.logging` module, which provides:
|
||||||
|
- `add_logging_args(parser)` - Adds standard logging CLI arguments
|
||||||
|
- `setup_logging(args)` - Configures logging from parsed arguments
|
||||||
|
|
||||||
|
This module is used by all server-side components:
|
||||||
|
- AsyncProcessor-based services
|
||||||
|
- API Gateway
|
||||||
|
- MCP Server
|
||||||
|
|
||||||
## Implementation Guidelines
|
## Implementation Guidelines
|
||||||
|
|
||||||
|
|
@ -26,39 +38,80 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Centralized Configuration
|
The logger name is automatically used as a label in Loki for filtering and searching.
|
||||||
|
|
||||||
The logging configuration should be centralized in `async_processor.py` (or a dedicated logging configuration module) since it's inherited by much of the codebase:
|
### 2. Service Initialization
|
||||||
|
|
||||||
|
All server-side services automatically get logging configuration through the centralized module:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import logging
|
from trustgraph.base import add_logging_args, setup_logging
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def setup_logging(log_level='INFO'):
|
def main():
|
||||||
"""Configure logging for the entire application"""
|
|
||||||
logging.basicConfig(
|
|
||||||
level=getattr(logging, log_level.upper()),
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[logging.StreamHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
|
||||||
'--log-level',
|
|
||||||
default='INFO',
|
|
||||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
||||||
help='Set the logging level (default: INFO)'
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
# In main execution
|
# Add standard logging arguments (includes Loki configuration)
|
||||||
if __name__ == '__main__':
|
add_logging_args(parser)
|
||||||
args = parse_args()
|
|
||||||
setup_logging(args.log_level)
|
# Add your service-specific arguments
|
||||||
|
parser.add_argument('--port', type=int, default=8080)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = vars(args)
|
||||||
|
|
||||||
|
# Setup logging early in startup
|
||||||
|
setup_logging(args)
|
||||||
|
|
||||||
|
# Rest of your service initialization
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("Service starting...")
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Logging Best Practices
|
### 3. Command-Line Arguments
|
||||||
|
|
||||||
|
All services support these logging arguments:
|
||||||
|
|
||||||
|
**Log Level:**
|
||||||
|
```bash
|
||||||
|
--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Loki Configuration:**
|
||||||
|
```bash
|
||||||
|
--loki-enabled # Enable Loki (default)
|
||||||
|
--no-loki-enabled # Disable Loki
|
||||||
|
--loki-url URL # Loki push URL (default: http://loki:3100/loki/api/v1/push)
|
||||||
|
--loki-username USERNAME # Optional authentication
|
||||||
|
--loki-password PASSWORD # Optional authentication
|
||||||
|
```
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
```bash
|
||||||
|
# Default - INFO level, Loki enabled
|
||||||
|
./my-service
|
||||||
|
|
||||||
|
# Debug mode, console only
|
||||||
|
./my-service --log-level DEBUG --no-loki-enabled
|
||||||
|
|
||||||
|
# Custom Loki server with auth
|
||||||
|
./my-service --loki-url http://loki.prod:3100/loki/api/v1/push \
|
||||||
|
--loki-username admin --loki-password secret
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Environment Variables
|
||||||
|
|
||||||
|
Loki configuration supports environment variable fallbacks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LOKI_URL=http://loki.prod:3100/loki/api/v1/push
|
||||||
|
export LOKI_USERNAME=admin
|
||||||
|
export LOKI_PASSWORD=secret
|
||||||
|
```
|
||||||
|
|
||||||
|
Command-line arguments take precedence over environment variables.
|
||||||
|
|
||||||
|
### 5. Logging Best Practices
|
||||||
|
|
||||||
#### Log Levels Usage
|
#### Log Levels Usage
|
||||||
- **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit)
|
- **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit)
|
||||||
|
|
@ -89,20 +142,25 @@ if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug(f"Debug data: {debug_data}")
|
logger.debug(f"Debug data: {debug_data}")
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Structured Logging
|
### 6. Structured Logging with Loki
|
||||||
|
|
||||||
For complex data, use structured logging:
|
For complex data, use structured logging with extra tags for Loki:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
logger.info("Request processed", extra={
|
logger.info("Request processed", extra={
|
||||||
|
'tags': {
|
||||||
'request_id': request_id,
|
'request_id': request_id,
|
||||||
'duration_ms': duration,
|
'user_id': user_id,
|
||||||
'status_code': status_code,
|
'status': 'success'
|
||||||
'user_id': user_id
|
}
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Exception Logging
|
These tags become searchable labels in Loki, in addition to automatic labels:
|
||||||
|
- `severity` - Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
- `logger` - Module name (from `__name__`)
|
||||||
|
|
||||||
|
### 7. Exception Logging
|
||||||
|
|
||||||
Always include stack traces for exceptions:
|
Always include stack traces for exceptions:
|
||||||
|
|
||||||
|
|
@ -114,9 +172,13 @@ except Exception as e:
|
||||||
raise
|
raise
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Async Logging Considerations
|
### 8. Async Logging Considerations
|
||||||
|
|
||||||
For async code, ensure thread-safe logging:
|
The logging system uses non-blocking queued handlers for Loki:
|
||||||
|
- Console output is synchronous (fast)
|
||||||
|
- Loki output is queued with 500-message buffer
|
||||||
|
- Background thread handles Loki transmission
|
||||||
|
- No blocking of main application code
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -124,46 +186,165 @@ import logging
|
||||||
|
|
||||||
async def async_operation():
|
async def async_operation():
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
# Logging is thread-safe and won't block async operations
|
||||||
logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}")
|
logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Environment Variables
|
## Loki Integration
|
||||||
|
|
||||||
Support environment-based configuration as a fallback:
|
### Architecture
|
||||||
|
|
||||||
|
The logging system uses Python's built-in `QueueHandler` and `QueueListener` for non-blocking Loki integration:
|
||||||
|
|
||||||
|
1. **QueueHandler**: Logs are placed in a 500-message queue (non-blocking)
|
||||||
|
2. **Background Thread**: QueueListener sends logs to Loki asynchronously
|
||||||
|
3. **Graceful Degradation**: If Loki is unavailable, console logging continues
|
||||||
|
|
||||||
|
### Automatic Labels
|
||||||
|
|
||||||
|
Every log sent to Loki includes:
|
||||||
|
- `processor`: Processor identity (e.g., `config-svc`, `text-completion`, `embeddings`)
|
||||||
|
- `severity`: Log level (DEBUG, INFO, etc.)
|
||||||
|
- `logger`: Module name (e.g., `trustgraph.gateway.service`, `trustgraph.agent.react.service`)
|
||||||
|
|
||||||
|
### Custom Labels
|
||||||
|
|
||||||
|
Add custom labels via the `extra` parameter:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
logger.info("User action", extra={
|
||||||
|
'tags': {
|
||||||
log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', 'INFO')
|
'user_id': user_id,
|
||||||
|
'action': 'document_upload',
|
||||||
|
'collection': collection_name
|
||||||
|
}
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Querying Logs in Loki
|
||||||
|
|
||||||
|
```logql
|
||||||
|
# All logs from a specific processor (recommended - matches Prometheus metrics)
|
||||||
|
{processor="config-svc"}
|
||||||
|
{processor="text-completion"}
|
||||||
|
{processor="embeddings"}
|
||||||
|
|
||||||
|
# Error logs from a specific processor
|
||||||
|
{processor="config-svc", severity="ERROR"}
|
||||||
|
|
||||||
|
# Error logs from all processors
|
||||||
|
{severity="ERROR"}
|
||||||
|
|
||||||
|
# Logs from a specific processor with text filter
|
||||||
|
{processor="text-completion"} |= "Processing"
|
||||||
|
|
||||||
|
# All logs from API gateway
|
||||||
|
{processor="api-gateway"}
|
||||||
|
|
||||||
|
# Logs from processors matching pattern
|
||||||
|
{processor=~".*-completion"}
|
||||||
|
|
||||||
|
# Logs with custom tags
|
||||||
|
{processor="api-gateway"} | json | user_id="12345"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Graceful Degradation
|
||||||
|
|
||||||
|
If Loki is unavailable or `python-logging-loki` is not installed:
|
||||||
|
- Warning message printed to console
|
||||||
|
- Console logging continues normally
|
||||||
|
- Application continues running
|
||||||
|
- No retry logic for Loki connection (fail fast, degrade gracefully)
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
During tests, consider using a different logging configuration:
|
During tests, consider using a different logging configuration:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# In test setup
|
# In test setup
|
||||||
logging.getLogger().setLevel(logging.WARNING) # Reduce noise during tests
|
import logging
|
||||||
|
|
||||||
|
# Reduce noise during tests
|
||||||
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Or disable Loki for tests
|
||||||
|
setup_logging({'log_level': 'WARNING', 'loki_enabled': False})
|
||||||
```
|
```
|
||||||
|
|
||||||
## Monitoring Integration
|
## Monitoring Integration
|
||||||
|
|
||||||
Ensure log format is compatible with monitoring tools:
|
### Standard Format
|
||||||
- Include timestamps in ISO format
|
All logs use consistent format:
|
||||||
- Use consistent field names
|
```
|
||||||
- Include correlation IDs where applicable
|
2025-01-09 10:30:45,123 - trustgraph.gateway.service - INFO - Request processed
|
||||||
- Structure logs for easy parsing (JSON format for production)
|
```
|
||||||
|
|
||||||
|
Format components:
|
||||||
|
- Timestamp (ISO format with milliseconds)
|
||||||
|
- Logger name (module path)
|
||||||
|
- Log level
|
||||||
|
- Message
|
||||||
|
|
||||||
|
### Loki Queries for Monitoring
|
||||||
|
|
||||||
|
Common monitoring queries:
|
||||||
|
|
||||||
|
```logql
|
||||||
|
# Error rate by processor
|
||||||
|
rate({severity="ERROR"}[5m]) by (processor)
|
||||||
|
|
||||||
|
# Top error-producing processors
|
||||||
|
topk(5, count_over_time({severity="ERROR"}[1h]) by (processor))
|
||||||
|
|
||||||
|
# Recent errors with processor name
|
||||||
|
{severity="ERROR"} | line_format "{{.processor}}: {{.message}}"
|
||||||
|
|
||||||
|
# All agent processors
|
||||||
|
{processor=~".*agent.*"} |= "exception"
|
||||||
|
|
||||||
|
# Specific processor error count
|
||||||
|
count_over_time({processor="config-svc", severity="ERROR"}[1h])
|
||||||
|
```
|
||||||
|
|
||||||
## Security Considerations
|
## Security Considerations
|
||||||
|
|
||||||
- Never log sensitive information (passwords, API keys, personal data)
|
- **Never log sensitive information** (passwords, API keys, personal data, tokens)
|
||||||
- Sanitize user input before logging
|
- **Sanitize user input** before logging
|
||||||
- Use placeholders for sensitive fields: `user_id=****1234`
|
- **Use placeholders** for sensitive fields: `user_id=****1234`
|
||||||
|
- **Loki authentication**: Use `--loki-username` and `--loki-password` for secure deployments
|
||||||
|
- **Secure transport**: Use HTTPS for Loki URL in production: `https://loki.prod:3100/loki/api/v1/push`
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The centralized logging module requires:
|
||||||
|
- `python-logging-loki` - For Loki integration (optional, graceful degradation if missing)
|
||||||
|
|
||||||
|
Already included in `trustgraph-base/pyproject.toml` and `requirements.txt`.
|
||||||
|
|
||||||
## Migration Path
|
## Migration Path
|
||||||
|
|
||||||
For existing code using print statements:
|
For existing code:
|
||||||
1. Replace `print()` with appropriate logger calls
|
|
||||||
2. Choose appropriate log levels based on message importance
|
1. **Services already using AsyncProcessor**: No changes needed, Loki support is automatic
|
||||||
3. Add context to make logs more useful
|
2. **Services not using AsyncProcessor** (api-gateway, mcp-server): Already updated
|
||||||
4. Test logging output at different levels
|
3. **CLI tools**: Out of scope - continue using print() or simple logging
|
||||||
|
|
||||||
|
### From print() to logging:
|
||||||
|
```python
|
||||||
|
# Before
|
||||||
|
print(f"Processing document {doc_id}")
|
||||||
|
|
||||||
|
# After
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Processing document {doc_id}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Summary
|
||||||
|
|
||||||
|
| Argument | Default | Environment Variable | Description |
|
||||||
|
|----------|---------|---------------------|-------------|
|
||||||
|
| `--log-level` | `INFO` | - | Console and Loki log level |
|
||||||
|
| `--loki-enabled` | `True` | - | Enable Loki logging |
|
||||||
|
| `--loki-url` | `http://loki:3100/loki/api/v1/push` | `LOKI_URL` | Loki push endpoint |
|
||||||
|
| `--loki-username` | `None` | `LOKI_USERNAME` | Loki auth username |
|
||||||
|
| `--loki-password` | `None` | `LOKI_PASSWORD` | Loki auth password |
|
||||||
|
|
|
||||||
258
docs/tech-specs/minio-to-s3-migration.md
Normal file
258
docs/tech-specs/minio-to-s3-migration.md
Normal file
|
|
@ -0,0 +1,258 @@
|
||||||
|
# Tech Spec: S3-Compatible Storage Backend Support
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Librarian service uses S3-compatible object storage for document blob storage. This spec documents the implementation that enables support for any S3-compatible backend including MinIO, Ceph RADOS Gateway (RGW), AWS S3, Cloudflare R2, DigitalOcean Spaces, and others.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Storage Components
|
||||||
|
- **Blob Storage**: S3-compatible object storage via `minio` Python client library
|
||||||
|
- **Metadata Storage**: Cassandra (stores object_id mapping and document metadata)
|
||||||
|
- **Affected Component**: Librarian service only
|
||||||
|
- **Storage Pattern**: Hybrid storage with metadata in Cassandra, content in S3-compatible storage
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
- **Library**: `minio` Python client (supports any S3-compatible API)
|
||||||
|
- **Location**: `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||||
|
- **Operations**:
|
||||||
|
- `add()` - Store blob with UUID object_id
|
||||||
|
- `get()` - Retrieve blob by object_id
|
||||||
|
- `remove()` - Delete blob by object_id
|
||||||
|
- `ensure_bucket()` - Create bucket if not exists
|
||||||
|
- **Bucket**: `library`
|
||||||
|
- **Object Path**: `doc/{object_id}`
|
||||||
|
- **Supported MIME Types**: `text/plain`, `application/pdf`
|
||||||
|
|
||||||
|
### Key Files
|
||||||
|
1. `trustgraph-flow/trustgraph/librarian/blob_store.py` - BlobStore implementation
|
||||||
|
2. `trustgraph-flow/trustgraph/librarian/librarian.py` - BlobStore initialization
|
||||||
|
3. `trustgraph-flow/trustgraph/librarian/service.py` - Service configuration
|
||||||
|
4. `trustgraph-flow/pyproject.toml` - Dependencies (`minio` package)
|
||||||
|
5. `docs/apis/api-librarian.md` - API documentation
|
||||||
|
|
||||||
|
## Supported Storage Backends
|
||||||
|
|
||||||
|
The implementation works with any S3-compatible object storage system:
|
||||||
|
|
||||||
|
### Tested/Supported
|
||||||
|
- **Ceph RADOS Gateway (RGW)** - Distributed storage system with S3 API (default configuration)
|
||||||
|
- **MinIO** - Lightweight self-hosted object storage
|
||||||
|
- **Garage** - Lightweight geo-distributed S3-compatible storage
|
||||||
|
|
||||||
|
### Should Work (S3-Compatible)
|
||||||
|
- **AWS S3** - Amazon's cloud object storage
|
||||||
|
- **Cloudflare R2** - Cloudflare's S3-compatible storage
|
||||||
|
- **DigitalOcean Spaces** - DigitalOcean's object storage
|
||||||
|
- **Wasabi** - S3-compatible cloud storage
|
||||||
|
- **Backblaze B2** - S3-compatible backup storage
|
||||||
|
- Any other service implementing the S3 REST API
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### CLI Arguments
|
||||||
|
|
||||||
|
```bash
|
||||||
|
librarian \
|
||||||
|
--object-store-endpoint <hostname:port> \
|
||||||
|
--object-store-access-key <access_key> \
|
||||||
|
--object-store-secret-key <secret_key> \
|
||||||
|
[--object-store-use-ssl] \
|
||||||
|
[--object-store-region <region>]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Do not include `http://` or `https://` in the endpoint. Use `--object-store-use-ssl` to enable HTTPS.
|
||||||
|
|
||||||
|
### Environment Variables (Alternative)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
OBJECT_STORE_ENDPOINT=<hostname:port>
|
||||||
|
OBJECT_STORE_ACCESS_KEY=<access_key>
|
||||||
|
OBJECT_STORE_SECRET_KEY=<secret_key>
|
||||||
|
OBJECT_STORE_USE_SSL=true|false # Optional, default: false
|
||||||
|
OBJECT_STORE_REGION=<region> # Optional
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
**Ceph RADOS Gateway (default):**
|
||||||
|
```bash
|
||||||
|
--object-store-endpoint ceph-rgw:7480 \
|
||||||
|
--object-store-access-key object-user \
|
||||||
|
--object-store-secret-key object-password
|
||||||
|
```
|
||||||
|
|
||||||
|
**MinIO:**
|
||||||
|
```bash
|
||||||
|
--object-store-endpoint minio:9000 \
|
||||||
|
--object-store-access-key minioadmin \
|
||||||
|
--object-store-secret-key minioadmin
|
||||||
|
```
|
||||||
|
|
||||||
|
**Garage (S3-compatible):**
|
||||||
|
```bash
|
||||||
|
--object-store-endpoint garage:3900 \
|
||||||
|
--object-store-access-key GK000000000000000000000001 \
|
||||||
|
--object-store-secret-key b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
|
||||||
|
```
|
||||||
|
|
||||||
|
**AWS S3 with SSL:**
|
||||||
|
```bash
|
||||||
|
--object-store-endpoint s3.amazonaws.com \
|
||||||
|
--object-store-access-key AKIAIOSFODNN7EXAMPLE \
|
||||||
|
--object-store-secret-key wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY \
|
||||||
|
--object-store-use-ssl \
|
||||||
|
--object-store-region us-east-1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
All S3-compatible backends require AWS Signature Version 4 (or v2) authentication:
|
||||||
|
|
||||||
|
- **Access Key** - Public identifier (like username)
|
||||||
|
- **Secret Key** - Private signing key (like password)
|
||||||
|
|
||||||
|
The MinIO Python client handles all signature calculation automatically.
|
||||||
|
|
||||||
|
### Creating Credentials
|
||||||
|
|
||||||
|
**For MinIO:**
|
||||||
|
```bash
|
||||||
|
# Use default credentials or create user via MinIO Console
|
||||||
|
minioadmin / minioadmin
|
||||||
|
```
|
||||||
|
|
||||||
|
**For Ceph RGW:**
|
||||||
|
```bash
|
||||||
|
radosgw-admin user create --uid="trustgraph" --display-name="TrustGraph Service"
|
||||||
|
# Returns access_key and secret_key
|
||||||
|
```
|
||||||
|
|
||||||
|
**For AWS S3:**
|
||||||
|
- Create IAM user with S3 permissions
|
||||||
|
- Generate access key in AWS Console
|
||||||
|
|
||||||
|
## Library Selection: MinIO Python Client
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Lightweight (~500KB vs boto3's ~50MB)
|
||||||
|
- S3-compatible - works with any S3 API endpoint
|
||||||
|
- Simpler API than boto3 for basic operations
|
||||||
|
- Already in use, no migration needed
|
||||||
|
- Battle-tested with MinIO and other S3 systems
|
||||||
|
|
||||||
|
## BlobStore Implementation
|
||||||
|
|
||||||
|
**Location:** `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from minio import Minio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""
|
||||||
|
S3-compatible blob storage for document content.
|
||||||
|
Supports MinIO, Ceph RGW, AWS S3, and other S3-compatible backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, endpoint, access_key, secret_key, bucket_name,
|
||||||
|
use_ssl=False, region=None):
|
||||||
|
"""
|
||||||
|
Initialize S3-compatible blob storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: S3 endpoint (e.g., "minio:9000", "ceph-rgw:7480")
|
||||||
|
access_key: S3 access key
|
||||||
|
secret_key: S3 secret key
|
||||||
|
bucket_name: Bucket name for storage
|
||||||
|
use_ssl: Use HTTPS instead of HTTP (default: False)
|
||||||
|
region: S3 region (optional, e.g., "us-east-1")
|
||||||
|
"""
|
||||||
|
self.client = Minio(
|
||||||
|
endpoint=endpoint,
|
||||||
|
access_key=access_key,
|
||||||
|
secret_key=secret_key,
|
||||||
|
secure=use_ssl,
|
||||||
|
region=region,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
|
||||||
|
protocol = "https" if use_ssl else "http"
|
||||||
|
logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}")
|
||||||
|
|
||||||
|
self.ensure_bucket()
|
||||||
|
|
||||||
|
def ensure_bucket(self):
|
||||||
|
"""Create bucket if it doesn't exist"""
|
||||||
|
found = self.client.bucket_exists(bucket_name=self.bucket_name)
|
||||||
|
if not found:
|
||||||
|
self.client.make_bucket(bucket_name=self.bucket_name)
|
||||||
|
logger.info(f"Created bucket {self.bucket_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Bucket {self.bucket_name} already exists")
|
||||||
|
|
||||||
|
async def add(self, object_id, blob, kind):
|
||||||
|
"""Store blob in S3-compatible storage"""
|
||||||
|
self.client.put_object(
|
||||||
|
bucket_name=self.bucket_name,
|
||||||
|
object_name=f"doc/{object_id}",
|
||||||
|
length=len(blob),
|
||||||
|
data=io.BytesIO(blob),
|
||||||
|
content_type=kind,
|
||||||
|
)
|
||||||
|
logger.debug("Add blob complete")
|
||||||
|
|
||||||
|
async def remove(self, object_id):
|
||||||
|
"""Delete blob from S3-compatible storage"""
|
||||||
|
self.client.remove_object(
|
||||||
|
bucket_name=self.bucket_name,
|
||||||
|
object_name=f"doc/{object_id}",
|
||||||
|
)
|
||||||
|
logger.debug("Remove blob complete")
|
||||||
|
|
||||||
|
async def get(self, object_id):
|
||||||
|
"""Retrieve blob from S3-compatible storage"""
|
||||||
|
resp = self.client.get_object(
|
||||||
|
bucket_name=self.bucket_name,
|
||||||
|
object_name=f"doc/{object_id}",
|
||||||
|
)
|
||||||
|
return resp.read()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Benefits
|
||||||
|
|
||||||
|
1. **No Vendor Lock-in** - Works with any S3-compatible storage
|
||||||
|
2. **Lightweight** - MinIO client is only ~500KB
|
||||||
|
3. **Simple Configuration** - Just endpoint + credentials
|
||||||
|
4. **No Data Migration** - Drop-in replacement between backends
|
||||||
|
5. **Battle-Tested** - MinIO client works with all major S3 implementations
|
||||||
|
|
||||||
|
## Implementation Status
|
||||||
|
|
||||||
|
All code has been updated to use generic S3 parameter names:
|
||||||
|
|
||||||
|
- ✅ `blob_store.py` - Updated to accept `endpoint`, `access_key`, `secret_key`
|
||||||
|
- ✅ `librarian.py` - Updated parameter names
|
||||||
|
- ✅ `service.py` - Updated CLI arguments and configuration
|
||||||
|
- ✅ Documentation updated
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
1. **SSL/TLS Support** - Add `--s3-use-ssl` flag for HTTPS
|
||||||
|
2. **Retry Logic** - Implement exponential backoff for transient failures
|
||||||
|
3. **Presigned URLs** - Generate temporary upload/download URLs
|
||||||
|
4. **Multi-region Support** - Replicate blobs across regions
|
||||||
|
5. **CDN Integration** - Serve blobs via CDN
|
||||||
|
6. **Storage Classes** - Use S3 storage classes for cost optimization
|
||||||
|
7. **Lifecycle Policies** - Automatic archival/deletion
|
||||||
|
8. **Versioning** - Store multiple versions of blobs
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- MinIO Python Client: https://min.io/docs/minio/linux/developers/python/API.html
|
||||||
|
- Ceph RGW S3 API: https://docs.ceph.com/en/latest/radosgw/s3/
|
||||||
|
- S3 API Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html
|
||||||
772
docs/tech-specs/multi-tenant-support.md
Normal file
772
docs/tech-specs/multi-tenant-support.md
Normal file
|
|
@ -0,0 +1,772 @@
|
||||||
|
# Technical Specification: Multi-Tenant Support
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Enable multi-tenant deployments by fixing parameter name mismatches that prevent queue customization and adding Cassandra keyspace parameterization.
|
||||||
|
|
||||||
|
## Architecture Context
|
||||||
|
|
||||||
|
### Flow-Based Queue Resolution
|
||||||
|
|
||||||
|
The TrustGraph system uses a **flow-based architecture** for dynamic queue resolution, which inherently supports multi-tenancy:
|
||||||
|
|
||||||
|
- **Flow Definitions** are stored in Cassandra and specify queue names via interface definitions
|
||||||
|
- **Queue names use templates** with `{id}` variables that are replaced with flow instance IDs
|
||||||
|
- **Services dynamically resolve queues** by looking up flow configurations at request time
|
||||||
|
- **Each tenant can have unique flows** with different queue names, providing isolation
|
||||||
|
|
||||||
|
Example flow interface definition:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"interfaces": {
|
||||||
|
"triples-store": "persistent://tg/flow/triples-store:{id}",
|
||||||
|
"graph-embeddings-store": "persistent://tg/flow/graph-embeddings-store:{id}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When tenant A starts flow `tenant-a-prod` and tenant B starts flow `tenant-b-prod`, they automatically get isolated queues:
|
||||||
|
- `persistent://tg/flow/triples-store:tenant-a-prod`
|
||||||
|
- `persistent://tg/flow/triples-store:tenant-b-prod`
|
||||||
|
|
||||||
|
**Services correctly designed for multi-tenancy:**
|
||||||
|
- ✅ **Knowledge Management (cores)** - Dynamically resolves queues from flow configuration passed in requests
|
||||||
|
|
||||||
|
**Services needing fixes:**
|
||||||
|
- 🔴 **Config Service** - Parameter name mismatch prevents queue customization
|
||||||
|
- 🔴 **Librarian Service** - Hardcoded storage management topics (discussed below)
|
||||||
|
- 🔴 **All Services** - Cannot customize Cassandra keyspace
|
||||||
|
|
||||||
|
## Problem Statement
|
||||||
|
|
||||||
|
### Issue #1: Parameter Name Mismatch in AsyncProcessor
|
||||||
|
- **CLI defines:** `--config-queue` (unclear naming)
|
||||||
|
- **Argparse converts to:** `config_queue` (in params dict)
|
||||||
|
- **Code looks for:** `config_push_queue`
|
||||||
|
- **Result:** Parameter is ignored, defaults to `persistent://tg/config/config`
|
||||||
|
- **Impact:** Affects all 32+ services inheriting from AsyncProcessor
|
||||||
|
- **Blocks:** Multi-tenant deployments cannot use tenant-specific config queues
|
||||||
|
- **Solution:** Rename CLI parameter to `--config-push-queue` for clarity (breaking change acceptable since feature is currently broken)
|
||||||
|
|
||||||
|
### Issue #2: Parameter Name Mismatch in Config Service
|
||||||
|
- **CLI defines:** `--push-queue` (ambiguous naming)
|
||||||
|
- **Argparse converts to:** `push_queue` (in params dict)
|
||||||
|
- **Code looks for:** `config_push_queue`
|
||||||
|
- **Result:** Parameter is ignored
|
||||||
|
- **Impact:** Config service cannot use custom push queue
|
||||||
|
- **Solution:** Rename CLI parameter to `--config-push-queue` for consistency and clarity (breaking change acceptable)
|
||||||
|
|
||||||
|
### Issue #3: Hardcoded Cassandra Keyspace
|
||||||
|
- **Current:** Keyspace hardcoded as `"config"`, `"knowledge"`, `"librarian"` in various services
|
||||||
|
- **Result:** Cannot customize keyspace for multi-tenant deployments
|
||||||
|
- **Impact:** Config, cores, and librarian services
|
||||||
|
- **Blocks:** Multiple tenants cannot use separate Cassandra keyspaces
|
||||||
|
|
||||||
|
### Issue #4: Collection Management Architecture ✅ COMPLETED
|
||||||
|
- **Previous:** Collections stored in Cassandra librarian keyspace via separate collections table
|
||||||
|
- **Previous:** Librarian used 4 hardcoded storage management topics to coordinate collection create/delete:
|
||||||
|
- `vector_storage_management_topic`
|
||||||
|
- `object_storage_management_topic`
|
||||||
|
- `triples_storage_management_topic`
|
||||||
|
- `storage_management_response_topic`
|
||||||
|
- **Problems (Resolved):**
|
||||||
|
- Hardcoded topics could not be customized for multi-tenant deployments
|
||||||
|
- Complex async coordination between librarian and 4+ storage services
|
||||||
|
- Separate Cassandra table and management infrastructure
|
||||||
|
- Non-persistent request/response queues for critical operations
|
||||||
|
- **Solution Implemented:** Migrated collections to config service storage, use config push for distribution
|
||||||
|
- **Status:** All storage backends migrated to `CollectionConfigHandler` pattern
|
||||||
|
|
||||||
|
## Solution
|
||||||
|
|
||||||
|
This spec addresses Issues #1, #2, #3, and #4.
|
||||||
|
|
||||||
|
### Part 1: Fix Parameter Name Mismatches
|
||||||
|
|
||||||
|
#### Change 1: AsyncProcessor Base Class - Rename CLI Parameter
|
||||||
|
**File:** `trustgraph-base/trustgraph/base/async_processor.py`
|
||||||
|
**Line:** 260-264
|
||||||
|
|
||||||
|
**Current:**
|
||||||
|
```python
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-queue',
|
||||||
|
default=default_config_queue,
|
||||||
|
help=f'Config push queue {default_config_queue}',
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fixed:**
|
||||||
|
```python
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-push-queue',
|
||||||
|
default=default_config_queue,
|
||||||
|
help=f'Config push queue (default: {default_config_queue})',
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Clearer, more explicit naming
|
||||||
|
- Matches the internal variable name `config_push_queue`
|
||||||
|
- Breaking change acceptable since feature is currently non-functional
|
||||||
|
- No code change needed in params.get() - it already looks for the correct name
|
||||||
|
|
||||||
|
#### Change 2: Config Service - Rename CLI Parameter
|
||||||
|
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
|
||||||
|
**Line:** 276-279
|
||||||
|
|
||||||
|
**Current:**
|
||||||
|
```python
|
||||||
|
parser.add_argument(
|
||||||
|
'--push-queue',
|
||||||
|
default=default_config_push_queue,
|
||||||
|
help=f'Config push queue (default: {default_config_push_queue})'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fixed:**
|
||||||
|
```python
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-push-queue',
|
||||||
|
default=default_config_push_queue,
|
||||||
|
help=f'Config push queue (default: {default_config_push_queue})'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Clearer naming - "config-push-queue" is more explicit than just "push-queue"
|
||||||
|
- Matches the internal variable name `config_push_queue`
|
||||||
|
- Consistent with AsyncProcessor's `--config-push-queue` parameter
|
||||||
|
- Breaking change acceptable since feature is currently non-functional
|
||||||
|
- No code change needed in params.get() - it already looks for the correct name
|
||||||
|
|
||||||
|
### Part 2: Add Cassandra Keyspace Parameterization
|
||||||
|
|
||||||
|
#### Change 3: Add Keyspace Parameter to cassandra_config Module
|
||||||
|
**File:** `trustgraph-base/trustgraph/base/cassandra_config.py`
|
||||||
|
|
||||||
|
**Add CLI argument** (in `add_cassandra_args()` function):
|
||||||
|
```python
|
||||||
|
parser.add_argument(
|
||||||
|
'--cassandra-keyspace',
|
||||||
|
default=None,
|
||||||
|
help='Cassandra keyspace (default: service-specific)'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Add environment variable support** (in `resolve_cassandra_config()` function):
|
||||||
|
```python
|
||||||
|
keyspace = params.get(
|
||||||
|
"cassandra_keyspace",
|
||||||
|
os.environ.get("CASSANDRA_KEYSPACE")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Update return value** of `resolve_cassandra_config()`:
|
||||||
|
- Currently returns: `(hosts, username, password)`
|
||||||
|
- Change to return: `(hosts, username, password, keyspace)`
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Consistent with existing Cassandra configuration pattern
|
||||||
|
- Available to all services via `add_cassandra_args()`
|
||||||
|
- Supports both CLI and environment variable configuration
|
||||||
|
|
||||||
|
#### Change 4: Config Service - Use Parameterized Keyspace
|
||||||
|
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
|
||||||
|
|
||||||
|
**Line 30** - Remove hardcoded keyspace:
|
||||||
|
```python
|
||||||
|
# DELETE THIS LINE:
|
||||||
|
keyspace = "config"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Lines 69-73** - Update cassandra config resolution:
|
||||||
|
|
||||||
|
**Current:**
|
||||||
|
```python
|
||||||
|
cassandra_host, cassandra_username, cassandra_password = \
|
||||||
|
resolve_cassandra_config(params)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fixed:**
|
||||||
|
```python
|
||||||
|
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||||
|
resolve_cassandra_config(params, default_keyspace="config")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Maintains backward compatibility with "config" as default
|
||||||
|
- Allows override via `--cassandra-keyspace` or `CASSANDRA_KEYSPACE`
|
||||||
|
|
||||||
|
#### Change 5: Cores/Knowledge Service - Use Parameterized Keyspace
|
||||||
|
**File:** `trustgraph-flow/trustgraph/cores/service.py`
|
||||||
|
|
||||||
|
**Line 37** - Remove hardcoded keyspace:
|
||||||
|
```python
|
||||||
|
# DELETE THIS LINE:
|
||||||
|
keyspace = "knowledge"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Update cassandra config resolution** (similar location as config service):
|
||||||
|
```python
|
||||||
|
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||||
|
resolve_cassandra_config(params, default_keyspace="knowledge")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Change 6: Librarian Service - Use Parameterized Keyspace
|
||||||
|
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
|
||||||
|
|
||||||
|
**Line 51** - Remove hardcoded keyspace:
|
||||||
|
```python
|
||||||
|
# DELETE THIS LINE:
|
||||||
|
keyspace = "librarian"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Update cassandra config resolution** (similar location as config service):
|
||||||
|
```python
|
||||||
|
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||||
|
resolve_cassandra_config(params, default_keyspace="librarian")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Part 3: Migrate Collection Management to Config Service
|
||||||
|
|
||||||
|
#### Overview
|
||||||
|
Migrate collections from Cassandra librarian keyspace to config service storage. This eliminates hardcoded storage management topics and simplifies the architecture by using the existing config push mechanism for distribution.
|
||||||
|
|
||||||
|
#### Current Architecture
|
||||||
|
```
|
||||||
|
API Request → Gateway → Librarian Service
|
||||||
|
↓
|
||||||
|
CollectionManager
|
||||||
|
↓
|
||||||
|
Cassandra Collections Table (librarian keyspace)
|
||||||
|
↓
|
||||||
|
Broadcast to 4 Storage Management Topics (hardcoded)
|
||||||
|
↓
|
||||||
|
Wait for 4+ Storage Service Responses
|
||||||
|
↓
|
||||||
|
Response to Gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
#### New Architecture
|
||||||
|
```
|
||||||
|
API Request → Gateway → Librarian Service
|
||||||
|
↓
|
||||||
|
CollectionManager
|
||||||
|
↓
|
||||||
|
Config Service API (put/delete/getvalues)
|
||||||
|
↓
|
||||||
|
Cassandra Config Table (class='collections', key='user:collection')
|
||||||
|
↓
|
||||||
|
Config Push (to all subscribers on config-push-queue)
|
||||||
|
↓
|
||||||
|
All Storage Services receive config update independently
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Change 7: Collection Manager - Use Config Service API
|
||||||
|
**File:** `trustgraph-flow/trustgraph/librarian/collection_manager.py`
|
||||||
|
|
||||||
|
**Remove:**
|
||||||
|
- `LibraryTableStore` usage (Lines 33, 40-41)
|
||||||
|
- Storage management producers initialization (Lines 86-140)
|
||||||
|
- `on_storage_response` method (Lines 400-430)
|
||||||
|
- `pending_deletions` tracking (Lines 57, 90-96, and usage throughout)
|
||||||
|
|
||||||
|
**Add:**
|
||||||
|
- Config service client for API calls (request/response pattern)
|
||||||
|
|
||||||
|
**Config Client Setup:**
|
||||||
|
```python
|
||||||
|
# In __init__, add config request/response producers/consumers
|
||||||
|
from trustgraph.schema.services.config import ConfigRequest, ConfigResponse
|
||||||
|
|
||||||
|
# Producer for config requests
|
||||||
|
self.config_request_producer = Producer(
|
||||||
|
client=pulsar_client,
|
||||||
|
topic=config_request_queue,
|
||||||
|
schema=ConfigRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consumer for config responses (with correlation ID)
|
||||||
|
self.config_response_consumer = Consumer(
|
||||||
|
taskgroup=taskgroup,
|
||||||
|
client=pulsar_client,
|
||||||
|
flow=None,
|
||||||
|
topic=config_response_queue,
|
||||||
|
subscriber=f"{id}-config",
|
||||||
|
schema=ConfigResponse,
|
||||||
|
handler=self.on_config_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tracking for pending config requests
|
||||||
|
self.pending_config_requests = {} # request_id -> asyncio.Event
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modify `list_collections` (Lines 145-180):**
|
||||||
|
```python
|
||||||
|
async def list_collections(self, user, tag_filter=None, limit=None):
|
||||||
|
"""List collections from config service"""
|
||||||
|
# Send getvalues request to config service
|
||||||
|
request = ConfigRequest(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
operation='getvalues',
|
||||||
|
type='collections',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send request and wait for response
|
||||||
|
response = await self.send_config_request(request)
|
||||||
|
|
||||||
|
# Parse collections from response
|
||||||
|
collections = []
|
||||||
|
for key, value_json in response.values.items():
|
||||||
|
if ":" in key:
|
||||||
|
coll_user, collection = key.split(":", 1)
|
||||||
|
if coll_user == user:
|
||||||
|
metadata = json.loads(value_json)
|
||||||
|
collections.append(CollectionMetadata(**metadata))
|
||||||
|
|
||||||
|
# Apply tag filtering in-memory (as before)
|
||||||
|
if tag_filter:
|
||||||
|
collections = [c for c in collections if any(tag in c.tags for tag in tag_filter)]
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
if limit:
|
||||||
|
collections = collections[:limit]
|
||||||
|
|
||||||
|
return collections
|
||||||
|
|
||||||
|
async def send_config_request(self, request):
|
||||||
|
"""Send config request and wait for response"""
|
||||||
|
event = asyncio.Event()
|
||||||
|
self.pending_config_requests[request.id] = event
|
||||||
|
|
||||||
|
await self.config_request_producer.send(request)
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
return self.pending_config_requests.pop(request.id + "_response")
|
||||||
|
|
||||||
|
async def on_config_response(self, message, consumer, flow):
|
||||||
|
"""Handle config response"""
|
||||||
|
response = message.value()
|
||||||
|
if response.id in self.pending_config_requests:
|
||||||
|
self.pending_config_requests[response.id + "_response"] = response
|
||||||
|
self.pending_config_requests[response.id].set()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modify `update_collection` (Lines 182-312):**
|
||||||
|
```python
|
||||||
|
async def update_collection(self, user, collection, name, description, tags):
|
||||||
|
"""Update collection via config service"""
|
||||||
|
# Create metadata
|
||||||
|
metadata = CollectionMetadata(
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send put request to config service
|
||||||
|
request = ConfigRequest(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
operation='put',
|
||||||
|
type='collections',
|
||||||
|
key=f'{user}:{collection}',
|
||||||
|
value=json.dumps(metadata.to_dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.send_config_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise RuntimeError(f"Config update failed: {response.error.message}")
|
||||||
|
|
||||||
|
# Config service will trigger config push automatically
|
||||||
|
# Storage services will receive update and create collections
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modify `delete_collection` (Lines 314-398):**
|
||||||
|
```python
|
||||||
|
async def delete_collection(self, user, collection):
|
||||||
|
"""Delete collection via config service"""
|
||||||
|
# Send delete request to config service
|
||||||
|
request = ConfigRequest(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
operation='delete',
|
||||||
|
type='collections',
|
||||||
|
key=f'{user}:{collection}',
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.send_config_request(request)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise RuntimeError(f"Config delete failed: {response.error.message}")
|
||||||
|
|
||||||
|
# Config service will trigger config push automatically
|
||||||
|
# Storage services will receive update and delete collections
|
||||||
|
```
|
||||||
|
|
||||||
|
**Collection Metadata Format:**
|
||||||
|
- Stored in config table as: `class='collections', key='user:collection'`
|
||||||
|
- Value is JSON-serialized CollectionMetadata (without timestamp fields)
|
||||||
|
- Fields: `user`, `collection`, `name`, `description`, `tags`
|
||||||
|
- Example: `class='collections', key='alice:my-docs', value='{"user":"alice","collection":"my-docs","name":"My Documents","description":"...","tags":["work"]}'`
|
||||||
|
|
||||||
|
#### Change 8: Librarian Service - Remove Storage Management Infrastructure
|
||||||
|
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
|
||||||
|
|
||||||
|
**Remove:**
|
||||||
|
- Storage management producers (Lines 173-190):
|
||||||
|
- `vector_storage_management_producer`
|
||||||
|
- `object_storage_management_producer`
|
||||||
|
- `triples_storage_management_producer`
|
||||||
|
- Storage response consumer (Lines 192-201)
|
||||||
|
- `on_storage_response` handler (Lines 467-473)
|
||||||
|
|
||||||
|
**Modify:**
|
||||||
|
- CollectionManager initialization (Lines 215-224) - remove storage producer parameters
|
||||||
|
|
||||||
|
**Note:** External collection API remains unchanged:
|
||||||
|
- `list-collections`
|
||||||
|
- `update-collection`
|
||||||
|
- `delete-collection`
|
||||||
|
|
||||||
|
#### Change 9: Remove Collections Table from LibraryTableStore
|
||||||
|
**File:** `trustgraph-flow/trustgraph/tables/library.py`
|
||||||
|
|
||||||
|
**Delete:**
|
||||||
|
- Collections table CREATE statement (Lines 114-127)
|
||||||
|
- Collections prepared statements (Lines 205-240)
|
||||||
|
- All collection methods (Lines 578-717):
|
||||||
|
- `ensure_collection_exists`
|
||||||
|
- `list_collections`
|
||||||
|
- `update_collection`
|
||||||
|
- `delete_collection`
|
||||||
|
- `get_collection`
|
||||||
|
- `create_collection`
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Collections now stored in config table
|
||||||
|
- Breaking change acceptable - no data migration needed
|
||||||
|
- Simplifies librarian service significantly
|
||||||
|
|
||||||
|
#### Change 10: Storage Services - Config-Based Collection Management ✅ COMPLETED
|
||||||
|
|
||||||
|
**Status:** All 11 storage backends have been migrated to use `CollectionConfigHandler`.
|
||||||
|
|
||||||
|
**Affected Services (11 total):**
|
||||||
|
- Document embeddings: milvus, pinecone, qdrant
|
||||||
|
- Graph embeddings: milvus, pinecone, qdrant
|
||||||
|
- Object storage: cassandra
|
||||||
|
- Triples storage: cassandra, falkordb, memgraph, neo4j
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||||
|
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||||
|
|
||||||
|
**Implementation Pattern (all services):**
|
||||||
|
|
||||||
|
1. **Register config handler in `__init__`:**
|
||||||
|
```python
|
||||||
|
# Add after AsyncProcessor initialization
|
||||||
|
self.register_config_handler(self.on_collection_config)
|
||||||
|
self.known_collections = set() # Track (user, collection) tuples
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Implement config handler:**
|
||||||
|
```python
|
||||||
|
async def on_collection_config(self, config, version):
|
||||||
|
"""Handle collection configuration updates"""
|
||||||
|
logger.info(f"Collection config version: {version}")
|
||||||
|
|
||||||
|
if "collections" not in config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Parse collections from config
|
||||||
|
# Key format: "user:collection" in config["collections"]
|
||||||
|
config_collections = set()
|
||||||
|
for key in config["collections"].keys():
|
||||||
|
if ":" in key:
|
||||||
|
user, collection = key.split(":", 1)
|
||||||
|
config_collections.add((user, collection))
|
||||||
|
|
||||||
|
# Determine changes
|
||||||
|
to_create = config_collections - self.known_collections
|
||||||
|
to_delete = self.known_collections - config_collections
|
||||||
|
|
||||||
|
# Create new collections (idempotent)
|
||||||
|
for user, collection in to_create:
|
||||||
|
try:
|
||||||
|
await self.create_collection_internal(user, collection)
|
||||||
|
self.known_collections.add((user, collection))
|
||||||
|
logger.info(f"Created collection: {user}/{collection}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create {user}/{collection}: {e}")
|
||||||
|
|
||||||
|
# Delete removed collections (idempotent)
|
||||||
|
for user, collection in to_delete:
|
||||||
|
try:
|
||||||
|
await self.delete_collection_internal(user, collection)
|
||||||
|
self.known_collections.discard((user, collection))
|
||||||
|
logger.info(f"Deleted collection: {user}/{collection}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete {user}/{collection}: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Initialize known collections on startup:**
|
||||||
|
```python
|
||||||
|
async def start(self):
|
||||||
|
"""Start the processor"""
|
||||||
|
await super().start()
|
||||||
|
await self.sync_known_collections()
|
||||||
|
|
||||||
|
async def sync_known_collections(self):
|
||||||
|
"""Query backend to populate known_collections set"""
|
||||||
|
# Backend-specific implementation:
|
||||||
|
# - Milvus/Pinecone/Qdrant: List collections/indexes matching naming pattern
|
||||||
|
# - Cassandra: Query keyspaces or collection metadata
|
||||||
|
# - Neo4j/Memgraph/FalkorDB: Query CollectionMetadata nodes
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Refactor existing handler methods:**
|
||||||
|
```python
|
||||||
|
# Rename and remove response sending:
|
||||||
|
# handle_create_collection → create_collection_internal
|
||||||
|
# handle_delete_collection → delete_collection_internal
|
||||||
|
|
||||||
|
async def create_collection_internal(self, user, collection):
|
||||||
|
"""Create collection (idempotent)"""
|
||||||
|
# Same logic as current handle_create_collection
|
||||||
|
# But remove response producer calls
|
||||||
|
# Handle "already exists" gracefully
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_collection_internal(self, user, collection):
|
||||||
|
"""Delete collection (idempotent)"""
|
||||||
|
# Same logic as current handle_delete_collection
|
||||||
|
# But remove response producer calls
|
||||||
|
# Handle "not found" gracefully
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Remove storage management infrastructure:**
|
||||||
|
- Remove `self.storage_request_consumer` setup and start
|
||||||
|
- Remove `self.storage_response_producer` setup
|
||||||
|
- Remove `on_storage_management` dispatcher method
|
||||||
|
- Remove metrics for storage management
|
||||||
|
- Remove imports: `StorageManagementRequest`, `StorageManagementResponse`
|
||||||
|
|
||||||
|
**Backend-Specific Considerations:**
|
||||||
|
|
||||||
|
- **Vector stores (Milvus, Pinecone, Qdrant):** Track logical `(user, collection)` in `known_collections`, but may create multiple backend collections per dimension. Continue lazy creation pattern. Delete operations must remove all dimension variants.
|
||||||
|
|
||||||
|
- **Cassandra Objects:** Collections are row properties, not structures. Track keyspace-level information.
|
||||||
|
|
||||||
|
- **Graph stores (Neo4j, Memgraph, FalkorDB):** Query `CollectionMetadata` nodes on startup. Create/delete metadata nodes on sync.
|
||||||
|
|
||||||
|
- **Cassandra Triples:** Use `KnowledgeGraph` API for collection operations.
|
||||||
|
|
||||||
|
**Key Design Points:**
|
||||||
|
|
||||||
|
- **Eventual consistency:** No request/response mechanism, config push is broadcast
|
||||||
|
- **Idempotency:** All create/delete operations must be safe to retry
|
||||||
|
- **Error handling:** Log errors but don't block config updates
|
||||||
|
- **Self-healing:** Failed operations will retry on next config push
|
||||||
|
- **Collection key format:** `"user:collection"` in `config["collections"]`
|
||||||
|
|
||||||
|
#### Change 11: Update Collection Schema - Remove Timestamps
|
||||||
|
**File:** `trustgraph-base/trustgraph/schema/services/collection.py`
|
||||||
|
|
||||||
|
**Modify CollectionMetadata (Lines 13-21):**
|
||||||
|
Remove `created_at` and `updated_at` fields:
|
||||||
|
```python
|
||||||
|
class CollectionMetadata(Record):
|
||||||
|
user = String()
|
||||||
|
collection = String()
|
||||||
|
name = String()
|
||||||
|
description = String()
|
||||||
|
tags = Array(String())
|
||||||
|
# Remove: created_at = String()
|
||||||
|
# Remove: updated_at = String()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Modify CollectionManagementRequest (Lines 25-47):**
|
||||||
|
Remove timestamp fields:
|
||||||
|
```python
|
||||||
|
class CollectionManagementRequest(Record):
|
||||||
|
operation = String()
|
||||||
|
user = String()
|
||||||
|
collection = String()
|
||||||
|
timestamp = String()
|
||||||
|
name = String()
|
||||||
|
description = String()
|
||||||
|
tags = Array(String())
|
||||||
|
# Remove: created_at = String()
|
||||||
|
# Remove: updated_at = String()
|
||||||
|
tag_filter = Array(String())
|
||||||
|
limit = Integer()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rationale:**
|
||||||
|
- Timestamps don't add value for collections
|
||||||
|
- Config service maintains its own version tracking
|
||||||
|
- Simplifies schema and reduces storage
|
||||||
|
|
||||||
|
#### Benefits of Config Service Migration
|
||||||
|
|
||||||
|
1. ✅ **Eliminates hardcoded storage management topics** - Solves multi-tenant blocker
|
||||||
|
2. ✅ **Simpler coordination** - No complex async waiting for 4+ storage responses
|
||||||
|
3. ✅ **Eventual consistency** - Storage services update independently via config push
|
||||||
|
4. ✅ **Better reliability** - Persistent config push vs non-persistent request/response
|
||||||
|
5. ✅ **Unified configuration model** - Collections treated as configuration
|
||||||
|
6. ✅ **Reduces complexity** - Removes ~300 lines of coordination code
|
||||||
|
7. ✅ **Multi-tenant ready** - Config already supports tenant isolation via keyspace
|
||||||
|
8. ✅ **Version tracking** - Config service version mechanism provides audit trail
|
||||||
|
|
||||||
|
## Implementation Notes
|
||||||
|
|
||||||
|
### Backward Compatibility
|
||||||
|
|
||||||
|
**Parameter Changes:**
|
||||||
|
- CLI parameter renames are breaking changes but acceptable (feature currently non-functional)
|
||||||
|
- Services work without parameters (use defaults)
|
||||||
|
- Default keyspaces preserved: "config", "knowledge", "librarian"
|
||||||
|
- Default queue: `persistent://tg/config/config`
|
||||||
|
|
||||||
|
**Collection Management:**
|
||||||
|
- **Breaking change:** Collections table removed from librarian keyspace
|
||||||
|
- **No data migration provided** - acceptable for this phase
|
||||||
|
- External collection API unchanged (list/update/delete operations)
|
||||||
|
- Collection metadata format simplified (timestamps removed)
|
||||||
|
|
||||||
|
### Testing Requirements
|
||||||
|
|
||||||
|
**Parameter Testing:**
|
||||||
|
1. Verify `--config-push-queue` parameter works on graph-embeddings service
|
||||||
|
2. Verify `--config-push-queue` parameter works on text-completion service
|
||||||
|
3. Verify `--config-push-queue` parameter works on config service
|
||||||
|
4. Verify `--cassandra-keyspace` parameter works for config service
|
||||||
|
5. Verify `--cassandra-keyspace` parameter works for cores service
|
||||||
|
6. Verify `--cassandra-keyspace` parameter works for librarian service
|
||||||
|
7. Verify services work without parameters (uses defaults)
|
||||||
|
8. Verify multi-tenant deployment with custom queue names and keyspace
|
||||||
|
|
||||||
|
**Collection Management Testing:**
|
||||||
|
9. Verify `list-collections` operation via config service
|
||||||
|
10. Verify `update-collection` creates/updates in config table
|
||||||
|
11. Verify `delete-collection` removes from config table
|
||||||
|
12. Verify config push is triggered on collection updates
|
||||||
|
13. Verify tag filtering works with config-based storage
|
||||||
|
14. Verify collection operations work without timestamp fields
|
||||||
|
|
||||||
|
### Multi-Tenant Deployment Example
|
||||||
|
```bash
|
||||||
|
# Tenant: tg-dev
|
||||||
|
graph-embeddings \
|
||||||
|
-p pulsar+ssl://broker:6651 \
|
||||||
|
--pulsar-api-key <KEY> \
|
||||||
|
--config-push-queue persistent://tg-dev/config/config
|
||||||
|
|
||||||
|
config-service \
|
||||||
|
-p pulsar+ssl://broker:6651 \
|
||||||
|
--pulsar-api-key <KEY> \
|
||||||
|
--config-push-queue persistent://tg-dev/config/config \
|
||||||
|
--cassandra-keyspace tg_dev_config
|
||||||
|
```
|
||||||
|
|
||||||
|
## Impact Analysis
|
||||||
|
|
||||||
|
### Services Affected by Change 1-2 (CLI Parameter Rename)
|
||||||
|
All services inheriting from AsyncProcessor or FlowProcessor:
|
||||||
|
- config-service
|
||||||
|
- cores-service
|
||||||
|
- librarian-service
|
||||||
|
- graph-embeddings
|
||||||
|
- document-embeddings
|
||||||
|
- text-completion-* (all providers)
|
||||||
|
- extract-* (all extractors)
|
||||||
|
- query-* (all query services)
|
||||||
|
- retrieval-* (all RAG services)
|
||||||
|
- storage-* (all storage services)
|
||||||
|
- And 20+ more services
|
||||||
|
|
||||||
|
### Services Affected by Changes 3-6 (Cassandra Keyspace)
|
||||||
|
- config-service
|
||||||
|
- cores-service
|
||||||
|
- librarian-service
|
||||||
|
|
||||||
|
### Services Affected by Changes 7-11 (Collection Management)
|
||||||
|
|
||||||
|
**Immediate Changes:**
|
||||||
|
- librarian-service (collection_manager.py, service.py)
|
||||||
|
- tables/library.py (collections table removal)
|
||||||
|
- schema/services/collection.py (timestamp removal)
|
||||||
|
|
||||||
|
**Completed Changes (Change 10):** ✅
|
||||||
|
- All storage services (11 total) - migrated to config push for collection updates via `CollectionConfigHandler`
|
||||||
|
- Storage management schema removed from `storage.py`
|
||||||
|
|
||||||
|
## Future Considerations
|
||||||
|
|
||||||
|
### Per-User Keyspace Model
|
||||||
|
|
||||||
|
Some services use **per-user keyspaces** dynamically, where each user gets their own Cassandra keyspace:
|
||||||
|
|
||||||
|
**Services with per-user keyspaces:**
|
||||||
|
1. **Triples Query Service** (`trustgraph-flow/trustgraph/query/triples/cassandra/service.py:65`)
|
||||||
|
- Uses `keyspace=query.user`
|
||||||
|
2. **Objects Query Service** (`trustgraph-flow/trustgraph/query/objects/cassandra/service.py:479`)
|
||||||
|
- Uses `keyspace=self.sanitize_name(user)`
|
||||||
|
3. **KnowledgeGraph Direct Access** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py:18`)
|
||||||
|
- Default parameter `keyspace="trustgraph"`
|
||||||
|
|
||||||
|
**Status:** These are **not modified** in this specification.
|
||||||
|
|
||||||
|
**Future Review Required:**
|
||||||
|
- Evaluate whether per-user keyspace model creates tenant isolation issues
|
||||||
|
- Consider if multi-tenant deployments need keyspace prefix patterns (e.g., `tenant_a_user1`)
|
||||||
|
- Review for potential user ID collision across tenants
|
||||||
|
- Assess if single shared keyspace per tenant with user-based row isolation is preferable
|
||||||
|
|
||||||
|
**Note:** This does not block the current multi-tenant implementation but should be reviewed before production multi-tenant deployments.
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Parameter Fixes (Changes 1-6)
|
||||||
|
- Fix `--config-push-queue` parameter naming
|
||||||
|
- Add `--cassandra-keyspace` parameter support
|
||||||
|
- **Outcome:** Multi-tenant queue and keyspace configuration enabled
|
||||||
|
|
||||||
|
### Phase 2: Collection Management Migration (Changes 7-9, 11)
|
||||||
|
- Migrate collection storage to config service
|
||||||
|
- Remove collections table from librarian
|
||||||
|
- Update collection schema (remove timestamps)
|
||||||
|
- **Outcome:** Eliminates hardcoded storage management topics, simplifies librarian
|
||||||
|
|
||||||
|
### Phase 3: Storage Service Updates (Change 10) ✅ COMPLETED
|
||||||
|
- Updated all storage services to use config push for collections via `CollectionConfigHandler`
|
||||||
|
- Removed storage management request/response infrastructure
|
||||||
|
- Removed legacy schema definitions
|
||||||
|
- **Outcome:** Complete config-based collection management achieved
|
||||||
|
|
||||||
|
## References
|
||||||
|
- GitHub Issue: https://github.com/trustgraph-ai/trustgraph/issues/582
|
||||||
|
- Related Files:
|
||||||
|
- `trustgraph-base/trustgraph/base/async_processor.py`
|
||||||
|
- `trustgraph-base/trustgraph/base/cassandra_config.py`
|
||||||
|
- `trustgraph-base/trustgraph/schema/core/topic.py`
|
||||||
|
- `trustgraph-base/trustgraph/schema/services/collection.py`
|
||||||
|
- `trustgraph-flow/trustgraph/config/service/service.py`
|
||||||
|
- `trustgraph-flow/trustgraph/cores/service.py`
|
||||||
|
- `trustgraph-flow/trustgraph/librarian/service.py`
|
||||||
|
- `trustgraph-flow/trustgraph/librarian/collection_manager.py`
|
||||||
|
- `trustgraph-flow/trustgraph/tables/library.py`
|
||||||
761
docs/tech-specs/ontology-extract-phase-2.md
Normal file
761
docs/tech-specs/ontology-extract-phase-2.md
Normal file
|
|
@ -0,0 +1,761 @@
|
||||||
|
# Ontology Knowledge Extraction - Phase 2 Refactor
|
||||||
|
|
||||||
|
**Status**: Draft
|
||||||
|
**Author**: Analysis Session 2025-12-03
|
||||||
|
**Related**: `ontology.md`, `ontorag.md`
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document identifies inconsistencies in the current ontology-based knowledge extraction system and proposes a refactor to improve LLM performance and reduce information loss.
|
||||||
|
|
||||||
|
## Current Implementation
|
||||||
|
|
||||||
|
### How It Works Now
|
||||||
|
|
||||||
|
1. **Ontology Loading** (`ontology_loader.py`)
|
||||||
|
- Loads ontology JSON with keys like `"fo/Recipe"`, `"fo/Food"`, `"fo/produces"`
|
||||||
|
- Class IDs include namespace prefix in the key itself
|
||||||
|
- Example from `food.ontology`:
|
||||||
|
```json
|
||||||
|
"classes": {
|
||||||
|
"fo/Recipe": {
|
||||||
|
"uri": "http://purl.org/ontology/fo/Recipe",
|
||||||
|
"rdfs:comment": "A Recipe is a combination..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Prompt Construction** (`extract.py:299-307`, `ontology-prompt.md`)
|
||||||
|
- Template receives `classes`, `object_properties`, `datatype_properties` dicts
|
||||||
|
- Template iterates: `{% for class_id, class_def in classes.items() %}`
|
||||||
|
- LLM sees: `**fo/Recipe**: A Recipe is a combination...`
|
||||||
|
- Example output format shows:
|
||||||
|
```json
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Response Parsing** (`extract.py:382-428`)
|
||||||
|
- Expects JSON array: `[{"subject": "...", "predicate": "...", "object": "..."}]`
|
||||||
|
- Validates against ontology subset
|
||||||
|
- Expands URIs via `expand_uri()` (extract.py:473-521)
|
||||||
|
|
||||||
|
4. **URI Expansion** (`extract.py:473-521`)
|
||||||
|
- Checks if value is in `ontology_subset.classes` dict
|
||||||
|
- If found, extracts URI from class definition
|
||||||
|
- If not found, constructs URI: `f"https://trustgraph.ai/ontology/{ontology_id}#{value}"`
|
||||||
|
|
||||||
|
### Data Flow Example
|
||||||
|
|
||||||
|
**Ontology JSON → Loader → Prompt:**
|
||||||
|
```
|
||||||
|
"fo/Recipe" → classes["fo/Recipe"] → LLM sees "**fo/Recipe**"
|
||||||
|
```
|
||||||
|
|
||||||
|
**LLM → Parser → Output:**
|
||||||
|
```
|
||||||
|
"Recipe" → not in classes["fo/Recipe"] → constructs URI → LOSES original URI
|
||||||
|
"fo/Recipe" → found in classes → uses original URI → PRESERVES URI
|
||||||
|
```
|
||||||
|
|
||||||
|
## Problems Identified
|
||||||
|
|
||||||
|
### 1. **Inconsistent Examples in Prompt**
|
||||||
|
|
||||||
|
**Issue**: The prompt template shows class IDs with prefixes (`fo/Recipe`) but the example output uses unprefixed class names (`Recipe`).
|
||||||
|
|
||||||
|
**Location**: `ontology-prompt.md:5-52`
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Ontology Classes:
|
||||||
|
- **fo/Recipe**: A Recipe is...
|
||||||
|
|
||||||
|
## Example Output:
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**: LLM receives conflicting signals about what format to use.
|
||||||
|
|
||||||
|
### 2. **Information Loss in URI Expansion**
|
||||||
|
|
||||||
|
**Issue**: When LLM returns unprefixed class names following the example, `expand_uri()` can't find them in the ontology dict and constructs fallback URIs, losing the original proper URIs.
|
||||||
|
|
||||||
|
**Location**: `extract.py:494-500`
|
||||||
|
|
||||||
|
```python
|
||||||
|
if value in ontology_subset.classes: # Looks for "Recipe"
|
||||||
|
class_def = ontology_subset.classes[value] # But key is "fo/Recipe"
|
||||||
|
if isinstance(class_def, dict) and 'uri' in class_def:
|
||||||
|
return class_def['uri'] # Never reached!
|
||||||
|
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" # Fallback
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**:
|
||||||
|
- Original URI: `http://purl.org/ontology/fo/Recipe`
|
||||||
|
- Constructed URI: `https://trustgraph.ai/ontology/food#Recipe`
|
||||||
|
- Semantic meaning lost, breaks interoperability
|
||||||
|
|
||||||
|
### 3. **Ambiguous Entity Instance Format**
|
||||||
|
|
||||||
|
**Issue**: No clear guidance on entity instance URI format.
|
||||||
|
|
||||||
|
**Examples in prompt**:
|
||||||
|
- `"recipe:cornish-pasty"` (namespace-like prefix)
|
||||||
|
- `"ingredient:flour"` (different prefix)
|
||||||
|
|
||||||
|
**Actual behavior** (extract.py:517-520):
|
||||||
|
```python
|
||||||
|
# Treat as entity instance - construct unique URI
|
||||||
|
normalized = value.replace(" ", "-").lower()
|
||||||
|
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**: LLM must guess prefixing convention with no ontology context.
|
||||||
|
|
||||||
|
### 4. **No Namespace Prefix Guidance**
|
||||||
|
|
||||||
|
**Issue**: The ontology JSON contains namespace definitions (line 10-25 in food.ontology):
|
||||||
|
```json
|
||||||
|
"namespaces": {
|
||||||
|
"fo": "http://purl.org/ontology/fo/",
|
||||||
|
"rdf": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
But these are never surfaced to the LLM. The LLM doesn't know:
|
||||||
|
- What "fo" means
|
||||||
|
- What prefix to use for entities
|
||||||
|
- Which namespace applies to which elements
|
||||||
|
|
||||||
|
### 5. **Labels Not Used in Prompt**
|
||||||
|
|
||||||
|
**Issue**: Every class has `rdfs:label` fields (e.g., `{"value": "Recipe", "lang": "en-gb"}`), but the prompt template doesn't use them.
|
||||||
|
|
||||||
|
**Current**: Shows only `class_id` and `comment`
|
||||||
|
```jinja
|
||||||
|
- **{{class_id}}**{% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Available but unused**:
|
||||||
|
```python
|
||||||
|
"rdfs:label": [{"value": "Recipe", "lang": "en-gb"}]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Impact**: Could provide human-readable names alongside technical IDs.
|
||||||
|
|
||||||
|
## Proposed Solutions
|
||||||
|
|
||||||
|
### Option A: Normalize to Unprefixed IDs
|
||||||
|
|
||||||
|
**Approach**: Strip prefixes from class IDs before showing to LLM.
|
||||||
|
|
||||||
|
**Changes**:
|
||||||
|
1. Modify `build_extraction_variables()` to transform keys:
|
||||||
|
```python
|
||||||
|
classes_for_prompt = {
|
||||||
|
k.split('/')[-1]: v # "fo/Recipe" → "Recipe"
|
||||||
|
for k, v in ontology_subset.classes.items()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Update prompt example to match (already uses unprefixed names)
|
||||||
|
|
||||||
|
3. Modify `expand_uri()` to handle both formats:
|
||||||
|
```python
|
||||||
|
# Try exact match first
|
||||||
|
if value in ontology_subset.classes:
|
||||||
|
return ontology_subset.classes[value]['uri']
|
||||||
|
|
||||||
|
# Try with prefix
|
||||||
|
for prefix in ['fo/', 'rdf:', 'rdfs:']:
|
||||||
|
prefixed = f"{prefix}{value}"
|
||||||
|
if prefixed in ontology_subset.classes:
|
||||||
|
return ontology_subset.classes[prefixed]['uri']
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pros**:
|
||||||
|
- Cleaner, more human-readable
|
||||||
|
- Matches existing prompt examples
|
||||||
|
- LLMs work better with simpler tokens
|
||||||
|
|
||||||
|
**Cons**:
|
||||||
|
- Class name collisions if multiple ontologies have same class name
|
||||||
|
- Loses namespace information
|
||||||
|
- Requires fallback logic for lookups
|
||||||
|
|
||||||
|
### Option B: Use Full Prefixed IDs Consistently
|
||||||
|
|
||||||
|
**Approach**: Update examples to use prefixed IDs matching what's shown in the class list.
|
||||||
|
|
||||||
|
**Changes**:
|
||||||
|
1. Update prompt example (ontology-prompt.md:46-52):
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "fo/Recipe"},
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "fo/produces", "object": "food:cornish-pasty"},
|
||||||
|
{"subject": "food:cornish-pasty", "predicate": "rdf:type", "object": "fo/Food"}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Add namespace explanation to prompt:
|
||||||
|
```markdown
|
||||||
|
## Namespace Prefixes:
|
||||||
|
- **fo/**: Food Ontology (http://purl.org/ontology/fo/)
|
||||||
|
- **rdf:**: RDF Schema
|
||||||
|
- **rdfs:**: RDF Schema
|
||||||
|
|
||||||
|
Use these prefixes exactly as shown when referencing classes and properties.
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Keep `expand_uri()` as-is (works correctly when matches found)
|
||||||
|
|
||||||
|
**Pros**:
|
||||||
|
- Input = Output consistency
|
||||||
|
- No information loss
|
||||||
|
- Preserves namespace semantics
|
||||||
|
- Works with multiple ontologies
|
||||||
|
|
||||||
|
**Cons**:
|
||||||
|
- More verbose tokens for LLM
|
||||||
|
- Requires LLM to track prefixes
|
||||||
|
|
||||||
|
### Option C: Hybrid - Show Both Label and ID
|
||||||
|
|
||||||
|
**Approach**: Enhance prompt to show both human-readable labels and technical IDs.
|
||||||
|
|
||||||
|
**Changes**:
|
||||||
|
1. Update prompt template:
|
||||||
|
```jinja
|
||||||
|
{% for class_id, class_def in classes.items() %}
|
||||||
|
- **{{class_id}}** (label: "{{class_def.labels[0].value if class_def.labels else class_id}}"){% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
```markdown
|
||||||
|
- **fo/Recipe** (label: "Recipe"): A Recipe is a combination...
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Update instructions:
|
||||||
|
```markdown
|
||||||
|
When referencing classes:
|
||||||
|
- Use the full prefixed ID (e.g., "fo/Recipe") in JSON output
|
||||||
|
- The label (e.g., "Recipe") is for human understanding only
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pros**:
|
||||||
|
- Clearest for LLM
|
||||||
|
- Preserves all information
|
||||||
|
- Explicit about what to use
|
||||||
|
|
||||||
|
**Cons**:
|
||||||
|
- Longer prompt
|
||||||
|
- More complex template
|
||||||
|
|
||||||
|
## Implemented Approach
|
||||||
|
|
||||||
|
**Simplified Entity-Relationship-Attribute Format** - completely replaces the old triple-based format.
|
||||||
|
|
||||||
|
The new approach was chosen because:
|
||||||
|
|
||||||
|
1. **No Information Loss**: Original URIs preserved correctly
|
||||||
|
2. **Simpler Logic**: No transformation needed, direct dict lookups work
|
||||||
|
3. **Namespace Safety**: Handles multiple ontologies without collisions
|
||||||
|
4. **Semantic Correctness**: Maintains RDF/OWL semantics
|
||||||
|
|
||||||
|
## Implementation Complete
|
||||||
|
|
||||||
|
### What Was Built:
|
||||||
|
|
||||||
|
1. **New Prompt Template** (`prompts/ontology-extract-v2.txt`)
|
||||||
|
- ✅ Clear sections: Entity Types, Relationships, Attributes
|
||||||
|
- ✅ Example using full type identifiers (`fo/Recipe`, `fo/has_ingredient`)
|
||||||
|
- ✅ Instructions to use exact identifiers from schema
|
||||||
|
- ✅ New JSON format with entities/relationships/attributes arrays
|
||||||
|
|
||||||
|
2. **Entity Normalization** (`entity_normalizer.py`)
|
||||||
|
- ✅ `normalize_entity_name()` - Converts names to URI-safe format
|
||||||
|
- ✅ `normalize_type_identifier()` - Handles slashes in types (`fo/Recipe` → `fo-recipe`)
|
||||||
|
- ✅ `build_entity_uri()` - Creates unique URIs using (name, type) tuple
|
||||||
|
- ✅ `EntityRegistry` - Tracks entities for deduplication
|
||||||
|
|
||||||
|
3. **JSON Parser** (`simplified_parser.py`)
|
||||||
|
- ✅ Parses new format: `{entities: [...], relationships: [...], attributes: [...]}`
|
||||||
|
- ✅ Supports kebab-case and snake_case field names
|
||||||
|
- ✅ Returns structured dataclasses
|
||||||
|
- ✅ Graceful error handling with logging
|
||||||
|
|
||||||
|
4. **Triple Converter** (`triple_converter.py`)
|
||||||
|
- ✅ `convert_entity()` - Generates type + label triples automatically
|
||||||
|
- ✅ `convert_relationship()` - Connects entity URIs via properties
|
||||||
|
- ✅ `convert_attribute()` - Adds literal values
|
||||||
|
- ✅ Looks up full URIs from ontology definitions
|
||||||
|
|
||||||
|
5. **Updated Main Processor** (`extract.py`)
|
||||||
|
- ✅ Removed old triple-based extraction code
|
||||||
|
- ✅ Added `extract_with_simplified_format()` method
|
||||||
|
- ✅ Now exclusively uses new simplified format
|
||||||
|
- ✅ Calls prompt with `extract-with-ontologies-v2` ID
|
||||||
|
|
||||||
|
## Test Cases
|
||||||
|
|
||||||
|
### Test 1: URI Preservation
|
||||||
|
```python
|
||||||
|
# Given ontology class
|
||||||
|
classes = {"fo/Recipe": {"uri": "http://purl.org/ontology/fo/Recipe", ...}}
|
||||||
|
|
||||||
|
# When LLM returns
|
||||||
|
llm_output = {"subject": "x", "predicate": "rdf:type", "object": "fo/Recipe"}
|
||||||
|
|
||||||
|
# Then expanded URI should be
|
||||||
|
assert expanded == "http://purl.org/ontology/fo/Recipe"
|
||||||
|
# Not: "https://trustgraph.ai/ontology/food#Recipe"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test 2: Multi-Ontology Collision
|
||||||
|
```python
|
||||||
|
# Given two ontologies
|
||||||
|
ont1 = {"fo/Recipe": {...}}
|
||||||
|
ont2 = {"cooking/Recipe": {...}}
|
||||||
|
|
||||||
|
# LLM should use full prefix to disambiguate
|
||||||
|
llm_output = {"object": "fo/Recipe"} # Not just "Recipe"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test 3: Entity Instance Format
|
||||||
|
```python
|
||||||
|
# Given prompt with food ontology
|
||||||
|
# LLM should create instances like
|
||||||
|
{"subject": "recipe:cornish-pasty"} # Namespace-style
|
||||||
|
{"subject": "food:beef"} # Consistent prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
1. **Should entity instances use namespace prefixes?**
|
||||||
|
- Current: `"recipe:cornish-pasty"` (arbitrary)
|
||||||
|
- Alternative: Use ontology prefix `"fo:cornish-pasty"`?
|
||||||
|
- Alternative: No prefix, expand in URI `"cornish-pasty"` → full URI?
|
||||||
|
|
||||||
|
2. **How to handle domain/range in prompt?**
|
||||||
|
- Currently shows: `(Recipe → Food)`
|
||||||
|
- Should it be: `(fo/Recipe → fo/Food)`?
|
||||||
|
|
||||||
|
3. **Should we validate domain/range constraints?**
|
||||||
|
- TODO comment at extract.py:470
|
||||||
|
- Would catch more errors but more complex
|
||||||
|
|
||||||
|
4. **What about inverse properties and equivalences?**
|
||||||
|
- Ontology has `owl:inverseOf`, `owl:equivalentClass`
|
||||||
|
- Not currently used in extraction
|
||||||
|
- Should they be?
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
- ✅ Zero URI information loss (100% preservation of original URIs)
|
||||||
|
- ✅ LLM output format matches input format
|
||||||
|
- ✅ No ambiguous examples in prompt
|
||||||
|
- ✅ Tests pass with multiple ontologies
|
||||||
|
- ✅ Improved extraction quality (measured by valid triple %)
|
||||||
|
|
||||||
|
## Alternative Approach: Simplified Extraction Format
|
||||||
|
|
||||||
|
### Philosophy
|
||||||
|
|
||||||
|
Instead of asking the LLM to understand RDF/OWL semantics, ask it to do what it's good at: **find entities and relationships in text**.
|
||||||
|
|
||||||
|
Let the code handle URI construction, RDF conversion, and semantic web formalities.
|
||||||
|
|
||||||
|
### Example: Entity Classification
|
||||||
|
|
||||||
|
**Input Text:**
|
||||||
|
```
|
||||||
|
Cornish pasty is a traditional British pastry filled with meat and vegetables.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ontology Schema (shown to LLM):**
|
||||||
|
```markdown
|
||||||
|
## Entity Types:
|
||||||
|
- Recipe: A recipe is a combination of ingredients and a method
|
||||||
|
- Food: A food is something that can be eaten
|
||||||
|
- Ingredient: An ingredient combines a quantity and a food
|
||||||
|
```
|
||||||
|
|
||||||
|
**What LLM Returns (Simple JSON):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty",
|
||||||
|
"type": "Recipe"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**What Code Produces (RDF Triples):**
|
||||||
|
```python
|
||||||
|
# 1. Normalize entity name + type to ID (type prevents collisions)
|
||||||
|
entity_id = "recipe-cornish-pasty" # normalize("Cornish pasty", "Recipe")
|
||||||
|
entity_uri = "https://trustgraph.ai/food/recipe-cornish-pasty"
|
||||||
|
|
||||||
|
# Note: Same name, different type = different URI
|
||||||
|
# "Cornish pasty" (Recipe) → recipe-cornish-pasty
|
||||||
|
# "Cornish pasty" (Food) → food-cornish-pasty
|
||||||
|
|
||||||
|
# 2. Generate triples
|
||||||
|
triples = [
|
||||||
|
# Type triple
|
||||||
|
Triple(
|
||||||
|
s=Value(value=entity_uri, is_uri=True),
|
||||||
|
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||||
|
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
|
||||||
|
),
|
||||||
|
# Label triple (automatic)
|
||||||
|
Triple(
|
||||||
|
s=Value(value=entity_uri, is_uri=True),
|
||||||
|
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||||
|
o=Value(value="Cornish pasty", is_uri=False)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benefits
|
||||||
|
|
||||||
|
1. **LLM doesn't need to:**
|
||||||
|
- Understand URI syntax
|
||||||
|
- Invent identifier prefixes (`recipe:`, `ingredient:`)
|
||||||
|
- Know about `rdf:type` or `rdfs:label`
|
||||||
|
- Construct semantic web identifiers
|
||||||
|
|
||||||
|
2. **LLM just needs to:**
|
||||||
|
- Find entities in text
|
||||||
|
- Map them to ontology classes
|
||||||
|
- Extract relationships and attributes
|
||||||
|
|
||||||
|
3. **Code handles:**
|
||||||
|
- URI normalization and construction
|
||||||
|
- RDF triple generation
|
||||||
|
- Automatic label assignment
|
||||||
|
- Namespace management
|
||||||
|
|
||||||
|
### Why This Works Better
|
||||||
|
|
||||||
|
- **Simpler prompt** = less confusion = fewer errors
|
||||||
|
- **Consistent IDs** = code controls normalization rules
|
||||||
|
- **Auto-generated labels** = no missing rdfs:label triples
|
||||||
|
- **LLM focuses on extraction** = what it's actually good at
|
||||||
|
|
||||||
|
### Example: Entity Relationships
|
||||||
|
|
||||||
|
**Input Text:**
|
||||||
|
```
|
||||||
|
Cornish pasty is a traditional British pastry filled with beef and potatoes.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ontology Schema (shown to LLM):**
|
||||||
|
```markdown
|
||||||
|
## Entity Types:
|
||||||
|
- Recipe: A recipe is a combination of ingredients and a method
|
||||||
|
- Food: A food is something that can be eaten
|
||||||
|
- Ingredient: An ingredient combines a quantity and a food
|
||||||
|
|
||||||
|
## Relationships:
|
||||||
|
- has_ingredient: Relates a recipe to an ingredient it uses (Recipe → Ingredient)
|
||||||
|
- food: Relates an ingredient to the food that is required (Ingredient → Food)
|
||||||
|
```
|
||||||
|
|
||||||
|
**What LLM Returns (Simple JSON):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty",
|
||||||
|
"type": "Recipe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": "beef",
|
||||||
|
"type": "Food"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": "potatoes",
|
||||||
|
"type": "Food"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"relationships": [
|
||||||
|
{
|
||||||
|
"subject": "Cornish pasty",
|
||||||
|
"subject-type": "Recipe",
|
||||||
|
"relation": "has_ingredient",
|
||||||
|
"object": "beef",
|
||||||
|
"object-type": "Food"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subject": "Cornish pasty",
|
||||||
|
"subject-type": "Recipe",
|
||||||
|
"relation": "has_ingredient",
|
||||||
|
"object": "potatoes",
|
||||||
|
"object-type": "Food"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**What Code Produces (RDF Triples):**
|
||||||
|
```python
|
||||||
|
# Normalize entity names to URIs
|
||||||
|
cornish_pasty_uri = "https://trustgraph.ai/food/cornish-pasty"
|
||||||
|
beef_uri = "https://trustgraph.ai/food/beef"
|
||||||
|
potatoes_uri = "https://trustgraph.ai/food/potatoes"
|
||||||
|
|
||||||
|
# Look up relation URI from ontology
|
||||||
|
has_ingredient_uri = "http://purl.org/ontology/fo/ingredients" # from fo/has_ingredient
|
||||||
|
|
||||||
|
triples = [
|
||||||
|
# Entity type triples (as before)
|
||||||
|
Triple(s=cornish_pasty_uri, p=rdf_type, o="http://purl.org/ontology/fo/Recipe"),
|
||||||
|
Triple(s=cornish_pasty_uri, p=rdfs_label, o="Cornish pasty"),
|
||||||
|
|
||||||
|
Triple(s=beef_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
|
||||||
|
Triple(s=beef_uri, p=rdfs_label, o="beef"),
|
||||||
|
|
||||||
|
Triple(s=potatoes_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
|
||||||
|
Triple(s=potatoes_uri, p=rdfs_label, o="potatoes"),
|
||||||
|
|
||||||
|
# Relationship triples
|
||||||
|
Triple(
|
||||||
|
s=Value(value=cornish_pasty_uri, is_uri=True),
|
||||||
|
p=Value(value=has_ingredient_uri, is_uri=True),
|
||||||
|
o=Value(value=beef_uri, is_uri=True)
|
||||||
|
),
|
||||||
|
Triple(
|
||||||
|
s=Value(value=cornish_pasty_uri, is_uri=True),
|
||||||
|
p=Value(value=has_ingredient_uri, is_uri=True),
|
||||||
|
o=Value(value=potatoes_uri, is_uri=True)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Points:**
|
||||||
|
- LLM returns natural language entity names: `"Cornish pasty"`, `"beef"`, `"potatoes"`
|
||||||
|
- LLM includes types to disambiguate: `subject-type`, `object-type`
|
||||||
|
- LLM uses relation name from schema: `"has_ingredient"`
|
||||||
|
- Code derives consistent IDs using (name, type): `("Cornish pasty", "Recipe")` → `recipe-cornish-pasty`
|
||||||
|
- Code looks up relation URI from ontology: `fo/has_ingredient` → full URI
|
||||||
|
- Same (name, type) tuple always gets same URI (deduplication)
|
||||||
|
|
||||||
|
### Example: Entity Name Disambiguation
|
||||||
|
|
||||||
|
**Problem:** Same name can refer to different entity types.
|
||||||
|
|
||||||
|
**Real-world case:**
|
||||||
|
```
|
||||||
|
"Cornish pasty" can be:
|
||||||
|
- A Recipe (instructions for making it)
|
||||||
|
- A Food (the dish itself)
|
||||||
|
```
|
||||||
|
|
||||||
|
**How It's Handled:**
|
||||||
|
|
||||||
|
LLM returns both as separate entities:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{"entity": "Cornish pasty", "type": "Recipe"},
|
||||||
|
{"entity": "Cornish pasty", "type": "Food"}
|
||||||
|
],
|
||||||
|
"relationships": [
|
||||||
|
{
|
||||||
|
"subject": "Cornish pasty",
|
||||||
|
"subject-type": "Recipe",
|
||||||
|
"relation": "produces",
|
||||||
|
"object": "Cornish pasty",
|
||||||
|
"object-type": "Food"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Code Resolution:**
|
||||||
|
```python
|
||||||
|
# Different types → different URIs
|
||||||
|
recipe_uri = normalize("Cornish pasty", "Recipe")
|
||||||
|
# → "https://trustgraph.ai/food/recipe-cornish-pasty"
|
||||||
|
|
||||||
|
food_uri = normalize("Cornish pasty", "Food")
|
||||||
|
# → "https://trustgraph.ai/food/food-cornish-pasty"
|
||||||
|
|
||||||
|
# Relationship connects them correctly
|
||||||
|
triple = Triple(
|
||||||
|
s=recipe_uri, # The Recipe
|
||||||
|
p="http://purl.org/ontology/fo/produces",
|
||||||
|
o=food_uri # The Food
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why This Works:**
|
||||||
|
- Type is included in ALL references (entities, relationships, attributes)
|
||||||
|
- Code uses `(name, type)` tuple as lookup key
|
||||||
|
- No ambiguity, no collisions
|
||||||
|
|
||||||
|
### Example: Entity Attributes
|
||||||
|
|
||||||
|
**Input Text:**
|
||||||
|
```
|
||||||
|
This Cornish pasty recipe serves 4-6 people and takes 45 minutes to prepare.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ontology Schema (shown to LLM):**
|
||||||
|
```markdown
|
||||||
|
## Entity Types:
|
||||||
|
- Recipe: A recipe is a combination of ingredients and a method
|
||||||
|
|
||||||
|
## Attributes:
|
||||||
|
- serves: Indicates what the recipe is intended to serve (Recipe → text)
|
||||||
|
- preparation_time: Time needed to prepare the recipe (Recipe → text)
|
||||||
|
```
|
||||||
|
|
||||||
|
**What LLM Returns (Simple JSON):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty recipe",
|
||||||
|
"type": "Recipe"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"attributes": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty recipe",
|
||||||
|
"entity-type": "Recipe",
|
||||||
|
"attribute": "serves",
|
||||||
|
"value": "4-6 people"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty recipe",
|
||||||
|
"entity-type": "Recipe",
|
||||||
|
"attribute": "preparation_time",
|
||||||
|
"value": "45 minutes"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**What Code Produces (RDF Triples):**
|
||||||
|
```python
|
||||||
|
# Normalize entity name to URI
|
||||||
|
recipe_uri = "https://trustgraph.ai/food/cornish-pasty-recipe"
|
||||||
|
|
||||||
|
# Look up attribute URIs from ontology
|
||||||
|
serves_uri = "http://purl.org/ontology/fo/serves" # from fo/serves
|
||||||
|
prep_time_uri = "http://purl.org/ontology/fo/preparation_time" # from fo/preparation_time
|
||||||
|
|
||||||
|
triples = [
|
||||||
|
# Entity type triple
|
||||||
|
Triple(
|
||||||
|
s=Value(value=recipe_uri, is_uri=True),
|
||||||
|
p=Value(value=rdf_type, is_uri=True),
|
||||||
|
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
|
||||||
|
),
|
||||||
|
|
||||||
|
# Label triple (automatic)
|
||||||
|
Triple(
|
||||||
|
s=Value(value=recipe_uri, is_uri=True),
|
||||||
|
p=Value(value=rdfs_label, is_uri=True),
|
||||||
|
o=Value(value="Cornish pasty recipe", is_uri=False)
|
||||||
|
),
|
||||||
|
|
||||||
|
# Attribute triples (objects are literals, not URIs)
|
||||||
|
Triple(
|
||||||
|
s=Value(value=recipe_uri, is_uri=True),
|
||||||
|
p=Value(value=serves_uri, is_uri=True),
|
||||||
|
o=Value(value="4-6 people", is_uri=False) # Literal value!
|
||||||
|
),
|
||||||
|
Triple(
|
||||||
|
s=Value(value=recipe_uri, is_uri=True),
|
||||||
|
p=Value(value=prep_time_uri, is_uri=True),
|
||||||
|
o=Value(value="45 minutes", is_uri=False) # Literal value!
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Points:**
|
||||||
|
- LLM extracts literal values: `"4-6 people"`, `"45 minutes"`
|
||||||
|
- LLM includes entity type for disambiguation: `entity-type`
|
||||||
|
- LLM uses attribute name from schema: `"serves"`, `"preparation_time"`
|
||||||
|
- Code looks up attribute URI from ontology datatype properties
|
||||||
|
- **Object is literal** (`is_uri=False`), not a URI reference
|
||||||
|
- Values stay as natural text, no normalization needed
|
||||||
|
|
||||||
|
**Difference from Relationships:**
|
||||||
|
- Relationships: both subject and object are entities (URIs)
|
||||||
|
- Attributes: subject is entity (URI), object is literal value (string/number)
|
||||||
|
|
||||||
|
### Complete Example: Entities + Relationships + Attributes
|
||||||
|
|
||||||
|
**Input Text:**
|
||||||
|
```
|
||||||
|
Cornish pasty is a savory pastry filled with beef and potatoes.
|
||||||
|
This recipe serves 4 people.
|
||||||
|
```
|
||||||
|
|
||||||
|
**What LLM Returns:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty",
|
||||||
|
"type": "Recipe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": "beef",
|
||||||
|
"type": "Food"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity": "potatoes",
|
||||||
|
"type": "Food"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"relationships": [
|
||||||
|
{
|
||||||
|
"subject": "Cornish pasty",
|
||||||
|
"subject-type": "Recipe",
|
||||||
|
"relation": "has_ingredient",
|
||||||
|
"object": "beef",
|
||||||
|
"object-type": "Food"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subject": "Cornish pasty",
|
||||||
|
"subject-type": "Recipe",
|
||||||
|
"relation": "has_ingredient",
|
||||||
|
"object": "potatoes",
|
||||||
|
"object-type": "Food"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"attributes": [
|
||||||
|
{
|
||||||
|
"entity": "Cornish pasty",
|
||||||
|
"entity-type": "Recipe",
|
||||||
|
"attribute": "serves",
|
||||||
|
"value": "4 people"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result:** 11 RDF triples generated:
|
||||||
|
- 3 entity type triples (rdf:type)
|
||||||
|
- 3 entity label triples (rdfs:label) - automatic
|
||||||
|
- 2 relationship triples (has_ingredient)
|
||||||
|
- 1 attribute triple (serves)
|
||||||
|
|
||||||
|
All from simple, natural language extractions by the LLM!
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Current implementation: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py`
|
||||||
|
- Prompt template: `ontology-prompt.md`
|
||||||
|
- Test cases: `tests/unit/test_extract/test_ontology/`
|
||||||
|
- Example ontology: `e2e/test-data/food.ontology`
|
||||||
958
docs/tech-specs/pubsub.md
Normal file
958
docs/tech-specs/pubsub.md
Normal file
|
|
@ -0,0 +1,958 @@
|
||||||
|
# Pub/Sub Infrastructure
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document catalogs all connections between the TrustGraph codebase and the pub/sub infrastructure. Currently, the system is hardcoded to use Apache Pulsar. This analysis identifies all integration points to inform future refactoring toward a configurable pub/sub abstraction.
|
||||||
|
|
||||||
|
## Current State: Pulsar Integration Points
|
||||||
|
|
||||||
|
### 1. Direct Pulsar Client Usage
|
||||||
|
|
||||||
|
**Location:** `trustgraph-flow/trustgraph/gateway/service.py`
|
||||||
|
|
||||||
|
The API gateway directly imports and instantiates the Pulsar client:
|
||||||
|
|
||||||
|
- **Line 20:** `import pulsar`
|
||||||
|
- **Lines 54-61:** Direct instantiation of `pulsar.Client()` with optional `pulsar.AuthenticationToken()`
|
||||||
|
- **Lines 33-35:** Default Pulsar host configuration from environment variables
|
||||||
|
- **Lines 178-192:** CLI arguments for `--pulsar-host`, `--pulsar-api-key`, and `--pulsar-listener`
|
||||||
|
- **Lines 78, 124:** Passes `pulsar_client` to `ConfigReceiver` and `DispatcherManager`
|
||||||
|
|
||||||
|
This is the only location that directly instantiates a Pulsar client outside of the abstraction layer.
|
||||||
|
|
||||||
|
### 2. Base Processor Framework
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/base/async_processor.py`
|
||||||
|
|
||||||
|
The base class for all processors provides Pulsar connectivity:
|
||||||
|
|
||||||
|
- **Line 9:** `import _pulsar` (for exception handling)
|
||||||
|
- **Line 18:** `from . pubsub import PulsarClient`
|
||||||
|
- **Line 38:** Creates `pulsar_client_object = PulsarClient(**params)`
|
||||||
|
- **Lines 104-108:** Properties exposing `pulsar_host` and `pulsar_client`
|
||||||
|
- **Line 250:** Static method `add_args()` calls `PulsarClient.add_args(parser)` for CLI arguments
|
||||||
|
- **Lines 223-225:** Exception handling for `_pulsar.Interrupted`
|
||||||
|
|
||||||
|
All processors inherit from `AsyncProcessor`, making this the central integration point.
|
||||||
|
|
||||||
|
### 3. Consumer Abstraction
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/base/consumer.py`
|
||||||
|
|
||||||
|
Consumes messages from queues and invokes handler functions:
|
||||||
|
|
||||||
|
**Pulsar imports:**
|
||||||
|
- **Line 12:** `from pulsar.schema import JsonSchema`
|
||||||
|
- **Line 13:** `import pulsar`
|
||||||
|
- **Line 14:** `import _pulsar`
|
||||||
|
|
||||||
|
**Pulsar-specific usage:**
|
||||||
|
- **Lines 100, 102:** `pulsar.InitialPosition.Earliest` / `pulsar.InitialPosition.Latest`
|
||||||
|
- **Line 108:** `JsonSchema(self.schema)` wrapper
|
||||||
|
- **Line 110:** `pulsar.ConsumerType.Shared`
|
||||||
|
- **Lines 104-111:** `self.client.subscribe()` with Pulsar-specific parameters
|
||||||
|
- **Lines 143, 150, 65:** `consumer.unsubscribe()` and `consumer.close()` methods
|
||||||
|
- **Line 162:** `_pulsar.Timeout` exception
|
||||||
|
- **Lines 182, 205, 232:** `consumer.acknowledge()` / `consumer.negative_acknowledge()`
|
||||||
|
|
||||||
|
**Spec file:** `trustgraph-base/trustgraph/base/consumer_spec.py`
|
||||||
|
- **Line 22:** References `processor.pulsar_client`
|
||||||
|
|
||||||
|
### 4. Producer Abstraction
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/base/producer.py`
|
||||||
|
|
||||||
|
Sends messages to queues:
|
||||||
|
|
||||||
|
**Pulsar imports:**
|
||||||
|
- **Line 2:** `from pulsar.schema import JsonSchema`
|
||||||
|
|
||||||
|
**Pulsar-specific usage:**
|
||||||
|
- **Line 49:** `JsonSchema(self.schema)` wrapper
|
||||||
|
- **Lines 47-51:** `self.client.create_producer()` with Pulsar-specific parameters (topic, schema, chunking_enabled)
|
||||||
|
- **Lines 31, 76:** `producer.close()` method
|
||||||
|
- **Lines 64-65:** `producer.send()` with message and properties
|
||||||
|
|
||||||
|
**Spec file:** `trustgraph-base/trustgraph/base/producer_spec.py`
|
||||||
|
- **Line 18:** References `processor.pulsar_client`
|
||||||
|
|
||||||
|
### 5. Publisher Abstraction
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/base/publisher.py`
|
||||||
|
|
||||||
|
Asynchronous message publishing with queue buffering:
|
||||||
|
|
||||||
|
**Pulsar imports:**
|
||||||
|
- **Line 2:** `from pulsar.schema import JsonSchema`
|
||||||
|
- **Line 6:** `import pulsar`
|
||||||
|
|
||||||
|
**Pulsar-specific usage:**
|
||||||
|
- **Line 52:** `JsonSchema(self.schema)` wrapper
|
||||||
|
- **Lines 50-54:** `self.client.create_producer()` with Pulsar-specific parameters
|
||||||
|
- **Lines 101, 103:** `producer.send()` with message and optional properties
|
||||||
|
- **Lines 106-107:** `producer.flush()` and `producer.close()` methods
|
||||||
|
|
||||||
|
### 6. Subscriber Abstraction
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/base/subscriber.py`
|
||||||
|
|
||||||
|
Provides multi-recipient message distribution from queues:
|
||||||
|
|
||||||
|
**Pulsar imports:**
|
||||||
|
- **Line 6:** `from pulsar.schema import JsonSchema`
|
||||||
|
- **Line 8:** `import _pulsar`
|
||||||
|
|
||||||
|
**Pulsar-specific usage:**
|
||||||
|
- **Line 55:** `JsonSchema(self.schema)` wrapper
|
||||||
|
- **Line 57:** `self.client.subscribe(**subscribe_args)`
|
||||||
|
- **Lines 101, 136, 160, 167-172:** Pulsar exceptions: `_pulsar.Timeout`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
|
||||||
|
- **Lines 159, 166, 170:** Consumer methods: `negative_acknowledge()`, `unsubscribe()`, `close()`
|
||||||
|
- **Lines 247, 251:** Message acknowledgment: `acknowledge()`, `negative_acknowledge()`
|
||||||
|
|
||||||
|
**Spec file:** `trustgraph-base/trustgraph/base/subscriber_spec.py`
|
||||||
|
- **Line 19:** References `processor.pulsar_client`
|
||||||
|
|
||||||
|
### 7. Schema System (Heart of Darkness)
|
||||||
|
|
||||||
|
**Location:** `trustgraph-base/trustgraph/schema/`
|
||||||
|
|
||||||
|
Every message schema in the system is defined using Pulsar's schema framework.
|
||||||
|
|
||||||
|
**Core primitives:** `schema/core/primitives.py`
|
||||||
|
- **Line 2:** `from pulsar.schema import Record, String, Boolean, Array, Integer`
|
||||||
|
- All schemas inherit from Pulsar's `Record` base class
|
||||||
|
- All field types are Pulsar types: `String()`, `Integer()`, `Boolean()`, `Array()`, `Map()`, `Double()`
|
||||||
|
|
||||||
|
**Example schemas:**
|
||||||
|
- `schema/services/llm.py` (Line 2): `from pulsar.schema import Record, String, Array, Double, Integer, Boolean`
|
||||||
|
- `schema/services/config.py` (Line 2): `from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer`
|
||||||
|
|
||||||
|
**Topic naming:** `schema/core/topic.py`
|
||||||
|
- **Lines 2-3:** Topic format: `{kind}://{tenant}/{namespace}/{topic}`
|
||||||
|
- This URI structure is Pulsar-specific (e.g., `persistent://tg/flow/config`)
|
||||||
|
|
||||||
|
**Impact:**
|
||||||
|
- All request/response message definitions throughout the codebase use Pulsar schemas
|
||||||
|
- This includes services for: config, flow, llm, prompt, query, storage, agent, collection, diagnosis, library, lookup, nlp_query, objects_query, retrieval, structured_query
|
||||||
|
- Schema definitions are imported and used extensively across all processors and services
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
### Pulsar Dependencies by Category
|
||||||
|
|
||||||
|
1. **Client instantiation:**
|
||||||
|
- Direct: `gateway/service.py`
|
||||||
|
- Abstracted: `async_processor.py` → `pubsub.py` (PulsarClient)
|
||||||
|
|
||||||
|
2. **Message transport:**
|
||||||
|
- Consumer: `consumer.py`, `consumer_spec.py`
|
||||||
|
- Producer: `producer.py`, `producer_spec.py`
|
||||||
|
- Publisher: `publisher.py`
|
||||||
|
- Subscriber: `subscriber.py`, `subscriber_spec.py`
|
||||||
|
|
||||||
|
3. **Schema system:**
|
||||||
|
- Base types: `schema/core/primitives.py`
|
||||||
|
- All service schemas: `schema/services/*.py`
|
||||||
|
- Topic naming: `schema/core/topic.py`
|
||||||
|
|
||||||
|
4. **Pulsar-specific concepts required:**
|
||||||
|
- Topic-based messaging
|
||||||
|
- Schema system (Record, field types)
|
||||||
|
- Shared subscriptions
|
||||||
|
- Message acknowledgment (positive/negative)
|
||||||
|
- Consumer positioning (earliest/latest)
|
||||||
|
- Message properties
|
||||||
|
- Initial positions and consumer types
|
||||||
|
- Chunking support
|
||||||
|
- Persistent vs non-persistent topics
|
||||||
|
|
||||||
|
### Refactoring Challenges
|
||||||
|
|
||||||
|
The good news: The abstraction layer (Consumer, Producer, Publisher, Subscriber) provides a clean encapsulation of most Pulsar interactions.
|
||||||
|
|
||||||
|
The challenges:
|
||||||
|
1. **Schema system pervasiveness:** Every message definition uses `pulsar.schema.Record` and Pulsar field types
|
||||||
|
2. **Pulsar-specific enums:** `InitialPosition`, `ConsumerType`
|
||||||
|
3. **Pulsar exceptions:** `_pulsar.Timeout`, `_pulsar.Interrupted`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
|
||||||
|
4. **Method signatures:** `acknowledge()`, `negative_acknowledge()`, `subscribe()`, `create_producer()`, etc.
|
||||||
|
5. **Topic URI format:** Pulsar's `kind://tenant/namespace/topic` structure
|
||||||
|
|
||||||
|
### Next Steps
|
||||||
|
|
||||||
|
To make the pub/sub infrastructure configurable, we need to:
|
||||||
|
|
||||||
|
1. Create an abstraction interface for the client/schema system
|
||||||
|
2. Abstract Pulsar-specific enums and exceptions
|
||||||
|
3. Create schema wrappers or alternative schema definitions
|
||||||
|
4. Implement the interface for both Pulsar and alternative systems (Kafka, RabbitMQ, Redis Streams, etc.)
|
||||||
|
5. Update `pubsub.py` to be configurable and support multiple backends
|
||||||
|
6. Provide migration path for existing deployments
|
||||||
|
|
||||||
|
## Approach Draft 1: Adapter Pattern with Schema Translation Layer
|
||||||
|
|
||||||
|
### Key Insight
|
||||||
|
The **schema system** is the deepest integration point - everything else flows from it. We need to solve this first, or we'll be rewriting the entire codebase.
|
||||||
|
|
||||||
|
### Strategy: Minimal Disruption with Adapters
|
||||||
|
|
||||||
|
**1. Keep Pulsar schemas as the internal representation**
|
||||||
|
- Don't rewrite all the schema definitions
|
||||||
|
- Schemas remain `pulsar.schema.Record` internally
|
||||||
|
- Use adapters to translate at the boundary between our code and the pub/sub backend
|
||||||
|
|
||||||
|
**2. Create a pub/sub abstraction layer:**
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Existing Code (unchanged) │
|
||||||
|
│ - Uses Pulsar schemas internally │
|
||||||
|
│ - Consumer/Producer/Publisher │
|
||||||
|
└──────────────┬──────────────────────┘
|
||||||
|
│
|
||||||
|
┌──────────────┴──────────────────────┐
|
||||||
|
│ PubSubFactory (configurable) │
|
||||||
|
│ - Creates backend-specific client │
|
||||||
|
└──────────────┬──────────────────────┘
|
||||||
|
│
|
||||||
|
┌──────┴──────┐
|
||||||
|
│ │
|
||||||
|
┌───────▼─────┐ ┌────▼─────────┐
|
||||||
|
│ PulsarAdapter│ │ KafkaAdapter │ etc...
|
||||||
|
│ (passthrough)│ │ (translates) │
|
||||||
|
└──────────────┘ └──────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Define abstract interfaces:**
|
||||||
|
- `PubSubClient` - client connection
|
||||||
|
- `PubSubProducer` - sending messages
|
||||||
|
- `PubSubConsumer` - receiving messages
|
||||||
|
- `SchemaAdapter` - translating Pulsar schemas to/from JSON or backend-specific formats
|
||||||
|
|
||||||
|
**4. Implementation details:**
|
||||||
|
|
||||||
|
For **Pulsar adapter**: Nearly passthrough, minimal translation
|
||||||
|
|
||||||
|
For **other backends** (Kafka, RabbitMQ, etc.):
|
||||||
|
- Serialize Pulsar Record objects to JSON/bytes
|
||||||
|
- Map concepts like:
|
||||||
|
- `InitialPosition.Earliest/Latest` → Kafka's auto.offset.reset
|
||||||
|
- `acknowledge()` → Kafka's commit
|
||||||
|
- `negative_acknowledge()` → Re-queue or DLQ pattern
|
||||||
|
- Topic URIs → Backend-specific topic names
|
||||||
|
|
||||||
|
### Analysis
|
||||||
|
|
||||||
|
**Pros:**
|
||||||
|
- ✅ Minimal code changes to existing services
|
||||||
|
- ✅ Schemas stay as-is (no massive rewrite)
|
||||||
|
- ✅ Gradual migration path
|
||||||
|
- ✅ Pulsar users see no difference
|
||||||
|
- ✅ New backends added via adapters
|
||||||
|
|
||||||
|
**Cons:**
|
||||||
|
- ⚠️ Still carries Pulsar dependency (for schema definitions)
|
||||||
|
- ⚠️ Some impedance mismatch translating concepts
|
||||||
|
|
||||||
|
### Alternative Consideration
|
||||||
|
|
||||||
|
Create a **TrustGraph schema system** that's pub/sub agnostic (using dataclasses or Pydantic), then generate Pulsar/Kafka/etc schemas from it. This requires rewriting every schema file and potentially breaking changes.
|
||||||
|
|
||||||
|
### Recommendation for Draft 1
|
||||||
|
|
||||||
|
Start with the **adapter approach** because:
|
||||||
|
1. It's pragmatic - works with existing code
|
||||||
|
2. Proves the concept with minimal risk
|
||||||
|
3. Can evolve to a native schema system later if needed
|
||||||
|
4. Configuration-driven: one env var switches backends
|
||||||
|
|
||||||
|
## Approach Draft 2: Backend-Agnostic Schema System with Dataclasses
|
||||||
|
|
||||||
|
### Core Concept
|
||||||
|
|
||||||
|
Use Python **dataclasses** as the neutral schema definition format. Each pub/sub backend provides its own serialization/deserialization for dataclasses, eliminating the need for Pulsar schemas to remain in the codebase.
|
||||||
|
|
||||||
|
### Schema Polymorphism at the Factory Level
|
||||||
|
|
||||||
|
Instead of translating Pulsar schemas, **each backend provides its own schema handling** that works with standard Python dataclasses.
|
||||||
|
|
||||||
|
### Publisher Flow
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 1. Get the configured backend from factory
|
||||||
|
pubsub = get_pubsub() # Returns PulsarBackend, MQTTBackend, etc.
|
||||||
|
|
||||||
|
# 2. Get schema class from the backend
|
||||||
|
# (Can be imported directly - backend-agnostic)
|
||||||
|
from trustgraph.schema.services.llm import TextCompletionRequest
|
||||||
|
|
||||||
|
# 3. Create a producer/publisher for a specific topic
|
||||||
|
producer = pubsub.create_producer(
|
||||||
|
topic="text-completion-requests",
|
||||||
|
schema=TextCompletionRequest # Tells backend what schema to use
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Create message instances (same API regardless of backend)
|
||||||
|
request = TextCompletionRequest(
|
||||||
|
system="You are helpful",
|
||||||
|
prompt="Hello world",
|
||||||
|
streaming=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Send the message
|
||||||
|
producer.send(request) # Backend serializes appropriately
|
||||||
|
```
|
||||||
|
|
||||||
|
### Consumer Flow
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 1. Get the configured backend
|
||||||
|
pubsub = get_pubsub()
|
||||||
|
|
||||||
|
# 2. Create a consumer
|
||||||
|
consumer = pubsub.subscribe(
|
||||||
|
topic="text-completion-requests",
|
||||||
|
schema=TextCompletionRequest # Tells backend how to deserialize
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Receive and deserialize
|
||||||
|
msg = consumer.receive()
|
||||||
|
request = msg.value() # Returns TextCompletionRequest dataclass instance
|
||||||
|
|
||||||
|
# 4. Use the data (type-safe access)
|
||||||
|
print(request.system) # "You are helpful"
|
||||||
|
print(request.prompt) # "Hello world"
|
||||||
|
print(request.streaming) # False
|
||||||
|
```
|
||||||
|
|
||||||
|
### What Happens Behind the Scenes
|
||||||
|
|
||||||
|
**For Pulsar backend:**
|
||||||
|
- `create_producer()` → creates Pulsar producer with JSON schema or dynamically generated Record
|
||||||
|
- `send(request)` → serializes dataclass to JSON/Pulsar format, sends to Pulsar
|
||||||
|
- `receive()` → gets Pulsar message, deserializes back to dataclass
|
||||||
|
|
||||||
|
**For MQTT backend:**
|
||||||
|
- `create_producer()` → connects to MQTT broker, no schema registration needed
|
||||||
|
- `send(request)` → converts dataclass to JSON, publishes to MQTT topic
|
||||||
|
- `receive()` → subscribes to MQTT topic, deserializes JSON to dataclass
|
||||||
|
|
||||||
|
**For Kafka backend:**
|
||||||
|
- `create_producer()` → creates Kafka producer, registers Avro schema if needed
|
||||||
|
- `send(request)` → serializes dataclass to Avro format, sends to Kafka
|
||||||
|
- `receive()` → gets Kafka message, deserializes Avro back to dataclass
|
||||||
|
|
||||||
|
### Key Design Points
|
||||||
|
|
||||||
|
1. **Schema object creation**: The dataclass instance (`TextCompletionRequest(...)`) is identical regardless of backend
|
||||||
|
2. **Backend handles encoding**: Each backend knows how to serialize its dataclass to the wire format
|
||||||
|
3. **Schema definition at creation**: When creating producer/consumer, you specify the schema type
|
||||||
|
4. **Type safety preserved**: You get back a proper `TextCompletionRequest` object, not a dict
|
||||||
|
5. **No backend leakage**: Application code never imports backend-specific libraries
|
||||||
|
|
||||||
|
### Example Transformation
|
||||||
|
|
||||||
|
**Current (Pulsar-specific):**
|
||||||
|
```python
|
||||||
|
# schema/services/llm.py
|
||||||
|
from pulsar.schema import Record, String, Boolean, Integer
|
||||||
|
|
||||||
|
class TextCompletionRequest(Record):
|
||||||
|
system = String()
|
||||||
|
prompt = String()
|
||||||
|
streaming = Boolean()
|
||||||
|
```
|
||||||
|
|
||||||
|
**New (Backend-agnostic):**
|
||||||
|
```python
|
||||||
|
# schema/services/llm.py
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextCompletionRequest:
|
||||||
|
system: str
|
||||||
|
prompt: str
|
||||||
|
streaming: bool = False
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Integration
|
||||||
|
|
||||||
|
Each backend handles serialization/deserialization of dataclasses:
|
||||||
|
|
||||||
|
**Pulsar backend:**
|
||||||
|
- Dynamically generate `pulsar.schema.Record` classes from dataclasses
|
||||||
|
- Or serialize dataclasses to JSON and use Pulsar's JSON schema
|
||||||
|
- Maintains compatibility with existing Pulsar deployments
|
||||||
|
|
||||||
|
**MQTT/Redis backend:**
|
||||||
|
- Direct JSON serialization of dataclass instances
|
||||||
|
- Use `dataclasses.asdict()` / `from_dict()`
|
||||||
|
- Lightweight, no schema registry needed
|
||||||
|
|
||||||
|
**Kafka backend:**
|
||||||
|
- Generate Avro schemas from dataclass definitions
|
||||||
|
- Use Confluent's schema registry
|
||||||
|
- Type-safe serialization with schema evolution support
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Application Code │
|
||||||
|
│ - Uses dataclass schemas │
|
||||||
|
│ - Backend-agnostic │
|
||||||
|
└──────────────┬──────────────────────┘
|
||||||
|
│
|
||||||
|
┌──────────────┴──────────────────────┐
|
||||||
|
│ PubSubFactory (configurable) │
|
||||||
|
│ - get_pubsub() returns backend │
|
||||||
|
└──────────────┬──────────────────────┘
|
||||||
|
│
|
||||||
|
┌──────┴──────┐
|
||||||
|
│ │
|
||||||
|
┌───────▼─────────┐ ┌────▼──────────────┐
|
||||||
|
│ PulsarBackend │ │ MQTTBackend │
|
||||||
|
│ - JSON schema │ │ - JSON serialize │
|
||||||
|
│ - or dynamic │ │ - Simple queues │
|
||||||
|
│ Record gen │ │ │
|
||||||
|
└─────────────────┘ └───────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
**1. Schema definitions:** Plain dataclasses with type hints
|
||||||
|
- `str`, `int`, `bool`, `float` for primitives
|
||||||
|
- `list[T]` for arrays
|
||||||
|
- `dict[str, T]` for maps
|
||||||
|
- Nested dataclasses for complex types
|
||||||
|
|
||||||
|
**2. Each backend provides:**
|
||||||
|
- Serializer: `dataclass → bytes/wire format`
|
||||||
|
- Deserializer: `bytes/wire format → dataclass`
|
||||||
|
- Schema registration (if needed, like Pulsar/Kafka)
|
||||||
|
|
||||||
|
**3. Consumer/Producer abstraction:**
|
||||||
|
- Already exists (consumer.py, producer.py)
|
||||||
|
- Update to use backend's serialization
|
||||||
|
- Remove direct Pulsar imports
|
||||||
|
|
||||||
|
**4. Type mappings:**
|
||||||
|
- Pulsar `String()` → Python `str`
|
||||||
|
- Pulsar `Integer()` → Python `int`
|
||||||
|
- Pulsar `Boolean()` → Python `bool`
|
||||||
|
- Pulsar `Array(T)` → Python `list[T]`
|
||||||
|
- Pulsar `Map(K, V)` → Python `dict[K, V]`
|
||||||
|
- Pulsar `Double()` → Python `float`
|
||||||
|
- Pulsar `Bytes()` → Python `bytes`
|
||||||
|
|
||||||
|
### Migration Path
|
||||||
|
|
||||||
|
1. **Create dataclass versions** of all schemas in `trustgraph/schema/`
|
||||||
|
2. **Update backend classes** (Consumer, Producer, Publisher, Subscriber) to use backend-provided serialization
|
||||||
|
3. **Implement PulsarBackend** with JSON schema or dynamic Record generation
|
||||||
|
4. **Test with Pulsar** to ensure backward compatibility with existing deployments
|
||||||
|
5. **Add new backends** (MQTT, Kafka, Redis, etc.) as needed
|
||||||
|
6. **Remove Pulsar imports** from schema files
|
||||||
|
|
||||||
|
### Benefits
|
||||||
|
|
||||||
|
✅ **No pub/sub dependency** in schema definitions
|
||||||
|
✅ **Standard Python** - easy to understand, type-check, document
|
||||||
|
✅ **Modern tooling** - works with mypy, IDE autocomplete, linters
|
||||||
|
✅ **Backend-optimized** - each backend uses native serialization
|
||||||
|
✅ **No translation overhead** - direct serialization, no adapters
|
||||||
|
✅ **Type safety** - real objects with proper types
|
||||||
|
✅ **Easy validation** - can use Pydantic if needed
|
||||||
|
|
||||||
|
### Challenges & Solutions
|
||||||
|
|
||||||
|
**Challenge:** Pulsar's `Record` has runtime field validation
|
||||||
|
**Solution:** Use Pydantic dataclasses for validation if needed, or Python 3.10+ dataclass features with `__post_init__`
|
||||||
|
|
||||||
|
**Challenge:** Some Pulsar-specific features (like `Bytes` type)
|
||||||
|
**Solution:** Map to `bytes` type in dataclass, backend handles encoding appropriately
|
||||||
|
|
||||||
|
**Challenge:** Topic naming (`persistent://tenant/namespace/topic`)
|
||||||
|
**Solution:** Abstract topic names in schema definitions, backend converts to proper format
|
||||||
|
|
||||||
|
**Challenge:** Schema evolution and versioning
|
||||||
|
**Solution:** Each backend handles this according to its capabilities (Pulsar schema versions, Kafka schema registry, etc.)
|
||||||
|
|
||||||
|
**Challenge:** Nested complex types
|
||||||
|
**Solution:** Use nested dataclasses, backends recursively serialize/deserialize
|
||||||
|
|
||||||
|
### Design Decisions
|
||||||
|
|
||||||
|
1. **Plain dataclasses or Pydantic?**
|
||||||
|
- ✅ **Decision: Use plain Python dataclasses**
|
||||||
|
- Simpler, no additional dependencies
|
||||||
|
- Validation not required in practice
|
||||||
|
- Easier to understand and maintain
|
||||||
|
|
||||||
|
2. **Schema evolution:**
|
||||||
|
- ✅ **Decision: No versioning mechanism needed**
|
||||||
|
- Schemas are stable and long-lasting
|
||||||
|
- Updates typically add new fields (backward compatible)
|
||||||
|
- Backends handle schema evolution according to their capabilities
|
||||||
|
|
||||||
|
3. **Backward compatibility:**
|
||||||
|
- ✅ **Decision: Major version change, no backward compatibility required**
|
||||||
|
- Will be a breaking change with migration instructions
|
||||||
|
- Clean break allows for better design
|
||||||
|
- Migration guide will be provided for existing deployments
|
||||||
|
|
||||||
|
4. **Nested types and complex structures:**
|
||||||
|
- ✅ **Decision: Use nested dataclasses naturally**
|
||||||
|
- Python dataclasses handle nesting perfectly
|
||||||
|
- `list[T]` for arrays, `dict[K, V]` for maps
|
||||||
|
- Backends recursively serialize/deserialize
|
||||||
|
- Example:
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class Value:
|
||||||
|
value: str
|
||||||
|
is_uri: bool
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Triple:
|
||||||
|
s: Value # Nested dataclass
|
||||||
|
p: Value
|
||||||
|
o: Value
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphQuery:
|
||||||
|
triples: list[Triple] # Array of nested dataclasses
|
||||||
|
metadata: dict[str, str]
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Default values and optional fields:**
|
||||||
|
- ✅ **Decision: Mix of required, defaults, and optional fields**
|
||||||
|
- Required fields: No default value
|
||||||
|
- Fields with defaults: Always present, have sensible default
|
||||||
|
- Truly optional fields: `T | None = None`, omitted from serialization when `None`
|
||||||
|
- Example:
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class TextCompletionRequest:
|
||||||
|
system: str # Required, no default
|
||||||
|
prompt: str # Required, no default
|
||||||
|
streaming: bool = False # Optional with default value
|
||||||
|
metadata: dict | None = None # Truly optional, can be absent
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important serialization semantics:**
|
||||||
|
|
||||||
|
When `metadata = None`:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"system": "...",
|
||||||
|
"prompt": "...",
|
||||||
|
"streaming": false
|
||||||
|
// metadata field NOT PRESENT
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When `metadata = {}` (explicitly empty):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"system": "...",
|
||||||
|
"prompt": "...",
|
||||||
|
"streaming": false,
|
||||||
|
"metadata": {} // Field PRESENT but empty
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key distinction:**
|
||||||
|
- `None` → field absent from JSON (not serialized)
|
||||||
|
- Empty value (`{}`, `[]`, `""`) → field present with empty value
|
||||||
|
- This matters semantically: "not provided" vs "explicitly empty"
|
||||||
|
- Serialization backends must skip `None` fields, not encode as `null`
|
||||||
|
|
||||||
|
## Approach Draft 3: Implementation Details
|
||||||
|
|
||||||
|
### Generic Queue Naming Format
|
||||||
|
|
||||||
|
Replace backend-specific queue names with a generic format that backends can map appropriately.
|
||||||
|
|
||||||
|
**Format:** `{qos}/{tenant}/{namespace}/{queue-name}`
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `qos`: Quality of Service level
|
||||||
|
- `q0` = best-effort (fire and forget, no acknowledgment)
|
||||||
|
- `q1` = at-least-once (requires acknowledgment)
|
||||||
|
- `q2` = exactly-once (two-phase acknowledgment)
|
||||||
|
- `tenant`: Logical grouping for multi-tenancy
|
||||||
|
- `namespace`: Sub-grouping within tenant
|
||||||
|
- `queue-name`: Actual queue/topic name
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
```
|
||||||
|
q1/tg/flow/text-completion-requests
|
||||||
|
q2/tg/config/config-push
|
||||||
|
q0/tg/metrics/stats
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Topic Mapping
|
||||||
|
|
||||||
|
Each backend maps the generic format to its native format:
|
||||||
|
|
||||||
|
**Pulsar Backend:**
|
||||||
|
```python
|
||||||
|
def map_topic(self, generic_topic: str) -> str:
|
||||||
|
# Parse: q1/tg/flow/text-completion-requests
|
||||||
|
qos, tenant, namespace, queue = generic_topic.split('/', 3)
|
||||||
|
|
||||||
|
# Map QoS to persistence
|
||||||
|
persistence = 'persistent' if qos in ['q1', 'q2'] else 'non-persistent'
|
||||||
|
|
||||||
|
# Return Pulsar URI: persistent://tg/flow/text-completion-requests
|
||||||
|
return f"{persistence}://{tenant}/{namespace}/{queue}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**MQTT Backend:**
|
||||||
|
```python
|
||||||
|
def map_topic(self, generic_topic: str) -> tuple[str, int]:
|
||||||
|
# Parse: q1/tg/flow/text-completion-requests
|
||||||
|
qos, tenant, namespace, queue = generic_topic.split('/', 3)
|
||||||
|
|
||||||
|
# Map QoS level
|
||||||
|
qos_level = {'q0': 0, 'q1': 1, 'q2': 2}[qos]
|
||||||
|
|
||||||
|
# Build MQTT topic including tenant/namespace for proper namespacing
|
||||||
|
mqtt_topic = f"{tenant}/{namespace}/{queue}"
|
||||||
|
|
||||||
|
return mqtt_topic, qos_level
|
||||||
|
```
|
||||||
|
|
||||||
|
### Updated Topic Helper Function
|
||||||
|
|
||||||
|
```python
|
||||||
|
# schema/core/topic.py
|
||||||
|
def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
|
||||||
|
"""
|
||||||
|
Create a generic topic identifier that can be mapped by backends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_name: The queue/topic name
|
||||||
|
qos: Quality of service
|
||||||
|
- 'q0' = best-effort (no ack)
|
||||||
|
- 'q1' = at-least-once (ack required)
|
||||||
|
- 'q2' = exactly-once (two-phase ack)
|
||||||
|
tenant: Tenant identifier for multi-tenancy
|
||||||
|
namespace: Namespace within tenant
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generic topic string: qos/tenant/namespace/queue_name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
topic('my-queue') # q1/tg/flow/my-queue
|
||||||
|
topic('config', qos='q2', namespace='config') # q2/tg/config/config
|
||||||
|
"""
|
||||||
|
return f"{qos}/{tenant}/{namespace}/{queue_name}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration and Initialization
|
||||||
|
|
||||||
|
**Command-Line Arguments + Environment Variables:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In base/async_processor.py - add_args() method
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
# Pub/sub backend selection
|
||||||
|
parser.add_argument(
|
||||||
|
'--pubsub-backend',
|
||||||
|
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||||
|
choices=['pulsar', 'mqtt'],
|
||||||
|
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pulsar-specific configuration
|
||||||
|
parser.add_argument(
|
||||||
|
'--pulsar-host',
|
||||||
|
default=os.getenv('PULSAR_HOST', 'pulsar://localhost:6650'),
|
||||||
|
help='Pulsar host (default: pulsar://localhost:6650, env: PULSAR_HOST)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--pulsar-api-key',
|
||||||
|
default=os.getenv('PULSAR_API_KEY', None),
|
||||||
|
help='Pulsar API key (env: PULSAR_API_KEY)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--pulsar-listener',
|
||||||
|
default=os.getenv('PULSAR_LISTENER', None),
|
||||||
|
help='Pulsar listener name (env: PULSAR_LISTENER)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# MQTT-specific configuration
|
||||||
|
parser.add_argument(
|
||||||
|
'--mqtt-host',
|
||||||
|
default=os.getenv('MQTT_HOST', 'localhost'),
|
||||||
|
help='MQTT broker host (default: localhost, env: MQTT_HOST)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mqtt-port',
|
||||||
|
type=int,
|
||||||
|
default=int(os.getenv('MQTT_PORT', '1883')),
|
||||||
|
help='MQTT broker port (default: 1883, env: MQTT_PORT)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mqtt-username',
|
||||||
|
default=os.getenv('MQTT_USERNAME', None),
|
||||||
|
help='MQTT username (env: MQTT_USERNAME)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mqtt-password',
|
||||||
|
default=os.getenv('MQTT_PASSWORD', None),
|
||||||
|
help='MQTT password (env: MQTT_PASSWORD)'
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Factory Function:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In base/pubsub.py or base/pubsub_factory.py
|
||||||
|
def get_pubsub(**config) -> PubSubBackend:
|
||||||
|
"""
|
||||||
|
Create and return a pub/sub backend based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dict from command-line args
|
||||||
|
Must include 'pubsub_backend' key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend instance (PulsarBackend, MQTTBackend, etc.)
|
||||||
|
"""
|
||||||
|
backend_type = config.get('pubsub_backend', 'pulsar')
|
||||||
|
|
||||||
|
if backend_type == 'pulsar':
|
||||||
|
return PulsarBackend(
|
||||||
|
host=config.get('pulsar_host'),
|
||||||
|
api_key=config.get('pulsar_api_key'),
|
||||||
|
listener=config.get('pulsar_listener'),
|
||||||
|
)
|
||||||
|
elif backend_type == 'mqtt':
|
||||||
|
return MQTTBackend(
|
||||||
|
host=config.get('mqtt_host'),
|
||||||
|
port=config.get('mqtt_port'),
|
||||||
|
username=config.get('mqtt_username'),
|
||||||
|
password=config.get('mqtt_password'),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Usage in AsyncProcessor:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In async_processor.py
|
||||||
|
class AsyncProcessor:
|
||||||
|
def __init__(self, **params):
|
||||||
|
self.id = params.get("id")
|
||||||
|
|
||||||
|
# Create backend from config (replaces PulsarClient)
|
||||||
|
self.pubsub = get_pubsub(**params)
|
||||||
|
|
||||||
|
# Rest of initialization...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend Interface
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PubSubBackend(Protocol):
|
||||||
|
"""Protocol defining the interface all pub/sub backends must implement."""
|
||||||
|
|
||||||
|
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||||
|
"""
|
||||||
|
Create a producer for a topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
options: Backend-specific options (e.g., chunking_enabled)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific producer instance
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def create_consumer(
|
||||||
|
self,
|
||||||
|
topic: str,
|
||||||
|
subscription: str,
|
||||||
|
schema: type,
|
||||||
|
initial_position: str = 'latest',
|
||||||
|
consumer_type: str = 'shared',
|
||||||
|
**options
|
||||||
|
) -> BackendConsumer:
|
||||||
|
"""
|
||||||
|
Create a consumer for a topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
subscription: Subscription/consumer group name
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
initial_position: 'earliest' or 'latest' (MQTT may ignore)
|
||||||
|
consumer_type: 'shared', 'exclusive', 'failover' (MQTT may ignore)
|
||||||
|
options: Backend-specific options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific consumer instance
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the backend connection."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BackendProducer(Protocol):
|
||||||
|
"""Protocol for backend-specific producer."""
|
||||||
|
|
||||||
|
def send(self, message: Any, properties: dict = {}) -> None:
|
||||||
|
"""Send a message (dataclass instance) with optional properties."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush any buffered messages."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the producer."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BackendConsumer(Protocol):
|
||||||
|
"""Protocol for backend-specific consumer."""
|
||||||
|
|
||||||
|
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||||
|
"""
|
||||||
|
Receive a message from the topic.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If no message received within timeout
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def acknowledge(self, message: Message) -> None:
|
||||||
|
"""Acknowledge successful processing of a message."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def negative_acknowledge(self, message: Message) -> None:
|
||||||
|
"""Negative acknowledge - triggers redelivery."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def unsubscribe(self) -> None:
|
||||||
|
"""Unsubscribe from the topic."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the consumer."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Message(Protocol):
|
||||||
|
"""Protocol for a received message."""
|
||||||
|
|
||||||
|
def value(self) -> Any:
|
||||||
|
"""Get the deserialized message (dataclass instance)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def properties(self) -> dict:
|
||||||
|
"""Get message properties/metadata."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Existing Classes Refactoring
|
||||||
|
|
||||||
|
The existing `Consumer`, `Producer`, `Publisher`, `Subscriber` classes remain largely intact:
|
||||||
|
|
||||||
|
**Current responsibilities (keep):**
|
||||||
|
- Async threading model and taskgroups
|
||||||
|
- Reconnection logic and retry handling
|
||||||
|
- Metrics collection
|
||||||
|
- Rate limiting
|
||||||
|
- Concurrency management
|
||||||
|
|
||||||
|
**Changes needed:**
|
||||||
|
- Remove direct Pulsar imports (`pulsar.schema`, `pulsar.InitialPosition`, etc.)
|
||||||
|
- Accept `BackendProducer`/`BackendConsumer` instead of Pulsar client
|
||||||
|
- Delegate actual pub/sub operations to backend instances
|
||||||
|
- Map generic concepts to backend calls
|
||||||
|
|
||||||
|
**Example refactoring:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# OLD - consumer.py
|
||||||
|
class Consumer:
|
||||||
|
def __init__(self, client, topic, subscriber, schema, ...):
|
||||||
|
self.client = client # Direct Pulsar client
|
||||||
|
# ...
|
||||||
|
|
||||||
|
async def consumer_run(self):
|
||||||
|
# Uses pulsar.InitialPosition, pulsar.ConsumerType
|
||||||
|
self.consumer = self.client.subscribe(
|
||||||
|
topic=self.topic,
|
||||||
|
schema=JsonSchema(self.schema),
|
||||||
|
initial_position=pulsar.InitialPosition.Earliest,
|
||||||
|
consumer_type=pulsar.ConsumerType.Shared,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NEW - consumer.py
|
||||||
|
class Consumer:
|
||||||
|
def __init__(self, backend_consumer, schema, ...):
|
||||||
|
self.backend_consumer = backend_consumer # Backend-specific consumer
|
||||||
|
self.schema = schema
|
||||||
|
# ...
|
||||||
|
|
||||||
|
async def consumer_run(self):
|
||||||
|
# Backend consumer already created with right settings
|
||||||
|
# Just use it directly
|
||||||
|
while self.running:
|
||||||
|
msg = await asyncio.to_thread(
|
||||||
|
self.backend_consumer.receive,
|
||||||
|
timeout_millis=2000
|
||||||
|
)
|
||||||
|
await self.handle_message(msg)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend-Specific Behaviors
|
||||||
|
|
||||||
|
**Pulsar Backend:**
|
||||||
|
- Maps `q0` → `non-persistent://`, `q1`/`q2` → `persistent://`
|
||||||
|
- Supports all consumer types (shared, exclusive, failover)
|
||||||
|
- Supports initial position (earliest/latest)
|
||||||
|
- Native message acknowledgment
|
||||||
|
- Schema registry support
|
||||||
|
|
||||||
|
**MQTT Backend:**
|
||||||
|
- Maps `q0`/`q1`/`q2` → MQTT QoS levels 0/1/2
|
||||||
|
- Includes tenant/namespace in topic path for namespacing
|
||||||
|
- Auto-generates client IDs from subscription names
|
||||||
|
- Ignores initial position (no message history in basic MQTT)
|
||||||
|
- Ignores consumer type (MQTT uses client IDs, not consumer groups)
|
||||||
|
- Simple publish/subscribe model
|
||||||
|
|
||||||
|
### Design Decisions Summary
|
||||||
|
|
||||||
|
1. ✅ **Generic queue naming**: `qos/tenant/namespace/queue-name` format
|
||||||
|
2. ✅ **QoS in queue ID**: Determined by queue definition, not configuration
|
||||||
|
3. ✅ **Reconnection**: Handled by Consumer/Producer classes, not backends
|
||||||
|
4. ✅ **MQTT topics**: Include tenant/namespace for proper namespacing
|
||||||
|
5. ✅ **Message history**: MQTT ignores `initial_position` parameter (future enhancement)
|
||||||
|
6. ✅ **Client IDs**: MQTT backend auto-generates from subscription name
|
||||||
|
|
||||||
|
### Future Enhancements
|
||||||
|
|
||||||
|
**MQTT message history:**
|
||||||
|
- Could add optional persistence layer (e.g., retained messages, external store)
|
||||||
|
- Would allow supporting `initial_position='earliest'`
|
||||||
|
- Not required for initial implementation
|
||||||
|
|
||||||
1508
docs/tech-specs/python-api-refactor.md
Normal file
1508
docs/tech-specs/python-api-refactor.md
Normal file
File diff suppressed because it is too large
Load diff
54
ontology-prompt.md
Normal file
54
ontology-prompt.md
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
You are a knowledge extraction expert. Extract structured triples from text using ONLY the provided ontology elements.
|
||||||
|
|
||||||
|
## Ontology Classes:
|
||||||
|
|
||||||
|
{% for class_id, class_def in classes.items() %}
|
||||||
|
- **{{class_id}}**{% if class_def.subclass_of %} (subclass of {{class_def.subclass_of}}){% endif %}{% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
## Object Properties (connect entities):
|
||||||
|
|
||||||
|
{% for prop_id, prop_def in object_properties.items() %}
|
||||||
|
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
## Datatype Properties (entity attributes):
|
||||||
|
|
||||||
|
{% for prop_id, prop_def in datatype_properties.items() %}
|
||||||
|
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
## Text to Analyze:
|
||||||
|
|
||||||
|
{{text}}
|
||||||
|
|
||||||
|
## Extraction Rules:
|
||||||
|
|
||||||
|
1. Only use classes defined above for entity types
|
||||||
|
2. Only use properties defined above for relationships and attributes
|
||||||
|
3. Respect domain and range constraints where specified
|
||||||
|
4. For class instances, use `rdf:type` as the predicate
|
||||||
|
5. Include `rdfs:label` for new entities to provide human-readable names
|
||||||
|
6. Extract all relevant triples that can be inferred from the text
|
||||||
|
7. Use entity URIs or meaningful identifiers as subjects/objects
|
||||||
|
|
||||||
|
## Output Format:
|
||||||
|
|
||||||
|
Return ONLY a valid JSON array (no markdown, no code blocks) containing objects with these fields:
|
||||||
|
- "subject": the subject entity (URI or identifier)
|
||||||
|
- "predicate": the property (from ontology or rdf:type/rdfs:label)
|
||||||
|
- "object": the object entity or literal value
|
||||||
|
|
||||||
|
Important: Return raw JSON only, with no markdown formatting, no code blocks, and no backticks.
|
||||||
|
|
||||||
|
## Example Output:
|
||||||
|
|
||||||
|
[
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
|
||||||
|
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"},
|
||||||
|
{"subject": "ingredient:flour", "predicate": "rdf:type", "object": "Ingredient"},
|
||||||
|
{"subject": "ingredient:flour", "predicate": "rdfs:label", "object": "plain flour"}
|
||||||
|
]
|
||||||
|
|
||||||
|
Now extract triples from the text above.
|
||||||
|
|
@ -21,3 +21,4 @@ prometheus-client
|
||||||
pyarrow
|
pyarrow
|
||||||
boto3
|
boto3
|
||||||
ollama
|
ollama
|
||||||
|
python-logging-loki
|
||||||
|
|
|
||||||
60
tests/conftest.py
Normal file
60
tests/conftest.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""
|
||||||
|
Global pytest configuration for all tests.
|
||||||
|
|
||||||
|
This conftest.py applies to all test directories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
# import asyncio
|
||||||
|
# import tracemalloc
|
||||||
|
# import warnings
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Uncomment the lines below to enable asyncio debug mode and tracemalloc
|
||||||
|
# for tracing unawaited coroutines and their creation points
|
||||||
|
# tracemalloc.start()
|
||||||
|
# asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||||
|
# warnings.simplefilter("always", ResourceWarning)
|
||||||
|
# warnings.simplefilter("always", RuntimeWarning)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def mock_loki_handler(session_mocker=None):
|
||||||
|
"""
|
||||||
|
Mock LokiHandler to prevent connection attempts during tests.
|
||||||
|
|
||||||
|
This fixture runs once per test session and prevents the logging
|
||||||
|
module from trying to connect to a Loki server that doesn't exist
|
||||||
|
in the test environment.
|
||||||
|
"""
|
||||||
|
# Try to import logging_loki and mock it if available
|
||||||
|
try:
|
||||||
|
import logging_loki
|
||||||
|
# Create a mock LokiHandler that does nothing
|
||||||
|
original_loki_handler = logging_loki.LokiHandler
|
||||||
|
|
||||||
|
class MockLokiHandler:
|
||||||
|
"""Mock LokiHandler that doesn't make network calls."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Replace the real LokiHandler with our mock
|
||||||
|
logging_loki.LokiHandler = MockLokiHandler
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restore original after tests
|
||||||
|
logging_loki.LokiHandler = original_loki_handler
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# If logging_loki isn't installed, no need to mock
|
||||||
|
yield
|
||||||
|
|
@ -257,7 +257,6 @@ class TestAgentMessageContracts:
|
||||||
# Act
|
# Act
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="What comes next?",
|
question="What comes next?",
|
||||||
plan="Multi-step plan",
|
|
||||||
state="processing",
|
state="processing",
|
||||||
history=history_steps
|
history=history_steps
|
||||||
)
|
)
|
||||||
|
|
@ -588,7 +587,6 @@ class TestSerializationContracts:
|
||||||
|
|
||||||
request = AgentRequest(
|
request = AgentRequest(
|
||||||
question="Test with array",
|
question="Test with array",
|
||||||
plan="Test plan",
|
|
||||||
state="Test state",
|
state="Test state",
|
||||||
history=steps
|
history=steps
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -189,6 +189,7 @@ class TestObjectsCassandraContracts:
|
||||||
assert result == expected_val
|
assert result == expected_val
|
||||||
assert isinstance(result, expected_type) or result is None
|
assert isinstance(result, expected_type) or result is None
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||||
def test_extracted_object_serialization_contract(self):
|
def test_extracted_object_serialization_contract(self):
|
||||||
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
||||||
# Create test object
|
# Create test object
|
||||||
|
|
@ -408,6 +409,7 @@ class TestObjectsCassandraContractsBatch:
|
||||||
assert isinstance(single_batch_object.values[0], dict)
|
assert isinstance(single_batch_object.values[0], dict)
|
||||||
assert single_batch_object.values[0]["customer_id"] == "CUST999"
|
assert single_batch_object.values[0]["customer_id"] == "CUST999"
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||||
def test_extracted_object_batch_serialization_contract(self):
|
def test_extracted_object_batch_serialization_contract(self):
|
||||||
"""Test that batched ExtractedObject can be serialized/deserialized correctly"""
|
"""Test that batched ExtractedObject can be serialized/deserialized correctly"""
|
||||||
# Create batch object
|
# Create batch object
|
||||||
|
|
|
||||||
|
|
@ -480,11 +480,15 @@ def streaming_chunk_collector():
|
||||||
class ChunkCollector:
|
class ChunkCollector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chunks = []
|
self.chunks = []
|
||||||
|
self.end_of_stream_flags = []
|
||||||
self.complete = False
|
self.complete = False
|
||||||
|
|
||||||
async def collect(self, chunk):
|
async def collect(self, chunk, end_of_stream=False):
|
||||||
"""Async callback to collect chunks"""
|
"""Async callback to collect chunks with end_of_stream flag"""
|
||||||
self.chunks.append(chunk)
|
self.chunks.append(chunk)
|
||||||
|
self.end_of_stream_flags.append(end_of_stream)
|
||||||
|
if end_of_stream:
|
||||||
|
self.complete = True
|
||||||
|
|
||||||
def get_full_text(self):
|
def get_full_text(self):
|
||||||
"""Concatenate all chunk content"""
|
"""Concatenate all chunk content"""
|
||||||
|
|
@ -496,6 +500,14 @@ def streaming_chunk_collector():
|
||||||
return [c.get("chunk_type") for c in self.chunks]
|
return [c.get("chunk_type") for c in self.chunks]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def verify_streaming_protocol(self):
|
||||||
|
"""Verify that streaming protocol is correct"""
|
||||||
|
assert len(self.chunks) > 0, "Should have received at least one chunk"
|
||||||
|
assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag"
|
||||||
|
assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True"
|
||||||
|
assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True"
|
||||||
|
assert self.complete is True, "Should be marked complete after final chunk"
|
||||||
|
|
||||||
return ChunkCollector
|
return ChunkCollector
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,9 @@ Args: {
|
||||||
"}"
|
"}"
|
||||||
]
|
]
|
||||||
|
|
||||||
for chunk in chunks:
|
for i, chunk in enumerate(chunks):
|
||||||
await chunk_callback(chunk)
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk, is_final)
|
||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
else:
|
else:
|
||||||
|
|
@ -312,8 +313,10 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
||||||
if streaming and chunk_callback:
|
if streaming and chunk_callback:
|
||||||
for chunk in response.split():
|
chunks = response.split()
|
||||||
await chunk_callback(chunk + " ")
|
for i, chunk in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk + " ", is_final)
|
||||||
return response
|
return response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -373,13 +373,13 @@ class TestMultipleHostsHandling:
|
||||||
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
||||||
|
|
||||||
# Test various whitespace scenarios
|
# Test various whitespace scenarios
|
||||||
hosts1, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
||||||
assert hosts1 == ['host1', 'host2', 'host3']
|
assert hosts1 == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
hosts2, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
||||||
assert hosts2 == ['host1', 'host2', 'host3']
|
assert hosts2 == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
hosts3, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
||||||
assert hosts3 == ['host1', 'host2']
|
assert hosts3 == ['host1', 'host2']
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,16 @@ class TestDocumentRagStreaming:
|
||||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
if streaming and chunk_callback:
|
if streaming and chunk_callback:
|
||||||
# Simulate streaming chunks
|
# Simulate streaming chunks with end_of_stream flags
|
||||||
|
chunks = []
|
||||||
async for chunk in mock_streaming_llm_response():
|
async for chunk in mock_streaming_llm_response():
|
||||||
await chunk_callback(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Send all chunks with end_of_stream=False except the last
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk, is_final)
|
||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
else:
|
else:
|
||||||
# Non-streaming response - same text
|
# Non-streaming response - same text
|
||||||
|
|
@ -89,6 +96,9 @@ class TestDocumentRagStreaming:
|
||||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify streaming protocol compliance
|
||||||
|
collector.verify_streaming_protocol()
|
||||||
|
|
||||||
# Verify full response matches concatenated chunks
|
# Verify full response matches concatenated chunks
|
||||||
full_from_chunks = collector.get_full_text()
|
full_from_chunks = collector.get_full_text()
|
||||||
assert result == full_from_chunks
|
assert result == full_from_chunks
|
||||||
|
|
@ -117,7 +127,7 @@ class TestDocumentRagStreaming:
|
||||||
# Act - Streaming
|
# Act - Streaming
|
||||||
streaming_chunks = []
|
streaming_chunks = []
|
||||||
|
|
||||||
async def collect(chunk):
|
async def collect(chunk, end_of_stream):
|
||||||
streaming_chunks.append(chunk)
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
streaming_result = await document_rag_streaming.query(
|
streaming_result = await document_rag_streaming.query(
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,16 @@ class TestGraphRagStreaming:
|
||||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
if streaming and chunk_callback:
|
if streaming and chunk_callback:
|
||||||
# Simulate streaming chunks
|
# Simulate streaming chunks with end_of_stream flags
|
||||||
|
chunks = []
|
||||||
async for chunk in mock_streaming_llm_response():
|
async for chunk in mock_streaming_llm_response():
|
||||||
await chunk_callback(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Send all chunks with end_of_stream=False except the last
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk, is_final)
|
||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
else:
|
else:
|
||||||
# Non-streaming response - same text
|
# Non-streaming response - same text
|
||||||
|
|
@ -102,6 +109,9 @@ class TestGraphRagStreaming:
|
||||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify streaming protocol compliance
|
||||||
|
collector.verify_streaming_protocol()
|
||||||
|
|
||||||
# Verify full response matches concatenated chunks
|
# Verify full response matches concatenated chunks
|
||||||
full_from_chunks = collector.get_full_text()
|
full_from_chunks = collector.get_full_text()
|
||||||
assert result == full_from_chunks
|
assert result == full_from_chunks
|
||||||
|
|
@ -128,7 +138,7 @@ class TestGraphRagStreaming:
|
||||||
# Act - Streaming
|
# Act - Streaming
|
||||||
streaming_chunks = []
|
streaming_chunks = []
|
||||||
|
|
||||||
async def collect(chunk):
|
async def collect(chunk, end_of_stream):
|
||||||
streaming_chunks.append(chunk)
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
streaming_result = await graph_rag_streaming.query(
|
streaming_result = await graph_rag_streaming.query(
|
||||||
|
|
|
||||||
|
|
@ -59,16 +59,16 @@ class MockWebSocket:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pulsar_client():
|
def mock_backend():
|
||||||
"""Mock Pulsar client for integration testing."""
|
"""Mock backend for integration testing."""
|
||||||
client = MagicMock()
|
backend = MagicMock()
|
||||||
|
|
||||||
# Mock producer
|
# Mock producer
|
||||||
producer = MagicMock()
|
producer = MagicMock()
|
||||||
producer.send = MagicMock()
|
producer.send = MagicMock()
|
||||||
producer.flush = MagicMock()
|
producer.flush = MagicMock()
|
||||||
producer.close = MagicMock()
|
producer.close = MagicMock()
|
||||||
client.create_producer.return_value = producer
|
backend.create_producer.return_value = producer
|
||||||
|
|
||||||
# Mock consumer
|
# Mock consumer
|
||||||
consumer = MagicMock()
|
consumer = MagicMock()
|
||||||
|
|
@ -78,17 +78,15 @@ def mock_pulsar_client():
|
||||||
consumer.pause_message_listener = MagicMock()
|
consumer.pause_message_listener = MagicMock()
|
||||||
consumer.unsubscribe = MagicMock()
|
consumer.unsubscribe = MagicMock()
|
||||||
consumer.close = MagicMock()
|
consumer.close = MagicMock()
|
||||||
client.subscribe.return_value = consumer
|
backend.create_consumer.return_value = consumer
|
||||||
|
|
||||||
return client
|
return backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_import_graceful_shutdown_integration():
|
async def test_import_graceful_shutdown_integration(mock_backend):
|
||||||
"""Test import path handles shutdown gracefully with real message flow."""
|
"""Test import path handles shutdown gracefully with real message flow."""
|
||||||
mock_client = MagicMock()
|
mock_producer = mock_backend.create_producer.return_value
|
||||||
mock_producer = MagicMock()
|
|
||||||
mock_client.create_producer.return_value = mock_producer
|
|
||||||
|
|
||||||
# Track sent messages
|
# Track sent messages
|
||||||
sent_messages = []
|
sent_messages = []
|
||||||
|
|
@ -104,7 +102,7 @@ async def test_import_graceful_shutdown_integration():
|
||||||
import_handler = TriplesImport(
|
import_handler = TriplesImport(
|
||||||
ws=ws,
|
ws=ws,
|
||||||
running=running,
|
running=running,
|
||||||
pulsar_client=mock_client,
|
backend=mock_backend,
|
||||||
queue="test-triples-import"
|
queue="test-triples-import"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -151,11 +149,9 @@ async def test_import_graceful_shutdown_integration():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_export_no_message_loss_integration():
|
async def test_export_no_message_loss_integration(mock_backend):
|
||||||
"""Test export path doesn't lose acknowledged messages."""
|
"""Test export path doesn't lose acknowledged messages."""
|
||||||
mock_client = MagicMock()
|
mock_consumer = mock_backend.create_consumer.return_value
|
||||||
mock_consumer = MagicMock()
|
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
|
||||||
|
|
||||||
# Create test messages
|
# Create test messages
|
||||||
test_messages = []
|
test_messages = []
|
||||||
|
|
@ -202,7 +198,7 @@ async def test_export_no_message_loss_integration():
|
||||||
export_handler = TriplesExport(
|
export_handler = TriplesExport(
|
||||||
ws=ws,
|
ws=ws,
|
||||||
running=running,
|
running=running,
|
||||||
pulsar_client=mock_client,
|
backend=mock_backend,
|
||||||
queue="test-triples-export",
|
queue="test-triples-export",
|
||||||
consumer="test-consumer",
|
consumer="test-consumer",
|
||||||
subscriber="test-subscriber"
|
subscriber="test-subscriber"
|
||||||
|
|
@ -245,14 +241,14 @@ async def test_export_no_message_loss_integration():
|
||||||
async def test_concurrent_import_export_shutdown():
|
async def test_concurrent_import_export_shutdown():
|
||||||
"""Test concurrent import and export shutdown scenarios."""
|
"""Test concurrent import and export shutdown scenarios."""
|
||||||
# Setup mock clients
|
# Setup mock clients
|
||||||
import_client = MagicMock()
|
import_backend = MagicMock()
|
||||||
export_client = MagicMock()
|
export_backend = MagicMock()
|
||||||
|
|
||||||
import_producer = MagicMock()
|
import_producer = MagicMock()
|
||||||
export_consumer = MagicMock()
|
export_consumer = MagicMock()
|
||||||
|
|
||||||
import_client.create_producer.return_value = import_producer
|
import_backend.create_producer.return_value = import_producer
|
||||||
export_client.subscribe.return_value = export_consumer
|
export_backend.subscribe.return_value = export_consumer
|
||||||
|
|
||||||
# Track operations
|
# Track operations
|
||||||
import_operations = []
|
import_operations = []
|
||||||
|
|
@ -280,14 +276,14 @@ async def test_concurrent_import_export_shutdown():
|
||||||
import_handler = TriplesImport(
|
import_handler = TriplesImport(
|
||||||
ws=import_ws,
|
ws=import_ws,
|
||||||
running=import_running,
|
running=import_running,
|
||||||
pulsar_client=import_client,
|
backend=import_backend,
|
||||||
queue="concurrent-import"
|
queue="concurrent-import"
|
||||||
)
|
)
|
||||||
|
|
||||||
export_handler = TriplesExport(
|
export_handler = TriplesExport(
|
||||||
ws=export_ws,
|
ws=export_ws,
|
||||||
running=export_running,
|
running=export_running,
|
||||||
pulsar_client=export_client,
|
backend=export_backend,
|
||||||
queue="concurrent-export",
|
queue="concurrent-export",
|
||||||
consumer="concurrent-consumer",
|
consumer="concurrent-consumer",
|
||||||
subscriber="concurrent-subscriber"
|
subscriber="concurrent-subscriber"
|
||||||
|
|
@ -328,9 +324,9 @@ async def test_concurrent_import_export_shutdown():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_websocket_close_during_message_processing():
|
async def test_websocket_close_during_message_processing():
|
||||||
"""Test graceful handling when websocket closes during active message processing."""
|
"""Test graceful handling when websocket closes during active message processing."""
|
||||||
mock_client = MagicMock()
|
mock_backend_local = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend_local.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
# Simulate slow message processing
|
# Simulate slow message processing
|
||||||
processed_messages = []
|
processed_messages = []
|
||||||
|
|
@ -346,7 +342,7 @@ async def test_websocket_close_during_message_processing():
|
||||||
import_handler = TriplesImport(
|
import_handler = TriplesImport(
|
||||||
ws=ws,
|
ws=ws,
|
||||||
running=running,
|
running=running,
|
||||||
pulsar_client=mock_client,
|
backend=mock_backend_local,
|
||||||
queue="slow-processing-import"
|
queue="slow-processing-import"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -395,9 +391,9 @@ async def test_websocket_close_during_message_processing():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_backpressure_during_shutdown():
|
async def test_backpressure_during_shutdown():
|
||||||
"""Test graceful shutdown under backpressure conditions."""
|
"""Test graceful shutdown under backpressure conditions."""
|
||||||
mock_client = MagicMock()
|
mock_backend_local = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend_local.subscribe.return_value = mock_consumer
|
||||||
|
|
||||||
# Mock slow websocket
|
# Mock slow websocket
|
||||||
class SlowWebSocket(MockWebSocket):
|
class SlowWebSocket(MockWebSocket):
|
||||||
|
|
@ -411,7 +407,7 @@ async def test_backpressure_during_shutdown():
|
||||||
export_handler = TriplesExport(
|
export_handler = TriplesExport(
|
||||||
ws=ws,
|
ws=ws,
|
||||||
running=running,
|
running=running,
|
||||||
pulsar_client=mock_client,
|
backend=mock_backend_local,
|
||||||
queue="backpressure-export",
|
queue="backpressure-export",
|
||||||
consumer="backpressure-consumer",
|
consumer="backpressure-consumer",
|
||||||
subscriber="backpressure-subscriber"
|
subscriber="backpressure-subscriber"
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ class TestObjectsCassandraIntegration:
|
||||||
assert "customer_records" in processor.schemas
|
assert "customer_records" in processor.schemas
|
||||||
|
|
||||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||||
await processor.create_collection("test_user", "import_2024")
|
await processor.create_collection("test_user", "import_2024", {})
|
||||||
|
|
||||||
# Step 2: Process an ExtractedObject
|
# Step 2: Process an ExtractedObject
|
||||||
test_obj = ExtractedObject(
|
test_obj = ExtractedObject(
|
||||||
|
|
@ -213,8 +213,8 @@ class TestObjectsCassandraIntegration:
|
||||||
assert len(processor.schemas) == 2
|
assert len(processor.schemas) == 2
|
||||||
|
|
||||||
# Create collections first
|
# Create collections first
|
||||||
await processor.create_collection("shop", "catalog")
|
await processor.create_collection("shop", "catalog", {})
|
||||||
await processor.create_collection("shop", "sales")
|
await processor.create_collection("shop", "sales", {})
|
||||||
|
|
||||||
# Process objects for different schemas
|
# Process objects for different schemas
|
||||||
product_obj = ExtractedObject(
|
product_obj = ExtractedObject(
|
||||||
|
|
@ -263,7 +263,7 @@ class TestObjectsCassandraIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection first
|
# Create collection first
|
||||||
await processor.create_collection("test", "test")
|
await processor.create_collection("test", "test", {})
|
||||||
|
|
||||||
# Create object missing required field
|
# Create object missing required field
|
||||||
test_obj = ExtractedObject(
|
test_obj = ExtractedObject(
|
||||||
|
|
@ -302,7 +302,7 @@ class TestObjectsCassandraIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection first
|
# Create collection first
|
||||||
await processor.create_collection("logger", "app_events")
|
await processor.create_collection("logger", "app_events", {})
|
||||||
|
|
||||||
# Process object
|
# Process object
|
||||||
test_obj = ExtractedObject(
|
test_obj = ExtractedObject(
|
||||||
|
|
@ -407,7 +407,7 @@ class TestObjectsCassandraIntegration:
|
||||||
|
|
||||||
# Create all collections first
|
# Create all collections first
|
||||||
for coll in collections:
|
for coll in collections:
|
||||||
await processor.create_collection("analytics", coll)
|
await processor.create_collection("analytics", coll, {})
|
||||||
|
|
||||||
for coll in collections:
|
for coll in collections:
|
||||||
obj = ExtractedObject(
|
obj = ExtractedObject(
|
||||||
|
|
@ -486,7 +486,7 @@ class TestObjectsCassandraIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection first
|
# Create collection first
|
||||||
await processor.create_collection("test_user", "batch_import")
|
await processor.create_collection("test_user", "batch_import", {})
|
||||||
|
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.value.return_value = batch_obj
|
msg.value.return_value = batch_obj
|
||||||
|
|
@ -532,7 +532,7 @@ class TestObjectsCassandraIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection first
|
# Create collection first
|
||||||
await processor.create_collection("test", "empty")
|
await processor.create_collection("test", "empty", {})
|
||||||
|
|
||||||
# Process empty batch object
|
# Process empty batch object
|
||||||
empty_obj = ExtractedObject(
|
empty_obj = ExtractedObject(
|
||||||
|
|
@ -573,7 +573,7 @@ class TestObjectsCassandraIntegration:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create collection first
|
# Create collection first
|
||||||
await processor.create_collection("test", "mixed")
|
await processor.create_collection("test", "mixed", {})
|
||||||
|
|
||||||
# Single object (backward compatibility)
|
# Single object (backward compatibility)
|
||||||
single_obj = ExtractedObject(
|
single_obj = ExtractedObject(
|
||||||
|
|
|
||||||
351
tests/integration/test_rag_streaming_protocol.py
Normal file
351
tests/integration/test_rag_streaming_protocol.py
Normal file
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""
|
||||||
|
Integration tests for RAG service streaming protocol compliance.
|
||||||
|
|
||||||
|
These tests verify that RAG services correctly forward end_of_stream flags
|
||||||
|
and don't duplicate final chunks, ensuring proper streaming semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphRagStreamingProtocol:
|
||||||
|
"""Integration tests for GraphRAG streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_graph_embeddings_client(self):
|
||||||
|
"""Mock graph embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = ["entity1", "entity2"]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_triples_client(self):
|
||||||
|
"""Mock triples client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = []
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self):
|
||||||
|
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||||
|
await chunk_callback("The", False)
|
||||||
|
await chunk_callback(" answer", False)
|
||||||
|
await chunk_callback(" is here.", False)
|
||||||
|
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||||
|
return "" # Return value not used since callback handles everything
|
||||||
|
else:
|
||||||
|
return "The answer is here."
|
||||||
|
|
||||||
|
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||||
|
mock_triples_client, mock_streaming_prompt_client):
|
||||||
|
"""Create GraphRag instance with mocked dependencies"""
|
||||||
|
return GraphRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
graph_embeddings_client=mock_graph_embeddings_client,
|
||||||
|
triples_client=mock_triples_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_end_of_stream_parameter(self, graph_rag):
|
||||||
|
"""Test that callback receives end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should receive (chunk, end_of_stream) signature
|
||||||
|
assert callback.call_count == 4
|
||||||
|
# All calls should have 2 arguments
|
||||||
|
for call_args in callback.call_args_list:
|
||||||
|
assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag):
|
||||||
|
"""Test that end_of_stream flags are forwarded correctly"""
|
||||||
|
# Arrange
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 4
|
||||||
|
|
||||||
|
# First three chunks should have end_of_stream=False
|
||||||
|
assert chunks_with_flags[0] == ("The", False)
|
||||||
|
assert chunks_with_flags[1] == (" answer", False)
|
||||||
|
assert chunks_with_flags[2] == (" is here.", False)
|
||||||
|
|
||||||
|
# Final chunk should have end_of_stream=True
|
||||||
|
assert chunks_with_flags[3] == ("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_final_chunk(self, graph_rag):
|
||||||
|
"""Test that final chunk is not duplicated"""
|
||||||
|
# Arrange
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should have exactly 4 chunks, no duplicates
|
||||||
|
assert len(chunks) == 4
|
||||||
|
assert chunks == ["The", " answer", " is here.", ""]
|
||||||
|
|
||||||
|
# The last chunk appears exactly once
|
||||||
|
assert chunks.count("") == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exactly_one_end_of_stream_true(self, graph_rag):
|
||||||
|
"""Test that exactly one message has end_of_stream=True"""
|
||||||
|
# Arrange
|
||||||
|
end_of_stream_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
end_of_stream_flags.append(end_of_stream)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - exactly one True
|
||||||
|
assert end_of_stream_flags.count(True) == 1
|
||||||
|
assert end_of_stream_flags.count(False) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_final_chunk_preserved(self, graph_rag):
|
||||||
|
"""Test that empty final chunks are preserved and forwarded"""
|
||||||
|
# Arrange
|
||||||
|
final_chunk = None
|
||||||
|
final_flag = None
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
nonlocal final_chunk, final_flag
|
||||||
|
if end_of_stream:
|
||||||
|
final_chunk = chunk
|
||||||
|
final_flag = end_of_stream
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert final_flag is True
|
||||||
|
assert final_chunk == "", "Empty final chunk should be preserved"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentRagStreamingProtocol:
|
||||||
|
"""Integration tests for DocumentRAG streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_doc_embeddings_client(self):
|
||||||
|
"""Mock document embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = ["doc1", "doc2"]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self):
|
||||||
|
"""Mock prompt client with streaming support"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate streaming with non-empty final chunk (some LLMs do this)
|
||||||
|
await chunk_callback("Document", False)
|
||||||
|
await chunk_callback(" summary", False)
|
||||||
|
await chunk_callback(".", True) # Non-empty final chunk
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "Document summary."
|
||||||
|
|
||||||
|
client.document_prompt.side_effect = document_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
|
mock_streaming_prompt_client):
|
||||||
|
"""Create DocumentRag instance with mocked dependencies"""
|
||||||
|
return DocumentRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_end_of_stream_parameter(self, document_rag):
|
||||||
|
"""Test that callback receives end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 3
|
||||||
|
for call_args in callback.call_args_list:
|
||||||
|
assert len(call_args.args) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_empty_final_chunk_preserved(self, document_rag):
|
||||||
|
"""Test that non-empty final chunks are preserved with correct flag"""
|
||||||
|
# Arrange
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 3
|
||||||
|
assert chunks_with_flags[0] == ("Document", False)
|
||||||
|
assert chunks_with_flags[1] == (" summary", False)
|
||||||
|
assert chunks_with_flags[2] == (".", True) # Non-empty final chunk
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_final_chunk(self, document_rag):
|
||||||
|
"""Test that final chunk is not duplicated"""
|
||||||
|
# Arrange
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - final "." appears exactly once
|
||||||
|
assert chunks.count(".") == 1
|
||||||
|
assert chunks == ["Document", " summary", "."]
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingProtocolEdgeCases:
|
||||||
|
"""Test edge cases in streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_empty_chunks_before_final(self):
|
||||||
|
"""Test handling of multiple empty chunks (edge case)"""
|
||||||
|
# Arrange
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
await chunk_callback("text", False)
|
||||||
|
await chunk_callback("", False) # Empty but not final
|
||||||
|
await chunk_callback("more", False)
|
||||||
|
await chunk_callback("", True) # Empty and final
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "textmore"
|
||||||
|
|
||||||
|
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||||
|
|
||||||
|
rag = GraphRag(
|
||||||
|
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||||
|
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||||
|
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||||
|
prompt_client=client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await rag.query(
|
||||||
|
query="test",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 4
|
||||||
|
assert chunks_with_flags[-1] == ("", True) # Final empty chunk
|
||||||
|
end_of_stream_flags = [f for c, f in chunks_with_flags]
|
||||||
|
assert end_of_stream_flags.count(True) == 1
|
||||||
|
|
@ -14,15 +14,16 @@ from trustgraph.base.async_processor import AsyncProcessor
|
||||||
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
"""Test AsyncProcessor base class functionality"""
|
"""Test AsyncProcessor base class functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.base.async_processor.PulsarClient')
|
@patch('trustgraph.base.async_processor.get_pubsub')
|
||||||
@patch('trustgraph.base.async_processor.Consumer')
|
@patch('trustgraph.base.async_processor.Consumer')
|
||||||
@patch('trustgraph.base.async_processor.ProcessorMetrics')
|
@patch('trustgraph.base.async_processor.ProcessorMetrics')
|
||||||
@patch('trustgraph.base.async_processor.ConsumerMetrics')
|
@patch('trustgraph.base.async_processor.ConsumerMetrics')
|
||||||
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
||||||
mock_consumer, mock_pulsar_client):
|
mock_consumer, mock_get_pubsub):
|
||||||
"""Test basic AsyncProcessor initialization"""
|
"""Test basic AsyncProcessor initialization"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_pulsar_client.return_value = MagicMock()
|
mock_backend = MagicMock()
|
||||||
|
mock_get_pubsub.return_value = mock_backend
|
||||||
mock_consumer.return_value = MagicMock()
|
mock_consumer.return_value = MagicMock()
|
||||||
mock_processor_metrics.return_value = MagicMock()
|
mock_processor_metrics.return_value = MagicMock()
|
||||||
mock_consumer_metrics.return_value = MagicMock()
|
mock_consumer_metrics.return_value = MagicMock()
|
||||||
|
|
@ -43,8 +44,8 @@ class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
assert hasattr(processor, 'config_handlers')
|
assert hasattr(processor, 'config_handlers')
|
||||||
assert processor.config_handlers == []
|
assert processor.config_handlers == []
|
||||||
|
|
||||||
# Verify PulsarClient was created
|
# Verify get_pubsub was called to create backend
|
||||||
mock_pulsar_client.assert_called_once_with(**config)
|
mock_get_pubsub.assert_called_once_with(**config)
|
||||||
|
|
||||||
# Verify metrics were initialized
|
# Verify metrics were initialized
|
||||||
mock_processor_metrics.assert_called_once()
|
mock_processor_metrics.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ class TestResolveCassandraConfig:
|
||||||
def test_default_configuration(self):
|
def test_default_configuration(self):
|
||||||
"""Test resolution with no parameters or environment variables."""
|
"""Test resolution with no parameters or environment variables."""
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
hosts, username, password = resolve_cassandra_config()
|
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['cassandra']
|
assert hosts == ['cassandra']
|
||||||
assert username is None
|
assert username is None
|
||||||
|
|
@ -160,7 +160,7 @@ class TestResolveCassandraConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
hosts, username, password = resolve_cassandra_config()
|
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['env1', 'env2', 'env3']
|
assert hosts == ['env1', 'env2', 'env3']
|
||||||
assert username == 'env-user'
|
assert username == 'env-user'
|
||||||
|
|
@ -175,7 +175,7 @@ class TestResolveCassandraConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
hosts, username, password = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host='explicit-host',
|
host='explicit-host',
|
||||||
username='explicit-user',
|
username='explicit-user',
|
||||||
password='explicit-pass'
|
password='explicit-pass'
|
||||||
|
|
@ -188,19 +188,19 @@ class TestResolveCassandraConfig:
|
||||||
def test_host_list_parsing(self):
|
def test_host_list_parsing(self):
|
||||||
"""Test different host list formats."""
|
"""Test different host list formats."""
|
||||||
# Single host
|
# Single host
|
||||||
hosts, _, _ = resolve_cassandra_config(host='single-host')
|
hosts, _, _, _ = resolve_cassandra_config(host='single-host')
|
||||||
assert hosts == ['single-host']
|
assert hosts == ['single-host']
|
||||||
|
|
||||||
# Multiple hosts with spaces
|
# Multiple hosts with spaces
|
||||||
hosts, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
||||||
assert hosts == ['host1', 'host2', 'host3']
|
assert hosts == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
# Empty elements filtered out
|
# Empty elements filtered out
|
||||||
hosts, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
||||||
assert hosts == ['host1', 'host2']
|
assert hosts == ['host1', 'host2']
|
||||||
|
|
||||||
# Already a list
|
# Already a list
|
||||||
hosts, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
||||||
assert hosts == ['list-host1', 'list-host2']
|
assert hosts == ['list-host1', 'list-host2']
|
||||||
|
|
||||||
def test_args_object_resolution(self):
|
def test_args_object_resolution(self):
|
||||||
|
|
@ -212,7 +212,7 @@ class TestResolveCassandraConfig:
|
||||||
cassandra_password = 'args-pass'
|
cassandra_password = 'args-pass'
|
||||||
|
|
||||||
args = MockArgs()
|
args = MockArgs()
|
||||||
hosts, username, password = resolve_cassandra_config(args)
|
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
||||||
|
|
||||||
assert hosts == ['args-host1', 'args-host2']
|
assert hosts == ['args-host1', 'args-host2']
|
||||||
assert username == 'args-user'
|
assert username == 'args-user'
|
||||||
|
|
@ -233,7 +233,7 @@ class TestResolveCassandraConfig:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
args = PartialArgs()
|
args = PartialArgs()
|
||||||
hosts, username, password = resolve_cassandra_config(args)
|
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
||||||
|
|
||||||
assert hosts == ['args-host'] # From args
|
assert hosts == ['args-host'] # From args
|
||||||
assert username == 'env-user' # From env
|
assert username == 'env-user' # From env
|
||||||
|
|
@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'cassandra_password': 'new-pass'
|
'cassandra_password': 'new-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['new-host1', 'new-host2']
|
assert hosts == ['new-host1', 'new-host2']
|
||||||
assert username == 'new-user'
|
assert username == 'new-user'
|
||||||
|
|
@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'graph_password': 'old-pass'
|
'graph_password': 'old-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
# Should use defaults since graph_* params are not recognized
|
# Should use defaults since graph_* params are not recognized
|
||||||
assert hosts == ['cassandra'] # Default
|
assert hosts == ['cassandra'] # Default
|
||||||
|
|
@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'cassandra_password': 'compat-pass'
|
'cassandra_password': 'compat-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['compat-host']
|
assert hosts == ['compat-host']
|
||||||
assert username is None # cassandra_user is not recognized
|
assert username is None # cassandra_user is not recognized
|
||||||
|
|
@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'graph_password': 'old-pass'
|
'graph_password': 'old-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['new-host'] # Only cassandra_* params work
|
assert hosts == ['new-host'] # Only cassandra_* params work
|
||||||
assert username == 'new-user' # Only cassandra_* params work
|
assert username == 'new-user' # Only cassandra_* params work
|
||||||
|
|
@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
params = {}
|
params = {}
|
||||||
hosts, username, password = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['fallback-host1', 'fallback-host2']
|
assert hosts == ['fallback-host1', 'fallback-host2']
|
||||||
assert username == 'fallback-user'
|
assert username == 'fallback-user'
|
||||||
|
|
@ -334,7 +334,7 @@ class TestConfigurationPriority:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
# CLI args should override everything
|
# CLI args should override everything
|
||||||
hosts, username, password = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host='cli-host',
|
host='cli-host',
|
||||||
username='cli-user',
|
username='cli-user',
|
||||||
password='cli-pass'
|
password='cli-pass'
|
||||||
|
|
@ -354,7 +354,7 @@ class TestConfigurationPriority:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
# Only provide host via CLI
|
# Only provide host via CLI
|
||||||
hosts, username, password = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host='cli-host'
|
host='cli-host'
|
||||||
# username and password not provided
|
# username and password not provided
|
||||||
)
|
)
|
||||||
|
|
@ -366,7 +366,7 @@ class TestConfigurationPriority:
|
||||||
def test_no_config_defaults(self):
|
def test_no_config_defaults(self):
|
||||||
"""Test that defaults are used when no configuration is provided."""
|
"""Test that defaults are used when no configuration is provided."""
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
hosts, username, password = resolve_cassandra_config()
|
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['cassandra'] # Default
|
assert hosts == ['cassandra'] # Default
|
||||||
assert username is None # Default
|
assert username is None # Default
|
||||||
|
|
@ -378,17 +378,17 @@ class TestEdgeCases:
|
||||||
|
|
||||||
def test_empty_host_string(self):
|
def test_empty_host_string(self):
|
||||||
"""Test handling of empty host string falls back to default."""
|
"""Test handling of empty host string falls back to default."""
|
||||||
hosts, _, _ = resolve_cassandra_config(host='')
|
hosts, _, _, _ = resolve_cassandra_config(host='')
|
||||||
assert hosts == ['cassandra'] # Falls back to default
|
assert hosts == ['cassandra'] # Falls back to default
|
||||||
|
|
||||||
def test_whitespace_only_host(self):
|
def test_whitespace_only_host(self):
|
||||||
"""Test handling of whitespace-only host string."""
|
"""Test handling of whitespace-only host string."""
|
||||||
hosts, _, _ = resolve_cassandra_config(host=' ')
|
hosts, _, _, _ = resolve_cassandra_config(host=' ')
|
||||||
assert hosts == [] # Empty after stripping whitespace
|
assert hosts == [] # Empty after stripping whitespace
|
||||||
|
|
||||||
def test_none_values_preserved(self):
|
def test_none_values_preserved(self):
|
||||||
"""Test that None values are preserved correctly."""
|
"""Test that None values are preserved correctly."""
|
||||||
hosts, username, password = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host=None,
|
host=None,
|
||||||
username=None,
|
username=None,
|
||||||
password=None
|
password=None
|
||||||
|
|
@ -401,7 +401,7 @@ class TestEdgeCases:
|
||||||
|
|
||||||
def test_mixed_none_and_values(self):
|
def test_mixed_none_and_values(self):
|
||||||
"""Test mixing None and actual values."""
|
"""Test mixing None and actual values."""
|
||||||
hosts, username, password = resolve_cassandra_config(
|
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||||
host='mixed-host',
|
host='mixed-host',
|
||||||
username=None,
|
username=None,
|
||||||
password='mixed-pass'
|
password='mixed-pass'
|
||||||
|
|
|
||||||
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
|
|
@ -0,0 +1,260 @@
|
||||||
|
"""
|
||||||
|
Unit tests for PromptClient streaming callback behavior.
|
||||||
|
|
||||||
|
These tests verify that the prompt client correctly passes the end_of_stream
|
||||||
|
flag to chunk callbacks, ensuring proper streaming protocol compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
from trustgraph.base.prompt_client import PromptClient
|
||||||
|
from trustgraph.schema import PromptResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptClientStreamingCallback:
|
||||||
|
"""Test PromptClient streaming callback behavior"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_client(self):
|
||||||
|
"""Create a PromptClient with mocked dependencies"""
|
||||||
|
# Mock all the required initialization parameters
|
||||||
|
with patch.object(PromptClient, '__init__', lambda self: None):
|
||||||
|
client = PromptClient()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request_response(self):
|
||||||
|
"""Create a mock request/response handler"""
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Simulate streaming responses
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
return PromptResponse(text="Hello world!", object=None, error=None)
|
||||||
|
|
||||||
|
return mock_request
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
|
||||||
|
"""Test that callback receives both chunk text and end_of_stream flag"""
|
||||||
|
# Arrange
|
||||||
|
prompt_client.request = mock_request_response
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should be called with (chunk, end_of_stream) signature
|
||||||
|
assert callback.call_count == 4
|
||||||
|
|
||||||
|
# Verify first chunk: text + end_of_stream=False
|
||||||
|
assert callback.call_args_list[0] == call("Hello", False)
|
||||||
|
|
||||||
|
# Verify second chunk
|
||||||
|
assert callback.call_args_list[1] == call(" world", False)
|
||||||
|
|
||||||
|
# Verify third chunk
|
||||||
|
assert callback.call_args_list[2] == call("!", False)
|
||||||
|
|
||||||
|
# Verify final chunk: empty text + end_of_stream=True
|
||||||
|
assert callback.call_args_list[3] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
|
||||||
|
"""Test that empty final chunks are passed to callback"""
|
||||||
|
# Arrange
|
||||||
|
prompt_client.request = mock_request_response
|
||||||
|
|
||||||
|
chunks_received = []
|
||||||
|
|
||||||
|
async def collect_chunks(chunk, end_of_stream):
|
||||||
|
chunks_received.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect_chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should receive the empty final chunk
|
||||||
|
final_chunk = chunks_received[-1]
|
||||||
|
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
|
||||||
|
"""Test callback signature when LLM sends non-empty final chunk"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Some LLMs send content in the final chunk
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request_non_empty_final
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Hello", False)
|
||||||
|
assert callback.call_args_list[1] == call(" world!", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_not_called_without_text(self, prompt_client):
|
||||||
|
"""Test that callback is not called for responses without text"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request_no_text(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Response with only end_of_stream, no text
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request_no_text
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should only be called once (for "Content")
|
||||||
|
assert callback.call_count == 1
|
||||||
|
assert callback.call_args_list[0] == call("Content", False)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
|
||||||
|
"""Test that synchronous callbacks also receive end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = MagicMock() # Synchronous mock
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - synchronous callback should also get both parameters
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("test", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||||
|
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.kg_prompt(
|
||||||
|
query="What is machine learning?",
|
||||||
|
kg=[("subject", "predicate", "object")],
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Answer", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||||
|
"""Test that document_prompt correctly passes streaming parameters"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.document_prompt(
|
||||||
|
query="Summarize this",
|
||||||
|
documents=["doc1", "doc2"],
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Summary", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
@ -8,22 +8,22 @@ from trustgraph.base.publisher import Publisher
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pulsar_client():
|
def mock_pulsar_backend():
|
||||||
"""Mock Pulsar client for testing."""
|
"""Mock Pulsar backend for testing."""
|
||||||
client = MagicMock()
|
backend = MagicMock()
|
||||||
producer = AsyncMock()
|
producer = AsyncMock()
|
||||||
producer.send = MagicMock()
|
producer.send = MagicMock()
|
||||||
producer.flush = MagicMock()
|
producer.flush = MagicMock()
|
||||||
producer.close = MagicMock()
|
producer.close = MagicMock()
|
||||||
client.create_producer.return_value = producer
|
backend.create_producer.return_value = producer
|
||||||
return client
|
return backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def publisher(mock_pulsar_client):
|
def publisher(mock_pulsar_backend):
|
||||||
"""Create Publisher instance for testing."""
|
"""Create Publisher instance for testing."""
|
||||||
return Publisher(
|
return Publisher(
|
||||||
client=mock_pulsar_client,
|
backend=mock_pulsar_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -34,12 +34,12 @@ def publisher(mock_pulsar_client):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_queue_drain():
|
async def test_publisher_queue_drain():
|
||||||
"""Verify Publisher drains queue on shutdown."""
|
"""Verify Publisher drains queue on shutdown."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -85,12 +85,12 @@ async def test_publisher_queue_drain():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_rejects_messages_during_drain():
|
async def test_publisher_rejects_messages_during_drain():
|
||||||
"""Verify Publisher rejects new messages during shutdown."""
|
"""Verify Publisher rejects new messages during shutdown."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -113,12 +113,12 @@ async def test_publisher_rejects_messages_during_drain():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_drain_timeout():
|
async def test_publisher_drain_timeout():
|
||||||
"""Verify Publisher respects drain timeout."""
|
"""Verify Publisher respects drain timeout."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -169,12 +169,12 @@ async def test_publisher_drain_timeout():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_successful_drain():
|
async def test_publisher_successful_drain():
|
||||||
"""Verify Publisher drains successfully under normal conditions."""
|
"""Verify Publisher drains successfully under normal conditions."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -224,12 +224,12 @@ async def test_publisher_successful_drain():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_state_transitions():
|
async def test_publisher_state_transitions():
|
||||||
"""Test Publisher state transitions during graceful shutdown."""
|
"""Test Publisher state transitions during graceful shutdown."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
@ -276,9 +276,9 @@ async def test_publisher_state_transitions():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publisher_exception_handling():
|
async def test_publisher_exception_handling():
|
||||||
"""Test Publisher handles exceptions during drain gracefully."""
|
"""Test Publisher handles exceptions during drain gracefully."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_producer = MagicMock()
|
mock_producer = MagicMock()
|
||||||
mock_client.create_producer.return_value = mock_producer
|
mock_backend.create_producer.return_value = mock_producer
|
||||||
|
|
||||||
# Mock producer.send to raise exception on second call
|
# Mock producer.send to raise exception on second call
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
@ -291,7 +291,7 @@ async def test_publisher_exception_handling():
|
||||||
mock_producer.send.side_effect = failing_send
|
mock_producer.send.side_effect = failing_send
|
||||||
|
|
||||||
publisher = Publisher(
|
publisher = Publisher(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
schema=dict,
|
schema=dict,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
|
|
||||||
|
|
@ -6,23 +6,11 @@ import uuid
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from trustgraph.base.subscriber import Subscriber
|
from trustgraph.base.subscriber import Subscriber
|
||||||
|
|
||||||
# Mock JsonSchema globally to avoid schema issues in tests
|
|
||||||
# Patch at the module level where it's imported in subscriber
|
|
||||||
@patch('trustgraph.base.subscriber.JsonSchema')
|
|
||||||
def mock_json_schema_global(mock_schema):
|
|
||||||
mock_schema.return_value = MagicMock()
|
|
||||||
return mock_schema
|
|
||||||
|
|
||||||
# Apply the global patch
|
|
||||||
_json_schema_patch = patch('trustgraph.base.subscriber.JsonSchema')
|
|
||||||
_mock_json_schema = _json_schema_patch.start()
|
|
||||||
_mock_json_schema.return_value = MagicMock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pulsar_client():
|
def mock_pulsar_backend():
|
||||||
"""Mock Pulsar client for testing."""
|
"""Mock Pulsar backend for testing."""
|
||||||
client = MagicMock()
|
backend = MagicMock()
|
||||||
consumer = MagicMock()
|
consumer = MagicMock()
|
||||||
consumer.receive = MagicMock()
|
consumer.receive = MagicMock()
|
||||||
consumer.acknowledge = MagicMock()
|
consumer.acknowledge = MagicMock()
|
||||||
|
|
@ -30,15 +18,15 @@ def mock_pulsar_client():
|
||||||
consumer.pause_message_listener = MagicMock()
|
consumer.pause_message_listener = MagicMock()
|
||||||
consumer.unsubscribe = MagicMock()
|
consumer.unsubscribe = MagicMock()
|
||||||
consumer.close = MagicMock()
|
consumer.close = MagicMock()
|
||||||
client.subscribe.return_value = consumer
|
backend.create_consumer.return_value = consumer
|
||||||
return client
|
return backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def subscriber(mock_pulsar_client):
|
def subscriber(mock_pulsar_backend):
|
||||||
"""Create Subscriber instance for testing."""
|
"""Create Subscriber instance for testing."""
|
||||||
return Subscriber(
|
return Subscriber(
|
||||||
client=mock_pulsar_client,
|
backend=mock_pulsar_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -60,12 +48,12 @@ def create_mock_message(message_id="test-id", data=None):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_deferred_acknowledgment_success():
|
async def test_subscriber_deferred_acknowledgment_success():
|
||||||
"""Verify Subscriber only acks on successful delivery."""
|
"""Verify Subscriber only acks on successful delivery."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -102,12 +90,12 @@ async def test_subscriber_deferred_acknowledgment_success():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_deferred_acknowledgment_failure():
|
async def test_subscriber_deferred_acknowledgment_failure():
|
||||||
"""Verify Subscriber negative acks on delivery failure."""
|
"""Verify Subscriber negative acks on delivery failure."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -140,13 +128,13 @@ async def test_subscriber_deferred_acknowledgment_failure():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_backpressure_strategies():
|
async def test_subscriber_backpressure_strategies():
|
||||||
"""Test different backpressure strategies."""
|
"""Test different backpressure strategies."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
# Test drop_oldest strategy
|
# Test drop_oldest strategy
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -187,12 +175,12 @@ async def test_subscriber_backpressure_strategies():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_graceful_shutdown():
|
async def test_subscriber_graceful_shutdown():
|
||||||
"""Test Subscriber graceful shutdown with queue draining."""
|
"""Test Subscriber graceful shutdown with queue draining."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -253,12 +241,12 @@ async def test_subscriber_graceful_shutdown():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_drain_timeout():
|
async def test_subscriber_drain_timeout():
|
||||||
"""Test Subscriber respects drain timeout."""
|
"""Test Subscriber respects drain timeout."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -288,12 +276,12 @@ async def test_subscriber_drain_timeout():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_pending_acks_cleanup():
|
async def test_subscriber_pending_acks_cleanup():
|
||||||
"""Test Subscriber cleans up pending acknowledgments on shutdown."""
|
"""Test Subscriber cleans up pending acknowledgments on shutdown."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
@ -342,12 +330,12 @@ async def test_subscriber_pending_acks_cleanup():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_subscriber_multiple_subscribers():
|
async def test_subscriber_multiple_subscribers():
|
||||||
"""Test Subscriber with multiple concurrent subscribers."""
|
"""Test Subscriber with multiple concurrent subscribers."""
|
||||||
mock_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_client.subscribe.return_value = mock_consumer
|
mock_backend.create_consumer.return_value = mock_consumer
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client=mock_client,
|
backend=mock_backend,
|
||||||
topic="test-topic",
|
topic="test-topic",
|
||||||
subscription="test-subscription",
|
subscription="test-subscription",
|
||||||
consumer_name="test-consumer",
|
consumer_name="test-consumer",
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,8 @@ class TestListConfigItems:
|
||||||
mock_list.assert_called_once_with(
|
mock_list.assert_called_once_with(
|
||||||
url='http://custom.com',
|
url='http://custom.com',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
format_type='json'
|
format_type='json',
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_list_main_uses_defaults(self):
|
def test_list_main_uses_defaults(self):
|
||||||
|
|
@ -126,7 +127,8 @@ class TestListConfigItems:
|
||||||
mock_list.assert_called_once_with(
|
mock_list.assert_called_once_with(
|
||||||
url='http://localhost:8088/',
|
url='http://localhost:8088/',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
format_type='text'
|
format_type='text',
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -193,7 +195,8 @@ class TestGetConfigItem:
|
||||||
url='http://custom.com',
|
url='http://custom.com',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
key='template-1',
|
key='template-1',
|
||||||
format_type='json'
|
format_type='json',
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -249,7 +252,8 @@ class TestPutConfigItem:
|
||||||
url='http://custom.com',
|
url='http://custom.com',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
key='new-template',
|
key='new-template',
|
||||||
value='Custom prompt: {input}'
|
value='Custom prompt: {input}',
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_put_main_with_stdin_arg(self):
|
def test_put_main_with_stdin_arg(self):
|
||||||
|
|
@ -273,7 +277,8 @@ class TestPutConfigItem:
|
||||||
url='http://localhost:8088/',
|
url='http://localhost:8088/',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
key='stdin-template',
|
key='stdin-template',
|
||||||
value=stdin_content
|
value=stdin_content,
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_put_main_mutually_exclusive_args(self):
|
def test_put_main_mutually_exclusive_args(self):
|
||||||
|
|
@ -328,7 +333,8 @@ class TestDeleteConfigItem:
|
||||||
mock_delete.assert_called_once_with(
|
mock_delete.assert_called_once_with(
|
||||||
url='http://custom.com',
|
url='http://custom.com',
|
||||||
config_type='prompt',
|
config_type='prompt',
|
||||||
key='old-template'
|
key='old-template',
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,16 @@
|
||||||
Unit tests for the load_knowledge CLI module.
|
Unit tests for the load_knowledge CLI module.
|
||||||
|
|
||||||
Tests the business logic of loading triples and entity contexts from Turtle files
|
Tests the business logic of loading triples and entity contexts from Turtle files
|
||||||
while mocking WebSocket connections and external dependencies.
|
using the BulkClient API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import asyncio
|
from unittest.mock import Mock, patch, MagicMock, call
|
||||||
from unittest.mock import AsyncMock, Mock, patch, mock_open, MagicMock
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from trustgraph.cli.load_knowledge import KnowledgeLoader, main
|
from trustgraph.cli.load_knowledge import KnowledgeLoader, main
|
||||||
|
from trustgraph.api import Triple
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -43,26 +42,6 @@ def temp_turtle_file(sample_turtle_content):
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_websocket():
|
|
||||||
"""Mock WebSocket connection."""
|
|
||||||
mock_ws = MagicMock()
|
|
||||||
|
|
||||||
async def async_send(data):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def async_recv():
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def async_close():
|
|
||||||
return None
|
|
||||||
|
|
||||||
mock_ws.send = Mock(side_effect=async_send)
|
|
||||||
mock_ws.recv = Mock(side_effect=async_recv)
|
|
||||||
mock_ws.close = Mock(side_effect=async_close)
|
|
||||||
return mock_ws
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def knowledge_loader():
|
def knowledge_loader():
|
||||||
"""Create a KnowledgeLoader instance with test parameters."""
|
"""Create a KnowledgeLoader instance with test parameters."""
|
||||||
|
|
@ -72,125 +51,66 @@ def knowledge_loader():
|
||||||
user="test-user",
|
user="test-user",
|
||||||
collection="test-collection",
|
collection="test-collection",
|
||||||
document_id="test-doc-123",
|
document_id="test-doc-123",
|
||||||
url="ws://test.example.com/"
|
url="http://test.example.com/",
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestKnowledgeLoader:
|
class TestKnowledgeLoader:
|
||||||
"""Test the KnowledgeLoader class business logic."""
|
"""Test the KnowledgeLoader class business logic."""
|
||||||
|
|
||||||
def test_init_constructs_urls_correctly(self):
|
def test_init_stores_parameters_correctly(self):
|
||||||
"""Test that URLs are constructed properly."""
|
"""Test that initialization stores parameters correctly."""
|
||||||
loader = KnowledgeLoader(
|
loader = KnowledgeLoader(
|
||||||
files=["test.ttl"],
|
files=["file1.ttl", "file2.ttl"],
|
||||||
flow="my-flow",
|
flow="my-flow",
|
||||||
user="user1",
|
user="user1",
|
||||||
collection="col1",
|
collection="col1",
|
||||||
document_id="doc1",
|
document_id="doc1",
|
||||||
url="ws://example.com/"
|
url="http://example.com/",
|
||||||
|
token="test-token"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples"
|
assert loader.files == ["file1.ttl", "file2.ttl"]
|
||||||
assert loader.entity_contexts_url == "ws://example.com/api/v1/flow/my-flow/import/entity-contexts"
|
assert loader.flow == "my-flow"
|
||||||
assert loader.user == "user1"
|
assert loader.user == "user1"
|
||||||
assert loader.collection == "col1"
|
assert loader.collection == "col1"
|
||||||
assert loader.document_id == "doc1"
|
assert loader.document_id == "doc1"
|
||||||
|
assert loader.url == "http://example.com/"
|
||||||
|
assert loader.token == "test-token"
|
||||||
|
|
||||||
def test_init_adds_trailing_slash(self):
|
def test_load_triples_from_file_yields_triples(self, temp_turtle_file, knowledge_loader):
|
||||||
"""Test that trailing slash is added to URL if missing."""
|
"""Test that load_triples_from_file yields Triple objects."""
|
||||||
loader = KnowledgeLoader(
|
triples = list(knowledge_loader.load_triples_from_file(temp_turtle_file))
|
||||||
files=["test.ttl"],
|
|
||||||
flow="my-flow",
|
|
||||||
user="user1",
|
|
||||||
collection="col1",
|
|
||||||
document_id="doc1",
|
|
||||||
url="ws://example.com" # No trailing slash
|
|
||||||
)
|
|
||||||
|
|
||||||
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples"
|
# Should have triples for all statements in the file
|
||||||
|
assert len(triples) > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# Verify they are Triple objects
|
||||||
async def test_load_triples_sends_correct_messages(self, temp_turtle_file, mock_websocket):
|
for triple in triples:
|
||||||
"""Test that triple loading sends correctly formatted messages."""
|
assert isinstance(triple, Triple)
|
||||||
loader = KnowledgeLoader(
|
assert hasattr(triple, 's')
|
||||||
files=[temp_turtle_file],
|
assert hasattr(triple, 'p')
|
||||||
flow="test-flow",
|
assert hasattr(triple, 'o')
|
||||||
user="test-user",
|
assert isinstance(triple.s, str)
|
||||||
collection="test-collection",
|
assert isinstance(triple.p, str)
|
||||||
document_id="test-doc"
|
assert isinstance(triple.o, str)
|
||||||
)
|
|
||||||
|
|
||||||
await loader.load_triples(temp_turtle_file, mock_websocket)
|
def test_load_entity_contexts_from_file_yields_literals_only(self, temp_turtle_file, knowledge_loader):
|
||||||
|
|
||||||
# Verify WebSocket send was called
|
|
||||||
assert mock_websocket.send.call_count > 0
|
|
||||||
|
|
||||||
# Check message format for one of the calls
|
|
||||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
|
||||||
|
|
||||||
# Verify message structure
|
|
||||||
sample_message = sent_messages[0]
|
|
||||||
assert "metadata" in sample_message
|
|
||||||
assert "triples" in sample_message
|
|
||||||
|
|
||||||
metadata = sample_message["metadata"]
|
|
||||||
assert metadata["id"] == "test-doc"
|
|
||||||
assert metadata["user"] == "test-user"
|
|
||||||
assert metadata["collection"] == "test-collection"
|
|
||||||
assert isinstance(metadata["metadata"], list)
|
|
||||||
|
|
||||||
triple = sample_message["triples"][0]
|
|
||||||
assert "s" in triple
|
|
||||||
assert "p" in triple
|
|
||||||
assert "o" in triple
|
|
||||||
|
|
||||||
# Check Value structure
|
|
||||||
assert "v" in triple["s"]
|
|
||||||
assert "e" in triple["s"]
|
|
||||||
assert triple["s"]["e"] is True # Subject should be URI
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_load_entity_contexts_processes_literals_only(self, temp_turtle_file, mock_websocket):
|
|
||||||
"""Test that entity contexts are created only for literals."""
|
"""Test that entity contexts are created only for literals."""
|
||||||
loader = KnowledgeLoader(
|
contexts = list(knowledge_loader.load_entity_contexts_from_file(temp_turtle_file))
|
||||||
files=[temp_turtle_file],
|
|
||||||
flow="test-flow",
|
|
||||||
user="test-user",
|
|
||||||
collection="test-collection",
|
|
||||||
document_id="test-doc"
|
|
||||||
)
|
|
||||||
|
|
||||||
await loader.load_entity_contexts(temp_turtle_file, mock_websocket)
|
# Should have contexts for literal objects (foaf:name, foaf:age, foaf:email)
|
||||||
|
assert len(contexts) > 0
|
||||||
|
|
||||||
# Get all sent messages
|
# Verify format: (entity, context) tuples
|
||||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
for entity, context in contexts:
|
||||||
|
assert isinstance(entity, str)
|
||||||
|
assert isinstance(context, str)
|
||||||
|
# Entity should be a URI (subject)
|
||||||
|
assert entity.startswith("http://")
|
||||||
|
|
||||||
# Verify we got entity context messages
|
def test_load_entity_contexts_skips_uri_objects(self):
|
||||||
assert len(sent_messages) > 0
|
|
||||||
|
|
||||||
for message in sent_messages:
|
|
||||||
assert "metadata" in message
|
|
||||||
assert "entities" in message
|
|
||||||
|
|
||||||
metadata = message["metadata"]
|
|
||||||
assert metadata["id"] == "test-doc"
|
|
||||||
assert metadata["user"] == "test-user"
|
|
||||||
assert metadata["collection"] == "test-collection"
|
|
||||||
|
|
||||||
entity_context = message["entities"][0]
|
|
||||||
assert "entity" in entity_context
|
|
||||||
assert "context" in entity_context
|
|
||||||
|
|
||||||
entity = entity_context["entity"]
|
|
||||||
assert "v" in entity
|
|
||||||
assert "e" in entity
|
|
||||||
assert entity["e"] is True # Entity should be URI (subject)
|
|
||||||
|
|
||||||
# Context should be a string (the literal value)
|
|
||||||
assert isinstance(entity_context["context"], str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_load_entity_contexts_skips_uri_objects(self, mock_websocket):
|
|
||||||
"""Test that URI objects don't generate entity contexts."""
|
"""Test that URI objects don't generate entity contexts."""
|
||||||
# Create turtle with only URI objects (no literals)
|
# Create turtle with only URI objects (no literals)
|
||||||
turtle_content = """
|
turtle_content = """
|
||||||
|
|
@ -208,63 +128,68 @@ ex:mary ex:knows ex:bob .
|
||||||
flow="test-flow",
|
flow="test-flow",
|
||||||
user="test-user",
|
user="test-user",
|
||||||
collection="test-collection",
|
collection="test-collection",
|
||||||
document_id="test-doc"
|
document_id="test-doc",
|
||||||
|
url="http://test.example.com/"
|
||||||
)
|
)
|
||||||
|
|
||||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
contexts = list(loader.load_entity_contexts_from_file(f.name))
|
||||||
|
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
||||||
# Should not send any messages since there are no literals
|
# Should have no contexts since there are no literals
|
||||||
mock_websocket.send.assert_not_called()
|
assert len(contexts) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@patch('trustgraph.cli.load_knowledge.Api')
|
||||||
@patch('trustgraph.cli.load_knowledge.connect')
|
def test_run_calls_bulk_api(self, mock_api_class, temp_turtle_file):
|
||||||
async def test_run_calls_both_loaders(self, mock_connect, knowledge_loader, temp_turtle_file):
|
"""Test that run() uses BulkClient API."""
|
||||||
"""Test that run() calls both triple and entity context loaders."""
|
# Setup mocks
|
||||||
knowledge_loader.files = [temp_turtle_file]
|
mock_api = MagicMock()
|
||||||
|
mock_bulk = MagicMock()
|
||||||
|
mock_api_class.return_value = mock_api
|
||||||
|
mock_api.bulk.return_value = mock_bulk
|
||||||
|
|
||||||
# Create a simple mock websocket
|
loader = KnowledgeLoader(
|
||||||
mock_ws = MagicMock()
|
files=[temp_turtle_file],
|
||||||
async def mock_send(data):
|
flow="test-flow",
|
||||||
pass
|
user="test-user",
|
||||||
mock_ws.send = mock_send
|
collection="test-collection",
|
||||||
|
document_id="test-doc",
|
||||||
|
url="http://test.example.com/",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
# Create async context manager mock
|
loader.run()
|
||||||
async def mock_aenter(self):
|
|
||||||
return mock_ws
|
|
||||||
|
|
||||||
async def mock_aexit(self, exc_type, exc_val, exc_tb):
|
# Verify Api was created with correct parameters
|
||||||
return None
|
mock_api_class.assert_called_once_with(
|
||||||
|
url="http://test.example.com/",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
mock_connection = MagicMock()
|
# Verify bulk client was obtained
|
||||||
mock_connection.__aenter__ = mock_aenter
|
mock_api.bulk.assert_called_once()
|
||||||
mock_connection.__aexit__ = mock_aexit
|
|
||||||
mock_connect.return_value = mock_connection
|
|
||||||
|
|
||||||
# Create AsyncMock objects that can track calls properly
|
# Verify import_triples was called
|
||||||
mock_load_triples = AsyncMock(return_value=None)
|
assert mock_bulk.import_triples.call_count == 1
|
||||||
mock_load_contexts = AsyncMock(return_value=None)
|
call_args = mock_bulk.import_triples.call_args
|
||||||
|
assert call_args[1]['flow'] == "test-flow"
|
||||||
|
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||||
|
assert call_args[1]['metadata']['user'] == "test-user"
|
||||||
|
assert call_args[1]['metadata']['collection'] == "test-collection"
|
||||||
|
|
||||||
with patch.object(knowledge_loader, 'load_triples', mock_load_triples), \
|
# Verify import_entity_contexts was called
|
||||||
patch.object(knowledge_loader, 'load_entity_contexts', mock_load_contexts):
|
assert mock_bulk.import_entity_contexts.call_count == 1
|
||||||
|
call_args = mock_bulk.import_entity_contexts.call_args
|
||||||
await knowledge_loader.run()
|
assert call_args[1]['flow'] == "test-flow"
|
||||||
|
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||||
# Verify both methods were called
|
|
||||||
mock_load_triples.assert_called_once_with(temp_turtle_file, mock_ws)
|
|
||||||
mock_load_contexts.assert_called_once_with(temp_turtle_file, mock_ws)
|
|
||||||
|
|
||||||
# Verify WebSocket connections were made to both URLs
|
|
||||||
assert mock_connect.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
class TestCLIArgumentParsing:
|
class TestCLIArgumentParsing:
|
||||||
"""Test CLI argument parsing and main function."""
|
"""Test CLI argument parsing and main function."""
|
||||||
|
|
||||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||||
def test_main_parses_args_correctly(self, mock_asyncio_run, mock_loader_class):
|
def test_main_parses_args_correctly(self, mock_sleep, mock_loader_class):
|
||||||
"""Test that main() parses arguments correctly."""
|
"""Test that main() parses arguments correctly."""
|
||||||
mock_loader_instance = MagicMock()
|
mock_loader_instance = MagicMock()
|
||||||
mock_loader_class.return_value = mock_loader_instance
|
mock_loader_class.return_value = mock_loader_instance
|
||||||
|
|
@ -275,7 +200,8 @@ class TestCLIArgumentParsing:
|
||||||
'-f', 'my-flow',
|
'-f', 'my-flow',
|
||||||
'-U', 'my-user',
|
'-U', 'my-user',
|
||||||
'-C', 'my-collection',
|
'-C', 'my-collection',
|
||||||
'-u', 'ws://custom.example.com/',
|
'-u', 'http://custom.example.com/',
|
||||||
|
'-t', 'my-token',
|
||||||
'file1.ttl',
|
'file1.ttl',
|
||||||
'file2.ttl'
|
'file2.ttl'
|
||||||
]
|
]
|
||||||
|
|
@ -286,19 +212,20 @@ class TestCLIArgumentParsing:
|
||||||
# Verify KnowledgeLoader was instantiated with correct args
|
# Verify KnowledgeLoader was instantiated with correct args
|
||||||
mock_loader_class.assert_called_once_with(
|
mock_loader_class.assert_called_once_with(
|
||||||
document_id='doc-123',
|
document_id='doc-123',
|
||||||
url='ws://custom.example.com/',
|
url='http://custom.example.com/',
|
||||||
|
token='my-token',
|
||||||
flow='my-flow',
|
flow='my-flow',
|
||||||
files=['file1.ttl', 'file2.ttl'],
|
files=['file1.ttl', 'file2.ttl'],
|
||||||
user='my-user',
|
user='my-user',
|
||||||
collection='my-collection'
|
collection='my-collection'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify asyncio.run was called once
|
# Verify run was called
|
||||||
mock_asyncio_run.assert_called_once()
|
mock_loader_instance.run.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||||
def test_main_uses_defaults(self, mock_asyncio_run, mock_loader_class):
|
def test_main_uses_defaults(self, mock_sleep, mock_loader_class):
|
||||||
"""Test that main() uses default values when not specified."""
|
"""Test that main() uses default values when not specified."""
|
||||||
mock_loader_instance = MagicMock()
|
mock_loader_instance = MagicMock()
|
||||||
mock_loader_class.return_value = mock_loader_instance
|
mock_loader_class.return_value = mock_loader_instance
|
||||||
|
|
@ -317,80 +244,69 @@ class TestCLIArgumentParsing:
|
||||||
assert call_args['flow'] == 'default'
|
assert call_args['flow'] == 'default'
|
||||||
assert call_args['user'] == 'trustgraph'
|
assert call_args['user'] == 'trustgraph'
|
||||||
assert call_args['collection'] == 'default'
|
assert call_args['collection'] == 'default'
|
||||||
assert call_args['url'] == 'ws://localhost:8088/'
|
assert call_args['url'] == 'http://localhost:8088/'
|
||||||
|
assert call_args['token'] is None
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
"""Test error handling scenarios."""
|
"""Test error handling scenarios."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_triples_handles_invalid_turtle(self, knowledge_loader):
|
||||||
async def test_load_triples_handles_invalid_turtle(self, mock_websocket):
|
|
||||||
"""Test handling of invalid Turtle content."""
|
"""Test handling of invalid Turtle content."""
|
||||||
# Create file with invalid Turtle content
|
# Create file with invalid Turtle content
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||||
f.write("Invalid Turtle Content {{{")
|
f.write("Invalid Turtle Content {{{")
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
loader = KnowledgeLoader(
|
|
||||||
files=[f.name],
|
|
||||||
flow="test-flow",
|
|
||||||
user="test-user",
|
|
||||||
collection="test-collection",
|
|
||||||
document_id="test-doc"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should raise an exception for invalid Turtle
|
# Should raise an exception for invalid Turtle
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await loader.load_triples(f.name, mock_websocket)
|
list(knowledge_loader.load_triples_from_file(f.name))
|
||||||
|
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_entity_contexts_handles_invalid_turtle(self, knowledge_loader):
|
||||||
async def test_load_entity_contexts_handles_invalid_turtle(self, mock_websocket):
|
|
||||||
"""Test handling of invalid Turtle content in entity contexts."""
|
"""Test handling of invalid Turtle content in entity contexts."""
|
||||||
# Create file with invalid Turtle content
|
# Create file with invalid Turtle content
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||||
f.write("Invalid Turtle Content {{{")
|
f.write("Invalid Turtle Content {{{")
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
loader = KnowledgeLoader(
|
|
||||||
files=[f.name],
|
|
||||||
flow="test-flow",
|
|
||||||
user="test-user",
|
|
||||||
collection="test-collection",
|
|
||||||
document_id="test-doc"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should raise an exception for invalid Turtle
|
# Should raise an exception for invalid Turtle
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||||
|
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@patch('trustgraph.cli.load_knowledge.Api')
|
||||||
@patch('trustgraph.cli.load_knowledge.connect')
|
|
||||||
@patch('builtins.print') # Mock print to avoid output during tests
|
@patch('builtins.print') # Mock print to avoid output during tests
|
||||||
async def test_run_handles_connection_errors(self, mock_print, mock_connect, knowledge_loader, temp_turtle_file):
|
def test_run_handles_api_errors(self, mock_print, mock_api_class, temp_turtle_file):
|
||||||
"""Test handling of WebSocket connection errors."""
|
"""Test handling of API errors."""
|
||||||
knowledge_loader.files = [temp_turtle_file]
|
# Mock API to raise an error
|
||||||
|
mock_api_class.side_effect = Exception("API connection failed")
|
||||||
|
|
||||||
# Mock connection failure
|
loader = KnowledgeLoader(
|
||||||
mock_connect.side_effect = ConnectionError("Failed to connect")
|
files=[temp_turtle_file],
|
||||||
|
flow="test-flow",
|
||||||
|
user="test-user",
|
||||||
|
collection="test-collection",
|
||||||
|
document_id="test-doc",
|
||||||
|
url="http://test.example.com/"
|
||||||
|
)
|
||||||
|
|
||||||
# Should not raise exception, just print error
|
# Should raise the exception
|
||||||
await knowledge_loader.run()
|
with pytest.raises(Exception, match="API connection failed"):
|
||||||
|
loader.run()
|
||||||
|
|
||||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
|
||||||
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||||
@patch('builtins.print') # Mock print to avoid output during tests
|
@patch('builtins.print') # Mock print to avoid output during tests
|
||||||
def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_asyncio_run, mock_loader_class):
|
def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_loader_class):
|
||||||
"""Test that main() retries on exceptions."""
|
"""Test that main() retries on exceptions."""
|
||||||
mock_loader_instance = MagicMock()
|
mock_loader_instance = MagicMock()
|
||||||
mock_loader_class.return_value = mock_loader_instance
|
mock_loader_class.return_value = mock_loader_instance
|
||||||
|
|
||||||
# First call raises exception, second succeeds
|
# First call raises exception, second succeeds
|
||||||
mock_asyncio_run.side_effect = [Exception("Test error"), None]
|
mock_loader_instance.run.side_effect = [Exception("Test error"), None]
|
||||||
|
|
||||||
test_args = [
|
test_args = [
|
||||||
'tg-load-knowledge',
|
'tg-load-knowledge',
|
||||||
|
|
@ -402,38 +318,29 @@ class TestErrorHandling:
|
||||||
main()
|
main()
|
||||||
|
|
||||||
# Should have been called twice (first failed, second succeeded)
|
# Should have been called twice (first failed, second succeeded)
|
||||||
assert mock_asyncio_run.call_count == 2
|
assert mock_loader_instance.run.call_count == 2
|
||||||
mock_sleep.assert_called_once_with(10)
|
mock_sleep.assert_called_once_with(10)
|
||||||
|
|
||||||
|
|
||||||
class TestDataValidation:
|
class TestDataValidation:
|
||||||
"""Test data validation and edge cases."""
|
"""Test data validation and edge cases."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_empty_turtle_file(self, knowledge_loader):
|
||||||
async def test_empty_turtle_file(self, mock_websocket):
|
|
||||||
"""Test handling of empty Turtle files."""
|
"""Test handling of empty Turtle files."""
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||||
f.write("") # Empty file
|
f.write("") # Empty file
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
loader = KnowledgeLoader(
|
triples = list(knowledge_loader.load_triples_from_file(f.name))
|
||||||
files=[f.name],
|
contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||||
flow="test-flow",
|
|
||||||
user="test-user",
|
|
||||||
collection="test-collection",
|
|
||||||
document_id="test-doc"
|
|
||||||
)
|
|
||||||
|
|
||||||
await loader.load_triples(f.name, mock_websocket)
|
# Should return empty lists for empty file
|
||||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
assert len(triples) == 0
|
||||||
|
assert len(contexts) == 0
|
||||||
# Should not send any messages for empty file
|
|
||||||
mock_websocket.send.assert_not_called()
|
|
||||||
|
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_turtle_with_mixed_literals_and_uris(self, knowledge_loader):
|
||||||
async def test_turtle_with_mixed_literals_and_uris(self, mock_websocket):
|
|
||||||
"""Test handling of Turtle with mixed literal and URI objects."""
|
"""Test handling of Turtle with mixed literal and URI objects."""
|
||||||
turtle_content = """
|
turtle_content = """
|
||||||
@prefix ex: <http://example.org/> .
|
@prefix ex: <http://example.org/> .
|
||||||
|
|
@ -448,32 +355,18 @@ ex:mary ex:name "Mary Johnson" .
|
||||||
f.write(turtle_content)
|
f.write(turtle_content)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
loader = KnowledgeLoader(
|
contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||||
files=[f.name],
|
|
||||||
flow="test-flow",
|
|
||||||
user="test-user",
|
|
||||||
collection="test-collection",
|
|
||||||
document_id="test-doc"
|
|
||||||
)
|
|
||||||
|
|
||||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
|
||||||
|
|
||||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
|
||||||
|
|
||||||
# Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson")
|
# Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson")
|
||||||
# URI ex:mary should be skipped
|
# URI ex:mary should be skipped
|
||||||
assert len(sent_messages) == 4
|
assert len(contexts) == 4
|
||||||
|
|
||||||
# Verify all contexts are for literals (subjects should be URIs)
|
# Verify all contexts are for literals (subjects should be URIs)
|
||||||
contexts = []
|
context_values = [context for entity, context in contexts]
|
||||||
for message in sent_messages:
|
|
||||||
entity_context = message["entities"][0]
|
|
||||||
assert entity_context["entity"]["e"] is True # Subject is URI
|
|
||||||
contexts.append(entity_context["context"])
|
|
||||||
|
|
||||||
assert "John Smith" in contexts
|
assert "John Smith" in context_values
|
||||||
assert "25" in contexts
|
assert "25" in context_values
|
||||||
assert "New York" in contexts
|
assert "New York" in context_values
|
||||||
assert "Mary Johnson" in contexts
|
assert "Mary Johnson" in context_values
|
||||||
|
|
||||||
Path(f.name).unlink(missing_ok=True)
|
Path(f.name).unlink(missing_ok=True)
|
||||||
|
|
@ -135,7 +135,8 @@ class TestSetToolStructuredQuery:
|
||||||
arguments=[],
|
arguments=[],
|
||||||
group=None,
|
group=None,
|
||||||
state=None,
|
state=None,
|
||||||
applicable_states=None
|
applicable_states=None,
|
||||||
|
token=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_set_main_structured_query_no_arguments_needed(self):
|
def test_set_main_structured_query_no_arguments_needed(self):
|
||||||
|
|
@ -313,7 +314,7 @@ class TestShowToolsStructuredQuery:
|
||||||
|
|
||||||
show_main()
|
show_main()
|
||||||
|
|
||||||
mock_show.assert_called_once_with(url='http://custom.com')
|
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||||
|
|
||||||
|
|
||||||
class TestStructuredQueryToolValidation:
|
class TestStructuredQueryToolValidation:
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,18 @@ class TestConfigReceiver:
|
||||||
|
|
||||||
def test_config_receiver_initialization(self):
|
def test_config_receiver_initialization(self):
|
||||||
"""Test ConfigReceiver initialization"""
|
"""Test ConfigReceiver initialization"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
|
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
assert config_receiver.pulsar_client == mock_pulsar_client
|
assert config_receiver.backend == mock_backend
|
||||||
assert config_receiver.flow_handlers == []
|
assert config_receiver.flow_handlers == []
|
||||||
assert config_receiver.flows == {}
|
assert config_receiver.flows == {}
|
||||||
|
|
||||||
def test_add_handler(self):
|
def test_add_handler(self):
|
||||||
"""Test adding flow handlers"""
|
"""Test adding flow handlers"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
handler1 = Mock()
|
handler1 = Mock()
|
||||||
handler2 = Mock()
|
handler2 = Mock()
|
||||||
|
|
@ -48,8 +48,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_with_new_flows(self):
|
async def test_on_config_with_new_flows(self):
|
||||||
"""Test on_config method with new flows"""
|
"""Test on_config method with new flows"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Track calls manually instead of using AsyncMock
|
# Track calls manually instead of using AsyncMock
|
||||||
start_flow_calls = []
|
start_flow_calls = []
|
||||||
|
|
@ -87,8 +87,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_with_removed_flows(self):
|
async def test_on_config_with_removed_flows(self):
|
||||||
"""Test on_config method with removed flows"""
|
"""Test on_config method with removed flows"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Pre-populate with existing flows
|
# Pre-populate with existing flows
|
||||||
config_receiver.flows = {
|
config_receiver.flows = {
|
||||||
|
|
@ -128,8 +128,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_with_no_flows(self):
|
async def test_on_config_with_no_flows(self):
|
||||||
"""Test on_config method with no flows in config"""
|
"""Test on_config method with no flows in config"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Mock the start_flow and stop_flow methods with async functions
|
# Mock the start_flow and stop_flow methods with async functions
|
||||||
async def mock_start_flow(*args):
|
async def mock_start_flow(*args):
|
||||||
|
|
@ -158,8 +158,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_exception_handling(self):
|
async def test_on_config_exception_handling(self):
|
||||||
"""Test on_config method handles exceptions gracefully"""
|
"""Test on_config method handles exceptions gracefully"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Create mock message that will cause an exception
|
# Create mock message that will cause an exception
|
||||||
mock_msg = Mock()
|
mock_msg = Mock()
|
||||||
|
|
@ -174,8 +174,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_flow_with_handlers(self):
|
async def test_start_flow_with_handlers(self):
|
||||||
"""Test start_flow method with multiple handlers"""
|
"""Test start_flow method with multiple handlers"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Add mock handlers
|
# Add mock handlers
|
||||||
handler1 = Mock()
|
handler1 = Mock()
|
||||||
|
|
@ -197,8 +197,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_flow_with_handler_exception(self):
|
async def test_start_flow_with_handler_exception(self):
|
||||||
"""Test start_flow method handles handler exceptions"""
|
"""Test start_flow method handles handler exceptions"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Add mock handler that raises exception
|
# Add mock handler that raises exception
|
||||||
handler = Mock()
|
handler = Mock()
|
||||||
|
|
@ -217,8 +217,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_flow_with_handlers(self):
|
async def test_stop_flow_with_handlers(self):
|
||||||
"""Test stop_flow method with multiple handlers"""
|
"""Test stop_flow method with multiple handlers"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Add mock handlers
|
# Add mock handlers
|
||||||
handler1 = Mock()
|
handler1 = Mock()
|
||||||
|
|
@ -240,8 +240,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_flow_with_handler_exception(self):
|
async def test_stop_flow_with_handler_exception(self):
|
||||||
"""Test stop_flow method handles handler exceptions"""
|
"""Test stop_flow method handles handler exceptions"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Add mock handler that raises exception
|
# Add mock handler that raises exception
|
||||||
handler = Mock()
|
handler = Mock()
|
||||||
|
|
@ -260,9 +260,9 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_config_loader_creates_consumer(self):
|
async def test_config_loader_creates_consumer(self):
|
||||||
"""Test config_loader method creates Pulsar consumer"""
|
"""Test config_loader method creates Pulsar consumer"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
|
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
# Temporarily restore the real config_loader for this test
|
# Temporarily restore the real config_loader for this test
|
||||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||||
|
|
||||||
|
|
@ -292,7 +292,7 @@ class TestConfigReceiver:
|
||||||
mock_consumer_class.assert_called_once()
|
mock_consumer_class.assert_called_once()
|
||||||
call_args = mock_consumer_class.call_args
|
call_args = mock_consumer_class.call_args
|
||||||
|
|
||||||
assert call_args[1]['client'] == mock_pulsar_client
|
assert call_args[1]['backend'] == mock_backend
|
||||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||||
assert call_args[1]['handler'] == config_receiver.on_config
|
assert call_args[1]['handler'] == config_receiver.on_config
|
||||||
assert call_args[1]['start_of_messages'] is True
|
assert call_args[1]['start_of_messages'] is True
|
||||||
|
|
@ -301,8 +301,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||||
"""Test start method creates config loader task"""
|
"""Test start method creates config loader task"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||||
mock_task = Mock()
|
mock_task = Mock()
|
||||||
|
|
@ -320,8 +320,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_mixed_flow_operations(self):
|
async def test_on_config_mixed_flow_operations(self):
|
||||||
"""Test on_config with mixed add/remove operations"""
|
"""Test on_config with mixed add/remove operations"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Pre-populate with existing flows
|
# Pre-populate with existing flows
|
||||||
config_receiver.flows = {
|
config_receiver.flows = {
|
||||||
|
|
@ -380,8 +380,8 @@ class TestConfigReceiver:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_config_invalid_json_flow_data(self):
|
async def test_on_config_invalid_json_flow_data(self):
|
||||||
"""Test on_config handles invalid JSON in flow data"""
|
"""Test on_config handles invalid JSON in flow data"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
config_receiver = ConfigReceiver(mock_backend)
|
||||||
|
|
||||||
# Mock the start_flow method with an async function
|
# Mock the start_flow method with an async function
|
||||||
async def mock_start_flow(*args):
|
async def mock_start_flow(*args):
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,10 @@ class TestConfigRequestor:
|
||||||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||||
|
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
|
|
||||||
requestor = ConfigRequestor(
|
requestor = ConfigRequestor(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
consumer="test-consumer",
|
consumer="test-consumer",
|
||||||
subscriber="test-subscriber",
|
subscriber="test-subscriber",
|
||||||
timeout=60
|
timeout=60
|
||||||
|
|
@ -55,7 +55,7 @@ class TestConfigRequestor:
|
||||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||||
patch.object(ServiceRequestor, 'process', return_value=None):
|
patch.object(ServiceRequestor, 'process', return_value=None):
|
||||||
requestor = ConfigRequestor(
|
requestor = ConfigRequestor(
|
||||||
pulsar_client=Mock(),
|
backend=Mock(),
|
||||||
consumer="test-consumer",
|
consumer="test-consumer",
|
||||||
subscriber="test-subscriber"
|
subscriber="test-subscriber"
|
||||||
)
|
)
|
||||||
|
|
@ -79,7 +79,7 @@ class TestConfigRequestor:
|
||||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||||
|
|
||||||
requestor = ConfigRequestor(
|
requestor = ConfigRequestor(
|
||||||
pulsar_client=Mock(),
|
backend=Mock(),
|
||||||
consumer="test-consumer",
|
consumer="test-consumer",
|
||||||
subscriber="test-subscriber"
|
subscriber="test-subscriber"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,12 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatcher_manager_initialization(self):
|
def test_dispatcher_manager_initialization(self):
|
||||||
"""Test DispatcherManager initialization"""
|
"""Test DispatcherManager initialization"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
|
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
assert manager.pulsar_client == mock_pulsar_client
|
assert manager.backend == mock_backend
|
||||||
assert manager.config_receiver == mock_config_receiver
|
assert manager.config_receiver == mock_config_receiver
|
||||||
assert manager.prefix == "api-gateway" # default prefix
|
assert manager.prefix == "api-gateway" # default prefix
|
||||||
assert manager.flows == {}
|
assert manager.flows == {}
|
||||||
|
|
@ -55,19 +55,19 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatcher_manager_initialization_with_custom_prefix(self):
|
def test_dispatcher_manager_initialization_with_custom_prefix(self):
|
||||||
"""Test DispatcherManager initialization with custom prefix"""
|
"""Test DispatcherManager initialization with custom prefix"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
|
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix")
|
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
|
||||||
|
|
||||||
assert manager.prefix == "custom-prefix"
|
assert manager.prefix == "custom-prefix"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_flow(self):
|
async def test_start_flow(self):
|
||||||
"""Test start_flow method"""
|
"""Test start_flow method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
flow_data = {"name": "test_flow", "steps": []}
|
flow_data = {"name": "test_flow", "steps": []}
|
||||||
|
|
||||||
|
|
@ -79,9 +79,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_flow(self):
|
async def test_stop_flow(self):
|
||||||
"""Test stop_flow method"""
|
"""Test stop_flow method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Pre-populate with a flow
|
# Pre-populate with a flow
|
||||||
flow_data = {"name": "test_flow", "steps": []}
|
flow_data = {"name": "test_flow", "steps": []}
|
||||||
|
|
@ -93,9 +93,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_global_service_returns_wrapper(self):
|
def test_dispatch_global_service_returns_wrapper(self):
|
||||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
wrapper = manager.dispatch_global_service()
|
wrapper = manager.dispatch_global_service()
|
||||||
|
|
||||||
|
|
@ -104,9 +104,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_core_export_returns_wrapper(self):
|
def test_dispatch_core_export_returns_wrapper(self):
|
||||||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
wrapper = manager.dispatch_core_export()
|
wrapper = manager.dispatch_core_export()
|
||||||
|
|
||||||
|
|
@ -115,9 +115,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_core_import_returns_wrapper(self):
|
def test_dispatch_core_import_returns_wrapper(self):
|
||||||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
wrapper = manager.dispatch_core_import()
|
wrapper = manager.dispatch_core_import()
|
||||||
|
|
||||||
|
|
@ -127,9 +127,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_core_import(self):
|
async def test_process_core_import(self):
|
||||||
"""Test process_core_import method"""
|
"""Test process_core_import method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||||
mock_importer = Mock()
|
mock_importer = Mock()
|
||||||
|
|
@ -138,16 +138,16 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
result = await manager.process_core_import("data", "error", "ok", "request")
|
result = await manager.process_core_import("data", "error", "ok", "request")
|
||||||
|
|
||||||
mock_core_import.assert_called_once_with(mock_pulsar_client)
|
mock_core_import.assert_called_once_with(mock_backend)
|
||||||
mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
|
mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
|
||||||
assert result == "import_result"
|
assert result == "import_result"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_core_export(self):
|
async def test_process_core_export(self):
|
||||||
"""Test process_core_export method"""
|
"""Test process_core_export method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||||
mock_exporter = Mock()
|
mock_exporter = Mock()
|
||||||
|
|
@ -156,16 +156,16 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
result = await manager.process_core_export("data", "error", "ok", "request")
|
result = await manager.process_core_export("data", "error", "ok", "request")
|
||||||
|
|
||||||
mock_core_export.assert_called_once_with(mock_pulsar_client)
|
mock_core_export.assert_called_once_with(mock_backend)
|
||||||
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
|
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
|
||||||
assert result == "export_result"
|
assert result == "export_result"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_global_service(self):
|
async def test_process_global_service(self):
|
||||||
"""Test process_global_service method"""
|
"""Test process_global_service method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||||
|
|
||||||
|
|
@ -178,9 +178,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_global_service_with_existing_dispatcher(self):
|
async def test_invoke_global_service_with_existing_dispatcher(self):
|
||||||
"""Test invoke_global_service with existing dispatcher"""
|
"""Test invoke_global_service with existing dispatcher"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Pre-populate with existing dispatcher
|
# Pre-populate with existing dispatcher
|
||||||
mock_dispatcher = Mock()
|
mock_dispatcher = Mock()
|
||||||
|
|
@ -195,9 +195,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_global_service_creates_new_dispatcher(self):
|
async def test_invoke_global_service_creates_new_dispatcher(self):
|
||||||
"""Test invoke_global_service creates new dispatcher"""
|
"""Test invoke_global_service creates new dispatcher"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||||
mock_dispatcher_class = Mock()
|
mock_dispatcher_class = Mock()
|
||||||
|
|
@ -211,10 +211,12 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
# Verify dispatcher was created with correct parameters
|
# Verify dispatcher was created with correct parameters
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
consumer="api-gateway-config-request",
|
consumer="api-gateway-config-request",
|
||||||
subscriber="api-gateway-config-request"
|
subscriber="api-gateway-config-request",
|
||||||
|
request_queue=None,
|
||||||
|
response_queue=None
|
||||||
)
|
)
|
||||||
mock_dispatcher.start.assert_called_once()
|
mock_dispatcher.start.assert_called_once()
|
||||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||||
|
|
@ -225,9 +227,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_flow_import_returns_method(self):
|
def test_dispatch_flow_import_returns_method(self):
|
||||||
"""Test dispatch_flow_import returns correct method"""
|
"""Test dispatch_flow_import returns correct method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
result = manager.dispatch_flow_import()
|
result = manager.dispatch_flow_import()
|
||||||
|
|
||||||
|
|
@ -235,9 +237,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_flow_export_returns_method(self):
|
def test_dispatch_flow_export_returns_method(self):
|
||||||
"""Test dispatch_flow_export returns correct method"""
|
"""Test dispatch_flow_export returns correct method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
result = manager.dispatch_flow_export()
|
result = manager.dispatch_flow_export()
|
||||||
|
|
||||||
|
|
@ -245,9 +247,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_socket_returns_method(self):
|
def test_dispatch_socket_returns_method(self):
|
||||||
"""Test dispatch_socket returns correct method"""
|
"""Test dispatch_socket returns correct method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
result = manager.dispatch_socket()
|
result = manager.dispatch_socket()
|
||||||
|
|
||||||
|
|
@ -255,9 +257,9 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
def test_dispatch_flow_service_returns_wrapper(self):
|
def test_dispatch_flow_service_returns_wrapper(self):
|
||||||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
wrapper = manager.dispatch_flow_service()
|
wrapper = manager.dispatch_flow_service()
|
||||||
|
|
||||||
|
|
@ -267,9 +269,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_flow_import_with_valid_flow_and_kind(self):
|
async def test_process_flow_import_with_valid_flow_and_kind(self):
|
||||||
"""Test process_flow_import with valid flow and kind"""
|
"""Test process_flow_import with valid flow and kind"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow
|
# Setup test flow
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -292,7 +294,7 @@ class TestDispatcherManager:
|
||||||
result = await manager.process_flow_import("ws", "running", params)
|
result = await manager.process_flow_import("ws", "running", params)
|
||||||
|
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
ws="ws",
|
ws="ws",
|
||||||
running="running",
|
running="running",
|
||||||
queue={"queue": "test_queue"}
|
queue={"queue": "test_queue"}
|
||||||
|
|
@ -303,9 +305,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_flow_import_with_invalid_flow(self):
|
async def test_process_flow_import_with_invalid_flow(self):
|
||||||
"""Test process_flow_import with invalid flow"""
|
"""Test process_flow_import with invalid flow"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||||
|
|
||||||
|
|
@ -318,9 +320,9 @@ class TestDispatcherManager:
|
||||||
import warnings
|
import warnings
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", RuntimeWarning)
|
warnings.simplefilter("ignore", RuntimeWarning)
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow
|
# Setup test flow
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -340,9 +342,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_flow_export_with_valid_flow_and_kind(self):
|
async def test_process_flow_export_with_valid_flow_and_kind(self):
|
||||||
"""Test process_flow_export with valid flow and kind"""
|
"""Test process_flow_export with valid flow and kind"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow
|
# Setup test flow
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -364,7 +366,7 @@ class TestDispatcherManager:
|
||||||
result = await manager.process_flow_export("ws", "running", params)
|
result = await manager.process_flow_export("ws", "running", params)
|
||||||
|
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
ws="ws",
|
ws="ws",
|
||||||
running="running",
|
running="running",
|
||||||
queue={"queue": "test_queue"},
|
queue={"queue": "test_queue"},
|
||||||
|
|
@ -376,9 +378,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_socket(self):
|
async def test_process_socket(self):
|
||||||
"""Test process_socket method"""
|
"""Test process_socket method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||||
mock_mux_instance = Mock()
|
mock_mux_instance = Mock()
|
||||||
|
|
@ -392,9 +394,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_flow_service(self):
|
async def test_process_flow_service(self):
|
||||||
"""Test process_flow_service method"""
|
"""Test process_flow_service method"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||||
|
|
||||||
|
|
@ -407,9 +409,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_with_existing_dispatcher(self):
|
async def test_invoke_flow_service_with_existing_dispatcher(self):
|
||||||
"""Test invoke_flow_service with existing dispatcher"""
|
"""Test invoke_flow_service with existing dispatcher"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Add flow to the flows dictionary
|
# Add flow to the flows dictionary
|
||||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||||
|
|
@ -427,9 +429,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_creates_request_response_dispatcher(self):
|
async def test_invoke_flow_service_creates_request_response_dispatcher(self):
|
||||||
"""Test invoke_flow_service creates request-response dispatcher"""
|
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow
|
# Setup test flow
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -454,7 +456,7 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
# Verify dispatcher was created with correct parameters
|
# Verify dispatcher was created with correct parameters
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
request_queue="agent_request_queue",
|
request_queue="agent_request_queue",
|
||||||
response_queue="agent_response_queue",
|
response_queue="agent_response_queue",
|
||||||
timeout=120,
|
timeout=120,
|
||||||
|
|
@ -471,9 +473,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_creates_sender_dispatcher(self):
|
async def test_invoke_flow_service_creates_sender_dispatcher(self):
|
||||||
"""Test invoke_flow_service creates sender dispatcher"""
|
"""Test invoke_flow_service creates sender dispatcher"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow
|
# Setup test flow
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -498,7 +500,7 @@ class TestDispatcherManager:
|
||||||
|
|
||||||
# Verify dispatcher was created with correct parameters
|
# Verify dispatcher was created with correct parameters
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue={"queue": "text_load_queue"}
|
queue={"queue": "text_load_queue"}
|
||||||
)
|
)
|
||||||
mock_dispatcher.start.assert_called_once()
|
mock_dispatcher.start.assert_called_once()
|
||||||
|
|
@ -511,9 +513,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_invalid_flow(self):
|
async def test_invoke_flow_service_invalid_flow(self):
|
||||||
"""Test invoke_flow_service with invalid flow"""
|
"""Test invoke_flow_service with invalid flow"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||||
|
|
@ -521,9 +523,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||||
"""Test invoke_flow_service with kind not supported by flow"""
|
"""Test invoke_flow_service with kind not supported by flow"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow without agent interface
|
# Setup test flow without agent interface
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
@ -538,9 +540,9 @@ class TestDispatcherManager:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_flow_service_invalid_kind(self):
|
async def test_invoke_flow_service_invalid_kind(self):
|
||||||
"""Test invoke_flow_service with invalid kind"""
|
"""Test invoke_flow_service with invalid kind"""
|
||||||
mock_pulsar_client = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||||
|
|
||||||
# Setup test flow with interface but unsupported kind
|
# Setup test flow with interface but unsupported kind
|
||||||
manager.flows["test_flow"] = {
|
manager.flows["test_flow"] = {
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@ class TestServiceRequestor:
|
||||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||||
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
|
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
|
||||||
"""Test ServiceRequestor initialization"""
|
"""Test ServiceRequestor initialization"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_request_schema = MagicMock()
|
mock_request_schema = MagicMock()
|
||||||
mock_response_schema = MagicMock()
|
mock_response_schema = MagicMock()
|
||||||
|
|
||||||
requestor = ServiceRequestor(
|
requestor = ServiceRequestor(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
request_queue="test-request-queue",
|
request_queue="test-request-queue",
|
||||||
request_schema=mock_request_schema,
|
request_schema=mock_request_schema,
|
||||||
response_queue="test-response-queue",
|
response_queue="test-response-queue",
|
||||||
|
|
@ -32,12 +32,12 @@ class TestServiceRequestor:
|
||||||
|
|
||||||
# Verify Publisher was created correctly
|
# Verify Publisher was created correctly
|
||||||
mock_publisher.assert_called_once_with(
|
mock_publisher.assert_called_once_with(
|
||||||
mock_pulsar_client, "test-request-queue", schema=mock_request_schema
|
mock_backend, "test-request-queue", schema=mock_request_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify Subscriber was created correctly
|
# Verify Subscriber was created correctly
|
||||||
mock_subscriber.assert_called_once_with(
|
mock_subscriber.assert_called_once_with(
|
||||||
mock_pulsar_client, "test-response-queue",
|
mock_backend, "test-response-queue",
|
||||||
"test-subscription", "test-consumer", mock_response_schema
|
"test-subscription", "test-consumer", mock_response_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -48,12 +48,12 @@ class TestServiceRequestor:
|
||||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||||
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
|
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
|
||||||
"""Test ServiceRequestor initialization with default parameters"""
|
"""Test ServiceRequestor initialization with default parameters"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_request_schema = MagicMock()
|
mock_request_schema = MagicMock()
|
||||||
mock_response_schema = MagicMock()
|
mock_response_schema = MagicMock()
|
||||||
|
|
||||||
requestor = ServiceRequestor(
|
requestor = ServiceRequestor(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
request_queue="test-queue",
|
request_queue="test-queue",
|
||||||
request_schema=mock_request_schema,
|
request_schema=mock_request_schema,
|
||||||
response_queue="response-queue",
|
response_queue="response-queue",
|
||||||
|
|
@ -62,7 +62,7 @@ class TestServiceRequestor:
|
||||||
|
|
||||||
# Verify default values
|
# Verify default values
|
||||||
mock_subscriber.assert_called_once_with(
|
mock_subscriber.assert_called_once_with(
|
||||||
mock_pulsar_client, "response-queue",
|
mock_backend, "response-queue",
|
||||||
"api-gateway", "api-gateway", mock_response_schema
|
"api-gateway", "api-gateway", mock_response_schema
|
||||||
)
|
)
|
||||||
assert requestor.timeout == 600 # Default timeout
|
assert requestor.timeout == 600 # Default timeout
|
||||||
|
|
@ -72,14 +72,14 @@ class TestServiceRequestor:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
|
async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
|
||||||
"""Test ServiceRequestor start method"""
|
"""Test ServiceRequestor start method"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_sub_instance = AsyncMock()
|
mock_sub_instance = AsyncMock()
|
||||||
mock_pub_instance = AsyncMock()
|
mock_pub_instance = AsyncMock()
|
||||||
mock_subscriber.return_value = mock_sub_instance
|
mock_subscriber.return_value = mock_sub_instance
|
||||||
mock_publisher.return_value = mock_pub_instance
|
mock_publisher.return_value = mock_pub_instance
|
||||||
|
|
||||||
requestor = ServiceRequestor(
|
requestor = ServiceRequestor(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
request_queue="test-queue",
|
request_queue="test-queue",
|
||||||
request_schema=MagicMock(),
|
request_schema=MagicMock(),
|
||||||
response_queue="response-queue",
|
response_queue="response-queue",
|
||||||
|
|
@ -98,14 +98,14 @@ class TestServiceRequestor:
|
||||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||||
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
|
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
|
||||||
"""Test ServiceRequestor has correct attributes"""
|
"""Test ServiceRequestor has correct attributes"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pub_instance = AsyncMock()
|
mock_pub_instance = AsyncMock()
|
||||||
mock_sub_instance = AsyncMock()
|
mock_sub_instance = AsyncMock()
|
||||||
mock_publisher.return_value = mock_pub_instance
|
mock_publisher.return_value = mock_pub_instance
|
||||||
mock_subscriber.return_value = mock_sub_instance
|
mock_subscriber.return_value = mock_sub_instance
|
||||||
|
|
||||||
requestor = ServiceRequestor(
|
requestor = ServiceRequestor(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
request_queue="test-queue",
|
request_queue="test-queue",
|
||||||
request_schema=MagicMock(),
|
request_schema=MagicMock(),
|
||||||
response_queue="response-queue",
|
response_queue="response-queue",
|
||||||
|
|
|
||||||
|
|
@ -14,18 +14,18 @@ class TestServiceSender:
|
||||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||||
def test_service_sender_initialization(self, mock_publisher):
|
def test_service_sender_initialization(self, mock_publisher):
|
||||||
"""Test ServiceSender initialization"""
|
"""Test ServiceSender initialization"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_schema = MagicMock()
|
mock_schema = MagicMock()
|
||||||
|
|
||||||
sender = ServiceSender(
|
sender = ServiceSender(
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=mock_schema
|
schema=mock_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify Publisher was created correctly
|
# Verify Publisher was created correctly
|
||||||
mock_publisher.assert_called_once_with(
|
mock_publisher.assert_called_once_with(
|
||||||
mock_pulsar_client, "test-queue", schema=mock_schema
|
mock_backend, "test-queue", schema=mock_schema
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||||
|
|
@ -36,7 +36,7 @@ class TestServiceSender:
|
||||||
mock_publisher.return_value = mock_pub_instance
|
mock_publisher.return_value = mock_pub_instance
|
||||||
|
|
||||||
sender = ServiceSender(
|
sender = ServiceSender(
|
||||||
pulsar_client=MagicMock(),
|
backend=MagicMock(),
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=MagicMock()
|
schema=MagicMock()
|
||||||
)
|
)
|
||||||
|
|
@ -55,7 +55,7 @@ class TestServiceSender:
|
||||||
mock_publisher.return_value = mock_pub_instance
|
mock_publisher.return_value = mock_pub_instance
|
||||||
|
|
||||||
sender = ServiceSender(
|
sender = ServiceSender(
|
||||||
pulsar_client=MagicMock(),
|
backend=MagicMock(),
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=MagicMock()
|
schema=MagicMock()
|
||||||
)
|
)
|
||||||
|
|
@ -70,7 +70,7 @@ class TestServiceSender:
|
||||||
def test_service_sender_to_request_not_implemented(self, mock_publisher):
|
def test_service_sender_to_request_not_implemented(self, mock_publisher):
|
||||||
"""Test ServiceSender to_request method raises RuntimeError"""
|
"""Test ServiceSender to_request method raises RuntimeError"""
|
||||||
sender = ServiceSender(
|
sender = ServiceSender(
|
||||||
pulsar_client=MagicMock(),
|
backend=MagicMock(),
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=MagicMock()
|
schema=MagicMock()
|
||||||
)
|
)
|
||||||
|
|
@ -91,7 +91,7 @@ class TestServiceSender:
|
||||||
return {"processed": request}
|
return {"processed": request}
|
||||||
|
|
||||||
sender = ConcreteSender(
|
sender = ConcreteSender(
|
||||||
pulsar_client=MagicMock(),
|
backend=MagicMock(),
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=MagicMock()
|
schema=MagicMock()
|
||||||
)
|
)
|
||||||
|
|
@ -111,7 +111,7 @@ class TestServiceSender:
|
||||||
mock_publisher.return_value = mock_pub_instance
|
mock_publisher.return_value = mock_pub_instance
|
||||||
|
|
||||||
sender = ServiceSender(
|
sender = ServiceSender(
|
||||||
pulsar_client=MagicMock(),
|
backend=MagicMock(),
|
||||||
queue="test-queue",
|
queue="test-queue",
|
||||||
schema=MagicMock()
|
schema=MagicMock()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from trustgraph.schema import Metadata, ExtractedObject
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pulsar_client():
|
def mock_backend():
|
||||||
"""Mock Pulsar client."""
|
"""Mock Pulsar client."""
|
||||||
client = Mock()
|
client = Mock()
|
||||||
return client
|
return client
|
||||||
|
|
@ -96,7 +96,7 @@ class TestObjectsImportInitialization:
|
||||||
"""Test ObjectsImport initialization."""
|
"""Test ObjectsImport initialization."""
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_class.return_value = mock_publisher_instance
|
mock_publisher_class.return_value = mock_publisher_instance
|
||||||
|
|
@ -104,13 +104,13 @@ class TestObjectsImportInitialization:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-objects-queue"
|
queue="test-objects-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify Publisher was created with correct parameters
|
# Verify Publisher was created with correct parameters
|
||||||
mock_publisher_class.assert_called_once_with(
|
mock_publisher_class.assert_called_once_with(
|
||||||
mock_pulsar_client,
|
mock_backend,
|
||||||
topic="test-objects-queue",
|
topic="test-objects-queue",
|
||||||
schema=ExtractedObject
|
schema=ExtractedObject
|
||||||
)
|
)
|
||||||
|
|
@ -121,12 +121,12 @@ class TestObjectsImportInitialization:
|
||||||
assert objects_import.publisher == mock_publisher_instance
|
assert objects_import.publisher == mock_publisher_instance
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that ObjectsImport stores all required references."""
|
"""Test that ObjectsImport stores all required references."""
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="objects-queue"
|
queue="objects-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -139,7 +139,7 @@ class TestObjectsImportLifecycle:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that start() calls publisher.start()."""
|
"""Test that start() calls publisher.start()."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.start = AsyncMock()
|
mock_publisher_instance.start = AsyncMock()
|
||||||
|
|
@ -148,7 +148,7 @@ class TestObjectsImportLifecycle:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -158,7 +158,7 @@ class TestObjectsImportLifecycle:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.stop = AsyncMock()
|
mock_publisher_instance.stop = AsyncMock()
|
||||||
|
|
@ -167,7 +167,7 @@ class TestObjectsImportLifecycle:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ class TestObjectsImportLifecycle:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running):
|
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||||
"""Test that destroy() handles None websocket gracefully."""
|
"""Test that destroy() handles None websocket gracefully."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.stop = AsyncMock()
|
mock_publisher_instance.stop = AsyncMock()
|
||||||
|
|
@ -189,7 +189,7 @@ class TestObjectsImportLifecycle:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=None, # None websocket
|
ws=None, # None websocket
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -205,7 +205,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||||
"""Test that receive() processes complete message correctly."""
|
"""Test that receive() processes complete message correctly."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
|
|
@ -214,7 +214,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -248,7 +248,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message):
|
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||||
"""Test that receive() handles message with minimal required fields."""
|
"""Test that receive() handles message with minimal required fields."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
|
|
@ -257,7 +257,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -281,7 +281,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() uses appropriate default values for optional fields."""
|
"""Test that receive() uses appropriate default values for optional fields."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
|
|
@ -290,7 +290,7 @@ class TestObjectsImportMessageProcessing:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -323,7 +323,7 @@ class TestObjectsImportRunMethod:
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that run() loops while running.get() returns True."""
|
"""Test that run() loops while running.get() returns True."""
|
||||||
mock_sleep.return_value = None
|
mock_sleep.return_value = None
|
||||||
mock_publisher_class.return_value = Mock()
|
mock_publisher_class.return_value = Mock()
|
||||||
|
|
@ -334,7 +334,7 @@ class TestObjectsImportRunMethod:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -353,7 +353,7 @@ class TestObjectsImportRunMethod:
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running):
|
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||||
"""Test that run() handles None websocket gracefully."""
|
"""Test that run() handles None websocket gracefully."""
|
||||||
mock_sleep.return_value = None
|
mock_sleep.return_value = None
|
||||||
mock_publisher_class.return_value = Mock()
|
mock_publisher_class.return_value = Mock()
|
||||||
|
|
@ -363,7 +363,7 @@ class TestObjectsImportRunMethod:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=None, # None websocket
|
ws=None, # None websocket
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -417,7 +417,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
|
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||||
"""Test that receive() processes batch message correctly."""
|
"""Test that receive() processes batch message correctly."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
|
|
@ -426,7 +426,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -467,7 +467,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() handles empty batch correctly."""
|
"""Test that receive() handles empty batch correctly."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock()
|
mock_publisher_instance.send = AsyncMock()
|
||||||
|
|
@ -476,7 +476,7 @@ class TestObjectsImportBatchProcessing:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -507,7 +507,7 @@ class TestObjectsImportErrorHandling:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||||
"""Test that receive() propagates publisher send errors."""
|
"""Test that receive() propagates publisher send errors."""
|
||||||
mock_publisher_instance = Mock()
|
mock_publisher_instance = Mock()
|
||||||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||||
|
|
@ -516,7 +516,7 @@ class TestObjectsImportErrorHandling:
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -528,14 +528,14 @@ class TestObjectsImportErrorHandling:
|
||||||
|
|
||||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||||
"""Test that receive() handles malformed JSON appropriately."""
|
"""Test that receive() handles malformed JSON appropriately."""
|
||||||
mock_publisher_class.return_value = Mock()
|
mock_publisher_class.return_value = Mock()
|
||||||
|
|
||||||
objects_import = ObjectsImport(
|
objects_import = ObjectsImport(
|
||||||
ws=mock_websocket,
|
ws=mock_websocket,
|
||||||
running=mock_running,
|
running=mock_running,
|
||||||
pulsar_client=mock_pulsar_client,
|
backend=mock_backend,
|
||||||
queue="test-queue"
|
queue="test-queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ class TestApi:
|
||||||
|
|
||||||
def test_api_initialization_with_defaults(self):
|
def test_api_initialization_with_defaults(self):
|
||||||
"""Test Api initialization with default values"""
|
"""Test Api initialization with default values"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_backend = Mock()
|
||||||
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
api = Api()
|
api = Api()
|
||||||
|
|
||||||
|
|
@ -31,11 +32,8 @@ class TestApi:
|
||||||
assert api.prometheus_url == default_prometheus_url + "/"
|
assert api.prometheus_url == default_prometheus_url + "/"
|
||||||
assert api.auth.allow_all is True
|
assert api.auth.allow_all is True
|
||||||
|
|
||||||
# Verify Pulsar client was created without API key
|
# Verify get_pubsub was called
|
||||||
mock_client.assert_called_once_with(
|
mock_get_pubsub.assert_called_once()
|
||||||
default_pulsar_host,
|
|
||||||
listener_name=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_api_initialization_with_custom_config(self):
|
def test_api_initialization_with_custom_config(self):
|
||||||
"""Test Api initialization with custom configuration"""
|
"""Test Api initialization with custom configuration"""
|
||||||
|
|
@ -49,10 +47,9 @@ class TestApi:
|
||||||
"api_token": "secret-token"
|
"api_token": "secret-token"
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch('pulsar.Client') as mock_client, \
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
mock_backend = Mock()
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = mock_backend
|
||||||
mock_auth.return_value = Mock()
|
|
||||||
|
|
||||||
api = Api(**config)
|
api = Api(**config)
|
||||||
|
|
||||||
|
|
@ -64,34 +61,24 @@ class TestApi:
|
||||||
assert api.auth.token == "secret-token"
|
assert api.auth.token == "secret-token"
|
||||||
assert api.auth.allow_all is False
|
assert api.auth.allow_all is False
|
||||||
|
|
||||||
# Verify Pulsar client was created with API key
|
# Verify get_pubsub was called with config
|
||||||
mock_auth.assert_called_once_with("test-api-key")
|
mock_get_pubsub.assert_called_once_with(**config)
|
||||||
mock_client.assert_called_once_with(
|
|
||||||
"pulsar://custom-host:6650",
|
|
||||||
listener_name="custom-listener",
|
|
||||||
authentication=mock_auth.return_value
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_api_initialization_with_pulsar_api_key(self):
|
def test_api_initialization_with_pulsar_api_key(self):
|
||||||
"""Test Api initialization with Pulsar API key authentication"""
|
"""Test Api initialization with Pulsar API key authentication"""
|
||||||
with patch('pulsar.Client') as mock_client, \
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
mock_get_pubsub.return_value = Mock()
|
||||||
mock_client.return_value = Mock()
|
|
||||||
mock_auth.return_value = Mock()
|
|
||||||
|
|
||||||
api = Api(pulsar_api_key="test-key")
|
api = Api(pulsar_api_key="test-key")
|
||||||
|
|
||||||
mock_auth.assert_called_once_with("test-key")
|
# Verify api key was stored
|
||||||
mock_client.assert_called_once_with(
|
assert api.pulsar_api_key == "test-key"
|
||||||
default_pulsar_host,
|
mock_get_pubsub.assert_called_once()
|
||||||
listener_name=None,
|
|
||||||
authentication=mock_auth.return_value
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_api_initialization_prometheus_url_normalization(self):
|
def test_api_initialization_prometheus_url_normalization(self):
|
||||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
# Test URL without trailing slash
|
# Test URL without trailing slash
|
||||||
api = Api(prometheus_url="http://prometheus:9090")
|
api = Api(prometheus_url="http://prometheus:9090")
|
||||||
|
|
@ -103,16 +90,16 @@ class TestApi:
|
||||||
|
|
||||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||||
"""Test that empty API token results in allow_all authentication"""
|
"""Test that empty API token results in allow_all authentication"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api(api_token="")
|
api = Api(api_token="")
|
||||||
assert api.auth.allow_all is True
|
assert api.auth.allow_all is True
|
||||||
|
|
||||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||||
"""Test that None API token results in allow_all authentication"""
|
"""Test that None API token results in allow_all authentication"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api(api_token=None)
|
api = Api(api_token=None)
|
||||||
assert api.auth.allow_all is True
|
assert api.auth.allow_all is True
|
||||||
|
|
@ -120,8 +107,8 @@ class TestApi:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_app_factory_creates_application(self):
|
async def test_app_factory_creates_application(self):
|
||||||
"""Test that app_factory creates aiohttp application"""
|
"""Test that app_factory creates aiohttp application"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api()
|
api = Api()
|
||||||
|
|
||||||
|
|
@ -147,8 +134,8 @@ class TestApi:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_app_factory_with_custom_endpoints(self):
|
async def test_app_factory_with_custom_endpoints(self):
|
||||||
"""Test app_factory with custom endpoints"""
|
"""Test app_factory with custom endpoints"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api()
|
api = Api()
|
||||||
|
|
||||||
|
|
@ -180,9 +167,9 @@ class TestApi:
|
||||||
|
|
||||||
def test_run_method_calls_web_run_app(self):
|
def test_run_method_calls_web_run_app(self):
|
||||||
"""Test that run method calls web.run_app"""
|
"""Test that run method calls web.run_app"""
|
||||||
with patch('pulsar.Client') as mock_client, \
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
|
||||||
patch('aiohttp.web.run_app') as mock_run_app:
|
patch('aiohttp.web.run_app') as mock_run_app:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api(port=8080)
|
api = Api(port=8080)
|
||||||
api.run()
|
api.run()
|
||||||
|
|
@ -195,8 +182,8 @@ class TestApi:
|
||||||
|
|
||||||
def test_api_components_initialization(self):
|
def test_api_components_initialization(self):
|
||||||
"""Test that all API components are properly initialized"""
|
"""Test that all API components are properly initialized"""
|
||||||
with patch('pulsar.Client') as mock_client:
|
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||||
mock_client.return_value = Mock()
|
mock_get_pubsub.return_value = Mock()
|
||||||
|
|
||||||
api = Api()
|
api = Api()
|
||||||
|
|
||||||
|
|
@ -207,7 +194,7 @@ class TestApi:
|
||||||
assert api.endpoints == []
|
assert api.endpoints == []
|
||||||
|
|
||||||
# Verify component relationships
|
# Verify component relationships
|
||||||
assert api.dispatcher_manager.pulsar_client == api.pulsar_client
|
assert api.dispatcher_manager.backend == api.pubsub_backend
|
||||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,17 @@ async def test_handle_normal_flow():
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
# Consume the coroutine to avoid "was never awaited" warning
|
||||||
|
coro.close()
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
@ -176,11 +186,25 @@ async def test_handle_exception_group_cleanup():
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||||
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
# Consume the coroutine to avoid "was never awaited" warning
|
||||||
|
coro.close()
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
||||||
mock_wait_for.return_value = None
|
# Make wait_for consume the coroutine passed to it
|
||||||
|
async def wait_for_side_effect(coro, timeout=None):
|
||||||
|
coro.close() # Consume the coroutine
|
||||||
|
return None
|
||||||
|
mock_wait_for.side_effect = wait_for_side_effect
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
|
@ -227,12 +251,26 @@ async def test_handle_dispatcher_cleanup_timeout():
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||||
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
# Consume the coroutine to avoid "was never awaited" warning
|
||||||
|
coro.close()
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
# Mock asyncio.wait_for to raise TimeoutError
|
# Mock asyncio.wait_for to raise TimeoutError
|
||||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
||||||
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
# Make wait_for consume the coroutine before raising
|
||||||
|
async def wait_for_timeout(coro, timeout=None):
|
||||||
|
coro.close() # Consume the coroutine
|
||||||
|
raise asyncio.TimeoutError("Cleanup timeout")
|
||||||
|
mock_wait_for.side_effect = wait_for_timeout
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
|
|
@ -314,7 +352,17 @@ async def test_handle_websocket_already_closed():
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
# Consume the coroutine to avoid "was never awaited" warning
|
||||||
|
coro.close()
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
|
||||||
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
|
|
@ -0,0 +1,326 @@
|
||||||
|
"""
|
||||||
|
Unit tests for streaming behavior in message translators.
|
||||||
|
|
||||||
|
These tests verify that translators correctly handle empty strings and
|
||||||
|
end_of_stream flags in streaming responses, preventing bugs where empty
|
||||||
|
final chunks could be dropped due to falsy value checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from trustgraph.messaging.translators.retrieval import (
|
||||||
|
GraphRagResponseTranslator,
|
||||||
|
DocumentRagResponseTranslator,
|
||||||
|
)
|
||||||
|
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
|
||||||
|
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
|
||||||
|
from trustgraph.schema import (
|
||||||
|
GraphRagResponse,
|
||||||
|
DocumentRagResponse,
|
||||||
|
PromptResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphRagResponseTranslator:
|
||||||
|
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_response(self):
|
||||||
|
"""Test that empty response strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - Empty string should be included in result
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_response(self):
|
||||||
|
"""Test that non-empty responses work correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response="Some text",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["response"] == "Some text"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
def test_from_pulsar_with_none_response(self):
|
||||||
|
"""Test that None response is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - None should not be included
|
||||||
|
assert "response" not in result
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_response_with_completion_returns_correct_flag(self):
|
||||||
|
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
|
||||||
|
# Test non-final chunk
|
||||||
|
response_chunk = GraphRagResponse(
|
||||||
|
response="chunk",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is False
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
# Test final chunk with empty content
|
||||||
|
final_response = GraphRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(final_response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is True
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentRagResponseTranslator:
|
||||||
|
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_response(self):
|
||||||
|
"""Test that empty response strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = DocumentRagResponseTranslator()
|
||||||
|
response = DocumentRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_response(self):
|
||||||
|
"""Test that non-empty responses work correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = DocumentRagResponseTranslator()
|
||||||
|
response = DocumentRagResponse(
|
||||||
|
response="Document content",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["response"] == "Document content"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptResponseTranslator:
|
||||||
|
"""Test PromptResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_text(self):
|
||||||
|
"""Test that empty text strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text="",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "text" in result
|
||||||
|
assert result["text"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_text(self):
|
||||||
|
"""Test that non-empty text works correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text="Some prompt response",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["text"] == "Some prompt response"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
def test_from_pulsar_with_none_text(self):
|
||||||
|
"""Test that None text is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text=None,
|
||||||
|
object='{"result": "data"}',
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "text" not in result
|
||||||
|
assert "object" in result
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_includes_end_of_stream(self):
|
||||||
|
"""Test that end_of_stream flag is always included"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
|
||||||
|
# Test with end_of_stream=False
|
||||||
|
response = PromptResponse(
|
||||||
|
text="chunk",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "end_of_stream" in result
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextCompletionResponseTranslator:
|
||||||
|
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_always_includes_response(self):
|
||||||
|
"""Test that response field is always included, even if empty"""
|
||||||
|
# Arrange
|
||||||
|
translator = TextCompletionResponseTranslator()
|
||||||
|
response = TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None,
|
||||||
|
in_token=100,
|
||||||
|
out_token=5,
|
||||||
|
model="test-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - Response should always be present
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
|
||||||
|
def test_from_response_with_completion_with_empty_final(self):
|
||||||
|
"""Test that empty final response is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = TextCompletionResponseTranslator()
|
||||||
|
response = TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None,
|
||||||
|
in_token=100,
|
||||||
|
out_token=5,
|
||||||
|
model="test-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is True
|
||||||
|
assert result["response"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingProtocolCompliance:
|
||||||
|
"""Test that all translators follow streaming protocol conventions"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||||
|
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||||
|
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||||
|
(PromptResponseTranslator, PromptResponse, "text"),
|
||||||
|
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||||
|
])
|
||||||
|
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
|
||||||
|
"""Test that all translators preserve empty final chunks"""
|
||||||
|
# Arrange
|
||||||
|
translator = translator_class()
|
||||||
|
kwargs = {
|
||||||
|
field_name: "",
|
||||||
|
"end_of_stream": True,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
response = response_class(**kwargs)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||||
|
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||||
|
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||||
|
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||||
|
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||||
|
])
|
||||||
|
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
|
||||||
|
"""Test that end_of_stream flag is included in all response translators"""
|
||||||
|
# Arrange
|
||||||
|
translator = translator_class()
|
||||||
|
kwargs = {
|
||||||
|
field_name: "test content",
|
||||||
|
"end_of_stream": True,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
response = response_class(**kwargs)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
446
tests/unit/test_python_api_client.py
Normal file
446
tests/unit/test_python_api_client.py
Normal file
|
|
@ -0,0 +1,446 @@
|
||||||
|
"""
|
||||||
|
Unit tests for TrustGraph Python API client library
|
||||||
|
|
||||||
|
These tests use mocks and do not require a running server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, MagicMock, call
|
||||||
|
import json
|
||||||
|
|
||||||
|
from trustgraph.api import (
|
||||||
|
Api,
|
||||||
|
Triple,
|
||||||
|
AgentThought,
|
||||||
|
AgentObservation,
|
||||||
|
AgentAnswer,
|
||||||
|
RAGChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiInstantiation:
|
||||||
|
"""Test Api class instantiation and configuration"""
|
||||||
|
|
||||||
|
def test_api_instantiation_defaults(self):
|
||||||
|
"""Test Api with default parameters"""
|
||||||
|
api = Api()
|
||||||
|
assert api.url == "http://localhost:8088/api/v1/"
|
||||||
|
assert api.timeout == 60
|
||||||
|
assert api.token is None
|
||||||
|
|
||||||
|
def test_api_instantiation_with_url(self):
|
||||||
|
"""Test Api with custom URL"""
|
||||||
|
api = Api(url="http://test-server:9000/")
|
||||||
|
assert api.url == "http://test-server:9000/api/v1/"
|
||||||
|
|
||||||
|
def test_api_instantiation_with_url_trailing_slash(self):
|
||||||
|
"""Test Api adds trailing slash if missing"""
|
||||||
|
api = Api(url="http://test-server:9000")
|
||||||
|
assert api.url == "http://test-server:9000/api/v1/"
|
||||||
|
|
||||||
|
def test_api_instantiation_with_token(self):
|
||||||
|
"""Test Api with authentication token"""
|
||||||
|
api = Api(token="test-token-123")
|
||||||
|
assert api.token == "test-token-123"
|
||||||
|
|
||||||
|
def test_api_instantiation_with_timeout(self):
|
||||||
|
"""Test Api with custom timeout"""
|
||||||
|
api = Api(timeout=120)
|
||||||
|
assert api.timeout == 120
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiLazyInitialization:
|
||||||
|
"""Test lazy initialization of client components"""
|
||||||
|
|
||||||
|
def test_socket_client_lazy_init(self):
|
||||||
|
"""Test socket client is created on first access"""
|
||||||
|
api = Api(url="http://test/", token="token")
|
||||||
|
|
||||||
|
assert api._socket_client is None
|
||||||
|
socket = api.socket()
|
||||||
|
assert api._socket_client is not None
|
||||||
|
assert socket is api._socket_client
|
||||||
|
|
||||||
|
# Second access returns same instance
|
||||||
|
socket2 = api.socket()
|
||||||
|
assert socket2 is socket
|
||||||
|
|
||||||
|
def test_bulk_client_lazy_init(self):
|
||||||
|
"""Test bulk client is created on first access"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
|
||||||
|
assert api._bulk_client is None
|
||||||
|
bulk = api.bulk()
|
||||||
|
assert api._bulk_client is not None
|
||||||
|
|
||||||
|
def test_async_flow_lazy_init(self):
|
||||||
|
"""Test async flow is created on first access"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
|
||||||
|
assert api._async_flow is None
|
||||||
|
async_flow = api.async_flow()
|
||||||
|
assert api._async_flow is not None
|
||||||
|
|
||||||
|
def test_metrics_lazy_init(self):
|
||||||
|
"""Test metrics client is created on first access"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
|
||||||
|
assert api._metrics is None
|
||||||
|
metrics = api.metrics()
|
||||||
|
assert api._metrics is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiContextManager:
|
||||||
|
"""Test context manager functionality"""
|
||||||
|
|
||||||
|
def test_sync_context_manager(self):
|
||||||
|
"""Test synchronous context manager"""
|
||||||
|
with Api(url="http://test/") as api:
|
||||||
|
assert api is not None
|
||||||
|
assert isinstance(api, Api)
|
||||||
|
# Should exit cleanly
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_context_manager(self):
|
||||||
|
"""Test asynchronous context manager"""
|
||||||
|
async with Api(url="http://test/") as api:
|
||||||
|
assert api is not None
|
||||||
|
assert isinstance(api, Api)
|
||||||
|
# Should exit cleanly
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlowClient:
|
||||||
|
"""Test Flow client functionality"""
|
||||||
|
|
||||||
|
@patch('requests.post')
|
||||||
|
def test_flow_list(self, mock_post):
|
||||||
|
"""Test listing flows"""
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
mock_post.return_value.json.return_value = {"flow-ids": ["flow1", "flow2"]}
|
||||||
|
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
flows = api.flow().list()
|
||||||
|
|
||||||
|
assert flows == ["flow1", "flow2"]
|
||||||
|
assert mock_post.called
|
||||||
|
|
||||||
|
@patch('requests.post')
|
||||||
|
def test_flow_list_with_token(self, mock_post):
|
||||||
|
"""Test flow listing includes auth token"""
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
mock_post.return_value.json.return_value = {"flow-ids": []}
|
||||||
|
|
||||||
|
api = Api(url="http://test/", token="my-token")
|
||||||
|
api.flow().list()
|
||||||
|
|
||||||
|
# Verify Authorization header was set
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
headers = call_args[1]['headers'] if 'headers' in call_args[1] else {}
|
||||||
|
assert 'Authorization' in headers
|
||||||
|
assert headers['Authorization'] == 'Bearer my-token'
|
||||||
|
|
||||||
|
@patch('requests.post')
|
||||||
|
def test_flow_get(self, mock_post):
|
||||||
|
"""Test getting flow definition"""
|
||||||
|
flow_def = {"name": "test-flow", "description": "Test"}
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
mock_post.return_value.json.return_value = {"flow": json.dumps(flow_def)}
|
||||||
|
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
result = api.flow().get("test-flow")
|
||||||
|
|
||||||
|
assert result == flow_def
|
||||||
|
|
||||||
|
def test_flow_instance_creation(self):
|
||||||
|
"""Test creating flow instance"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
flow_instance = api.flow().id("my-flow")
|
||||||
|
|
||||||
|
assert flow_instance is not None
|
||||||
|
assert flow_instance.id == "my-flow"
|
||||||
|
|
||||||
|
def test_flow_instance_has_methods(self):
|
||||||
|
"""Test flow instance has expected methods"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
flow_instance = api.flow().id("my-flow")
|
||||||
|
|
||||||
|
expected_methods = [
|
||||||
|
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||||
|
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||||
|
'triples_query', 'objects_query'
|
||||||
|
]
|
||||||
|
|
||||||
|
for method in expected_methods:
|
||||||
|
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSocketClient:
|
||||||
|
"""Test WebSocket client functionality"""
|
||||||
|
|
||||||
|
def test_socket_client_url_conversion_http(self):
|
||||||
|
"""Test HTTP URL converted to WebSocket"""
|
||||||
|
api = Api(url="http://test-server:8088/")
|
||||||
|
socket = api.socket()
|
||||||
|
|
||||||
|
assert socket.url.startswith("ws://")
|
||||||
|
assert "test-server" in socket.url
|
||||||
|
|
||||||
|
def test_socket_client_url_conversion_https(self):
|
||||||
|
"""Test HTTPS URL converted to secure WebSocket"""
|
||||||
|
api = Api(url="https://test-server:8088/")
|
||||||
|
socket = api.socket()
|
||||||
|
|
||||||
|
assert socket.url.startswith("wss://")
|
||||||
|
|
||||||
|
def test_socket_client_token_passed(self):
|
||||||
|
"""Test token is passed to socket client"""
|
||||||
|
api = Api(url="http://test/", token="socket-token")
|
||||||
|
socket = api.socket()
|
||||||
|
|
||||||
|
assert socket.token == "socket-token"
|
||||||
|
|
||||||
|
def test_socket_flow_instance(self):
|
||||||
|
"""Test creating socket flow instance"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
socket = api.socket()
|
||||||
|
flow_instance = socket.flow("test-flow")
|
||||||
|
|
||||||
|
assert flow_instance is not None
|
||||||
|
assert flow_instance.flow_id == "test-flow"
|
||||||
|
|
||||||
|
def test_socket_flow_has_methods(self):
|
||||||
|
"""Test socket flow instance has expected methods"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
flow_instance = api.socket().flow("test-flow")
|
||||||
|
|
||||||
|
expected_methods = [
|
||||||
|
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||||
|
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||||
|
'triples_query', 'objects_query', 'mcp_tool'
|
||||||
|
]
|
||||||
|
|
||||||
|
for method in expected_methods:
|
||||||
|
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkClient:
|
||||||
|
"""Test bulk operations client"""
|
||||||
|
|
||||||
|
def test_bulk_client_url_conversion(self):
|
||||||
|
"""Test bulk client uses WebSocket URL"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
bulk = api.bulk()
|
||||||
|
|
||||||
|
assert bulk.url.startswith("ws://")
|
||||||
|
|
||||||
|
def test_bulk_client_has_import_methods(self):
|
||||||
|
"""Test bulk client has import methods"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
bulk = api.bulk()
|
||||||
|
|
||||||
|
import_methods = [
|
||||||
|
'import_triples',
|
||||||
|
'import_graph_embeddings',
|
||||||
|
'import_document_embeddings',
|
||||||
|
'import_entity_contexts',
|
||||||
|
'import_objects'
|
||||||
|
]
|
||||||
|
|
||||||
|
for method in import_methods:
|
||||||
|
assert hasattr(bulk, method), f"Missing method: {method}"
|
||||||
|
|
||||||
|
def test_bulk_client_has_export_methods(self):
|
||||||
|
"""Test bulk client has export methods"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
bulk = api.bulk()
|
||||||
|
|
||||||
|
export_methods = [
|
||||||
|
'export_triples',
|
||||||
|
'export_graph_embeddings',
|
||||||
|
'export_document_embeddings',
|
||||||
|
'export_entity_contexts'
|
||||||
|
]
|
||||||
|
|
||||||
|
for method in export_methods:
|
||||||
|
assert hasattr(bulk, method), f"Missing method: {method}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsClient:
|
||||||
|
"""Test metrics client"""
|
||||||
|
|
||||||
|
@patch('requests.get')
|
||||||
|
def test_metrics_get(self, mock_get):
|
||||||
|
"""Test getting metrics"""
|
||||||
|
mock_get.return_value.status_code = 200
|
||||||
|
mock_get.return_value.text = "# HELP metric_name\nmetric_name 42"
|
||||||
|
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
metrics_text = api.metrics().get()
|
||||||
|
|
||||||
|
assert "metric_name" in metrics_text
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
@patch('requests.get')
|
||||||
|
def test_metrics_with_token(self, mock_get):
|
||||||
|
"""Test metrics request includes token"""
|
||||||
|
mock_get.return_value.status_code = 200
|
||||||
|
mock_get.return_value.text = "metrics"
|
||||||
|
|
||||||
|
api = Api(url="http://test/", token="metrics-token")
|
||||||
|
api.metrics().get()
|
||||||
|
|
||||||
|
# Verify token in headers
|
||||||
|
call_args = mock_get.call_args
|
||||||
|
headers = call_args[1].get('headers', {})
|
||||||
|
assert 'Authorization' in headers
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingTypes:
|
||||||
|
"""Test streaming chunk types"""
|
||||||
|
|
||||||
|
def test_agent_thought_creation(self):
|
||||||
|
"""Test creating AgentThought chunk"""
|
||||||
|
chunk = AgentThought(content="thinking...", end_of_message=False)
|
||||||
|
|
||||||
|
assert chunk.content == "thinking..."
|
||||||
|
assert chunk.end_of_message is False
|
||||||
|
assert chunk.chunk_type == "thought"
|
||||||
|
|
||||||
|
def test_agent_observation_creation(self):
|
||||||
|
"""Test creating AgentObservation chunk"""
|
||||||
|
chunk = AgentObservation(content="observing...", end_of_message=False)
|
||||||
|
|
||||||
|
assert chunk.content == "observing..."
|
||||||
|
assert chunk.chunk_type == "observation"
|
||||||
|
|
||||||
|
def test_agent_answer_creation(self):
|
||||||
|
"""Test creating AgentAnswer chunk"""
|
||||||
|
chunk = AgentAnswer(
|
||||||
|
content="answer",
|
||||||
|
end_of_message=True,
|
||||||
|
end_of_dialog=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunk.content == "answer"
|
||||||
|
assert chunk.end_of_message is True
|
||||||
|
assert chunk.end_of_dialog is True
|
||||||
|
assert chunk.chunk_type == "final-answer"
|
||||||
|
|
||||||
|
def test_rag_chunk_creation(self):
|
||||||
|
"""Test creating RAGChunk"""
|
||||||
|
chunk = RAGChunk(
|
||||||
|
content="response chunk",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunk.content == "response chunk"
|
||||||
|
assert chunk.end_of_stream is False
|
||||||
|
assert chunk.error is None
|
||||||
|
|
||||||
|
def test_rag_chunk_with_error(self):
|
||||||
|
"""Test RAGChunk with error"""
|
||||||
|
error_dict = {"type": "error", "message": "failed"}
|
||||||
|
chunk = RAGChunk(
|
||||||
|
content="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=error_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunk.error == error_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TestTripleType:
|
||||||
|
"""Test Triple data structure"""
|
||||||
|
|
||||||
|
def test_triple_creation(self):
|
||||||
|
"""Test creating Triple"""
|
||||||
|
triple = Triple(s="subject", p="predicate", o="object")
|
||||||
|
|
||||||
|
assert triple.s == "subject"
|
||||||
|
assert triple.p == "predicate"
|
||||||
|
assert triple.o == "object"
|
||||||
|
|
||||||
|
def test_triple_with_uris(self):
|
||||||
|
"""Test Triple with URI values"""
|
||||||
|
triple = Triple(
|
||||||
|
s="http://example.org/entity1",
|
||||||
|
p="http://example.org/relation",
|
||||||
|
o="http://example.org/entity2"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert triple.s.startswith("http://")
|
||||||
|
assert triple.p.startswith("http://")
|
||||||
|
assert triple.o.startswith("http://")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncClients:
|
||||||
|
"""Test async client availability"""
|
||||||
|
|
||||||
|
def test_async_flow_creation(self):
|
||||||
|
"""Test creating async flow client"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
async_flow = api.async_flow()
|
||||||
|
|
||||||
|
assert async_flow is not None
|
||||||
|
|
||||||
|
def test_async_socket_creation(self):
|
||||||
|
"""Test creating async socket client"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
async_socket = api.async_socket()
|
||||||
|
|
||||||
|
assert async_socket is not None
|
||||||
|
assert async_socket.url.startswith("ws://")
|
||||||
|
|
||||||
|
def test_async_bulk_creation(self):
|
||||||
|
"""Test creating async bulk client"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
async_bulk = api.async_bulk()
|
||||||
|
|
||||||
|
assert async_bulk is not None
|
||||||
|
|
||||||
|
def test_async_metrics_creation(self):
|
||||||
|
"""Test creating async metrics client"""
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
async_metrics = api.async_metrics()
|
||||||
|
|
||||||
|
assert async_metrics is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error handling"""
|
||||||
|
|
||||||
|
@patch('requests.post')
|
||||||
|
def test_protocol_exception_on_non_200(self, mock_post):
|
||||||
|
"""Test ProtocolException raised on non-200 status"""
|
||||||
|
from trustgraph.api.exceptions import ProtocolException
|
||||||
|
|
||||||
|
mock_post.return_value.status_code = 500
|
||||||
|
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
|
||||||
|
with pytest.raises(ProtocolException):
|
||||||
|
api.flow().list()
|
||||||
|
|
||||||
|
@patch('requests.post')
|
||||||
|
def test_application_exception_on_error_response(self, mock_post):
|
||||||
|
"""Test ApplicationException on error in response"""
|
||||||
|
from trustgraph.api.exceptions import ApplicationException
|
||||||
|
|
||||||
|
mock_post.return_value.status_code = 200
|
||||||
|
mock_post.return_value.json.return_value = {
|
||||||
|
"error": {
|
||||||
|
"type": "ValidationError",
|
||||||
|
"message": "Invalid input"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api = Api(url="http://test/")
|
||||||
|
|
||||||
|
with pytest.raises(ApplicationException):
|
||||||
|
api.flow().list()
|
||||||
|
|
||||||
|
|
||||||
|
# Run tests with: pytest tests/unit/test_python_api_client.py -v
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
|
|
@ -23,9 +23,9 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
|
|
||||||
assert request.operation == "detect-type"
|
assert request.operation == "detect-type"
|
||||||
assert request.sample == "test data"
|
assert request.sample == "test data"
|
||||||
assert request.type is None # Optional, defaults to None
|
assert request.type == "" # Optional, defaults to empty string
|
||||||
assert request.schema_name is None # Optional, defaults to None
|
assert request.schema_name == "" # Optional, defaults to empty string
|
||||||
assert request.options is None # Optional, defaults to None
|
assert request.options == {} # Optional, defaults to empty dict
|
||||||
|
|
||||||
def test_request_schema_all_operations(self):
|
def test_request_schema_all_operations(self):
|
||||||
"""Test request schema supports all operations"""
|
"""Test request schema supports all operations"""
|
||||||
|
|
@ -66,9 +66,9 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
assert response.detected_type == "xml"
|
assert response.detected_type == "xml"
|
||||||
assert response.confidence == 0.9
|
assert response.confidence == 0.9
|
||||||
assert response.error is None
|
assert response.error is None
|
||||||
assert response.descriptor is None
|
assert response.descriptor == "" # Defaults to empty string
|
||||||
assert response.metadata is None
|
assert response.metadata == {} # Defaults to empty dict
|
||||||
assert response.schema_matches is None # New field, defaults to None
|
assert response.schema_matches == [] # Defaults to empty list
|
||||||
|
|
||||||
def test_response_schema_with_error(self):
|
def test_response_schema_with_error(self):
|
||||||
"""Test response schema with error"""
|
"""Test response schema with error"""
|
||||||
|
|
@ -140,6 +140,7 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
assert response.metadata == metadata
|
assert response.metadata == metadata
|
||||||
assert response.metadata["field_count"] == "5"
|
assert response.metadata["field_count"] == "5"
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
|
||||||
def test_schema_serialization(self):
|
def test_schema_serialization(self):
|
||||||
"""Test that schemas can be serialized and deserialized correctly"""
|
"""Test that schemas can be serialized and deserialized correctly"""
|
||||||
# Test request serialization
|
# Test request serialization
|
||||||
|
|
@ -158,6 +159,7 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
assert deserialized.sample == request.sample
|
assert deserialized.sample == request.sample
|
||||||
assert deserialized.options == request.options
|
assert deserialized.options == request.options
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
|
||||||
def test_response_serialization_with_schema_matches(self):
|
def test_response_serialization_with_schema_matches(self):
|
||||||
"""Test response serialization with schema_matches array"""
|
"""Test response serialization with schema_matches array"""
|
||||||
response = StructuredDataDiagnosisResponse(
|
response = StructuredDataDiagnosisResponse(
|
||||||
|
|
@ -185,7 +187,7 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify default value for new field
|
# Verify default value for new field
|
||||||
assert response.schema_matches is None # Defaults to None when not set
|
assert response.schema_matches == [] # Defaults to empty list when not set
|
||||||
|
|
||||||
# Verify old fields still work
|
# Verify old fields still work
|
||||||
assert response.detected_type == "json"
|
assert response.detected_type == "json"
|
||||||
|
|
@ -221,7 +223,7 @@ class TestStructuredDiagnosisSchemaContract:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert error_response.error is not None
|
assert error_response.error is not None
|
||||||
assert error_response.schema_matches is None # Default None when not set
|
assert error_response.schema_matches == [] # Default empty list when not set
|
||||||
|
|
||||||
def test_all_operations_supported(self):
|
def test_all_operations_supported(self):
|
||||||
"""Verify all operations are properly supported in the contract"""
|
"""Verify all operations are properly supported in the contract"""
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class TestMessageDispatcher:
|
||||||
assert dispatcher.max_workers == 10
|
assert dispatcher.max_workers == 10
|
||||||
assert dispatcher.semaphore._value == 10
|
assert dispatcher.semaphore._value == 10
|
||||||
assert dispatcher.active_tasks == set()
|
assert dispatcher.active_tasks == set()
|
||||||
assert dispatcher.pulsar_client is None
|
assert dispatcher.backend is None
|
||||||
assert dispatcher.dispatcher_manager is None
|
assert dispatcher.dispatcher_manager is None
|
||||||
assert len(dispatcher.service_mapping) > 0
|
assert len(dispatcher.service_mapping) > 0
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ class TestMessageDispatcher:
|
||||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||||
mock_pulsar_client = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_config_receiver = MagicMock()
|
mock_config_receiver = MagicMock()
|
||||||
mock_dispatcher_instance = MagicMock()
|
mock_dispatcher_instance = MagicMock()
|
||||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||||
|
|
@ -94,14 +94,14 @@ class TestMessageDispatcher:
|
||||||
dispatcher = MessageDispatcher(
|
dispatcher = MessageDispatcher(
|
||||||
max_workers=8,
|
max_workers=8,
|
||||||
config_receiver=mock_config_receiver,
|
config_receiver=mock_config_receiver,
|
||||||
pulsar_client=mock_pulsar_client
|
backend=mock_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dispatcher.max_workers == 8
|
assert dispatcher.max_workers == 8
|
||||||
assert dispatcher.pulsar_client == mock_pulsar_client
|
assert dispatcher.backend == mock_backend
|
||||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||||
mock_dispatcher_manager.assert_called_once_with(
|
mock_dispatcher_manager.assert_called_once_with(
|
||||||
mock_pulsar_client, mock_config_receiver, prefix="rev-gateway"
|
mock_backend, mock_config_receiver, prefix="rev-gateway"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_dispatcher_service_mapping(self):
|
def test_message_dispatcher_service_mapping(self):
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,11 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway initialization with default parameters"""
|
"""Test ReverseGateway initialization with default parameters"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
|
|
||||||
|
|
@ -38,11 +38,11 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway initialization with custom parameters"""
|
"""Test ReverseGateway initialization with custom parameters"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway(
|
gateway = ReverseGateway(
|
||||||
websocket_uri="wss://example.com:8080/websocket",
|
websocket_uri="wss://example.com:8080/websocket",
|
||||||
|
|
@ -65,11 +65,11 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||||
|
|
||||||
|
|
@ -78,53 +78,49 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||||
ReverseGateway(websocket_uri="http://example.com")
|
ReverseGateway(websocket_uri="http://example.com")
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway initialization with missing hostname"""
|
"""Test ReverseGateway initialization with missing hostname"""
|
||||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||||
ReverseGateway(websocket_uri="ws://")
|
ReverseGateway(websocket_uri="ws://")
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway creates Pulsar client with authentication"""
|
"""Test ReverseGateway creates backend with authentication"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
with patch('pulsar.AuthenticationToken') as mock_auth:
|
|
||||||
mock_auth_instance = MagicMock()
|
|
||||||
mock_auth.return_value = mock_auth_instance
|
|
||||||
|
|
||||||
gateway = ReverseGateway(
|
gateway = ReverseGateway(
|
||||||
pulsar_api_key="test-key",
|
pulsar_api_key="test-key",
|
||||||
pulsar_listener="test-listener"
|
pulsar_listener="test-listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_auth.assert_called_once_with("test-key")
|
# Verify get_pubsub was called with the correct parameters
|
||||||
mock_pulsar_client.assert_called_once_with(
|
mock_get_pubsub.assert_called_once_with(
|
||||||
"pulsar://pulsar:6650",
|
pulsar_host="pulsar://pulsar:6650",
|
||||||
listener_name="test-listener",
|
pulsar_api_key="test-key",
|
||||||
authentication=mock_auth_instance
|
pulsar_listener="test-listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway successful connection"""
|
"""Test ReverseGateway successful connection"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
|
|
@ -142,13 +138,13 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway connection failure"""
|
"""Test ReverseGateway connection failure"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||||
|
|
@ -162,12 +158,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway disconnect"""
|
"""Test ReverseGateway disconnect"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
|
|
||||||
|
|
@ -189,12 +185,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway send message"""
|
"""Test ReverseGateway send message"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
|
|
||||||
|
|
@ -211,12 +207,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway send message with closed connection"""
|
"""Test ReverseGateway send message with closed connection"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
|
|
||||||
|
|
@ -234,12 +230,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway handle message"""
|
"""Test ReverseGateway handle message"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_dispatcher_instance = AsyncMock()
|
mock_dispatcher_instance = AsyncMock()
|
||||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||||
|
|
@ -263,12 +259,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway handle message with invalid JSON"""
|
"""Test ReverseGateway handle message with invalid JSON"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
|
|
||||||
|
|
@ -285,12 +281,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway listen with text message"""
|
"""Test ReverseGateway listen with text message"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
@ -318,12 +314,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway listen with binary message"""
|
"""Test ReverseGateway listen with binary message"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
@ -351,12 +347,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway listen with close message"""
|
"""Test ReverseGateway listen with close message"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
@ -383,12 +379,12 @@ class TestReverseGateway:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway shutdown"""
|
"""Test ReverseGateway shutdown"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_dispatcher_instance = AsyncMock()
|
mock_dispatcher_instance = AsyncMock()
|
||||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||||
|
|
@ -404,15 +400,15 @@ class TestReverseGateway:
|
||||||
assert gateway.running is False
|
assert gateway.running is False
|
||||||
mock_dispatcher_instance.shutdown.assert_called_once()
|
mock_dispatcher_instance.shutdown.assert_called_once()
|
||||||
gateway.disconnect.assert_called_once()
|
gateway.disconnect.assert_called_once()
|
||||||
mock_client_instance.close.assert_called_once()
|
mock_backend.close.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway stop"""
|
"""Test ReverseGateway stop"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
gateway = ReverseGateway()
|
gateway = ReverseGateway()
|
||||||
gateway.running = True
|
gateway.running = True
|
||||||
|
|
@ -427,12 +423,12 @@ class TestReverseGatewayRun:
|
||||||
|
|
||||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||||
@patch('pulsar.Client')
|
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||||
mock_client_instance = MagicMock()
|
mock_backend = MagicMock()
|
||||||
mock_pulsar_client.return_value = mock_client_instance
|
mock_get_pubsub.return_value = mock_backend
|
||||||
|
|
||||||
mock_config_receiver_instance = AsyncMock()
|
mock_config_receiver_instance = AsyncMock()
|
||||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
"""Test Qdrant document embeddings storage functionality"""
|
"""Test Qdrant document embeddings storage functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test basic Qdrant processor initialization"""
|
"""Test basic Qdrant processor initialization"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -34,9 +32,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify base class initialization was called
|
|
||||||
mock_base_init.assert_called_once()
|
|
||||||
|
|
||||||
# Verify QdrantClient was created with correct parameters
|
# Verify QdrantClient was created with correct parameters
|
||||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||||
|
|
||||||
|
|
@ -45,11 +40,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert processor.qdrant == mock_qdrant_instance
|
assert processor.qdrant == mock_qdrant_instance
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test processor initialization with default values"""
|
"""Test processor initialization with default values"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -68,11 +61,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_store_document_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing document embeddings with basic message"""
|
"""Test storing document embeddings with basic message"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -88,6 +79,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with chunks and vectors
|
# Create mock message with chunks and vectors
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'test_user'
|
mock_message.metadata.user = 'test_user'
|
||||||
|
|
@ -121,11 +115,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing document embeddings with multiple chunks"""
|
"""Test storing document embeddings with multiple chunks"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -141,6 +133,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('multi_user', 'multi_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with multiple chunks
|
# Create mock message with multiple chunks
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'multi_user'
|
mock_message.metadata.user = 'multi_user'
|
||||||
|
|
@ -180,11 +175,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -200,6 +193,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with chunk having multiple vectors
|
# Create mock message with chunk having multiple vectors
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'vector_user'
|
mock_message.metadata.user = 'vector_user'
|
||||||
|
|
@ -237,11 +233,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert point.payload['doc'] == 'multi-vector document chunk'
|
assert point.payload['doc'] == 'multi-vector document chunk'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client):
|
||||||
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test storing document embeddings skips empty chunks"""
|
"""Test storing document embeddings skips empty chunks"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -277,11 +271,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_collection_creation_when_not_exists(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test that writing to non-existent collection creates it lazily"""
|
"""Test that writing to non-existent collection creates it lazily"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -297,6 +289,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('new_user', 'new_collection')] = {}
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'new_user'
|
mock_message.metadata.user = 'new_user'
|
||||||
|
|
@ -326,11 +321,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_collection_creation_exception(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test that collection creation errors are propagated"""
|
"""Test that collection creation errors are propagated"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||||
# Simulate creation failure
|
# Simulate creation failure
|
||||||
|
|
@ -348,6 +341,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('error_user', 'error_collection')] = {}
|
||||||
|
|
||||||
# Create mock message
|
# Create mock message
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'error_user'
|
mock_message.metadata.user = 'error_user'
|
||||||
|
|
@ -364,12 +360,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
await processor.store_document_embeddings(mock_message)
|
await processor.store_document_embeddings(mock_message)
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
|
async def test_collection_validation_on_write(self, mock_uuid, mock_qdrant_client):
|
||||||
"""Test collection validation checks collection exists before writing"""
|
"""Test collection validation checks collection exists before writing"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -385,6 +379,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('cache_user', 'cache_collection')] = {}
|
||||||
|
|
||||||
# Create first mock message
|
# Create first mock message
|
||||||
mock_message1 = MagicMock()
|
mock_message1 = MagicMock()
|
||||||
mock_message1.metadata.user = 'cache_user'
|
mock_message1.metadata.user = 'cache_user'
|
||||||
|
|
@ -428,11 +425,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_different_dimensions_different_collections(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test that different vector dimensions create different collections"""
|
"""Test that different vector dimensions create different collections"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -448,6 +443,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('dim_user', 'dim_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with different dimension vectors
|
# Create mock message with different dimension vectors
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'dim_user'
|
mock_message.metadata.user = 'dim_user'
|
||||||
|
|
@ -482,11 +480,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_add_args_calls_parent(self, mock_qdrant_client):
|
||||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test that add_args() calls parent add_args method"""
|
"""Test that add_args() calls parent add_args method"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_client.return_value = MagicMock()
|
mock_qdrant_client.return_value = MagicMock()
|
||||||
mock_parser = MagicMock()
|
mock_parser = MagicMock()
|
||||||
|
|
||||||
|
|
@ -502,11 +498,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test proper UTF-8 decoding of chunk text"""
|
"""Test proper UTF-8 decoding of chunk text"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -522,6 +516,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('utf8_user', 'utf8_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with UTF-8 encoded text
|
# Create mock message with UTF-8 encoded text
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'utf8_user'
|
mock_message.metadata.user = 'utf8_user'
|
||||||
|
|
@ -546,11 +543,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
||||||
|
|
||||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
|
||||||
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test handling of chunk decode exceptions"""
|
"""Test handling of chunk decode exceptions"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -563,6 +558,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('decode_user', 'decode_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with decode error
|
# Create mock message with decode error
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'decode_user'
|
mock_message.metadata.user = 'decode_user'
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
"""Test Qdrant graph embeddings storage functionality"""
|
"""Test Qdrant graph embeddings storage functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test basic Qdrant processor initialization"""
|
"""Test basic Qdrant processor initialization"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -34,9 +32,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Verify base class initialization was called
|
|
||||||
mock_base_init.assert_called_once()
|
|
||||||
|
|
||||||
# Verify QdrantClient was created with correct parameters
|
# Verify QdrantClient was created with correct parameters
|
||||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||||
|
|
||||||
|
|
@ -46,11 +41,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_store_graph_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing graph embeddings with basic message"""
|
"""Test storing graph embeddings with basic message"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -65,6 +58,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with entities and vectors
|
# Create mock message with entities and vectors
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'test_user'
|
mock_message.metadata.user = 'test_user'
|
||||||
|
|
@ -98,11 +94,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_store_graph_embeddings_multiple_entities(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing graph embeddings with multiple entities"""
|
"""Test storing graph embeddings with multiple entities"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -117,6 +111,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('multi_user', 'multi_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with multiple entities
|
# Create mock message with multiple entities
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'multi_user'
|
mock_message.metadata.user = 'multi_user'
|
||||||
|
|
@ -156,11 +153,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
|
||||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
|
||||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_instance.collection_exists.return_value = True
|
mock_qdrant_instance.collection_exists.return_value = True
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
@ -175,6 +170,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
# Add collection to known_collections (simulates config push)
|
||||||
|
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||||
|
|
||||||
# Create mock message with entity having multiple vectors
|
# Create mock message with entity having multiple vectors
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_message.metadata.user = 'vector_user'
|
mock_message.metadata.user = 'vector_user'
|
||||||
|
|
@ -212,11 +210,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
assert point.payload['entity'] == 'multi_vector_entity'
|
assert point.payload['entity'] == 'multi_vector_entity'
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
|
||||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test storing graph embeddings skips empty entity values"""
|
"""Test storing graph embeddings skips empty entity values"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -253,11 +249,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test processor initialization with default values"""
|
"""Test processor initialization with default values"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_instance = MagicMock()
|
mock_qdrant_instance = MagicMock()
|
||||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||||
|
|
||||||
|
|
@ -275,11 +269,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||||
|
|
||||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
async def test_add_args_calls_parent(self, mock_qdrant_client):
|
||||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
|
||||||
"""Test that add_args() calls parent add_args method"""
|
"""Test that add_args() calls parent add_args method"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_base_init.return_value = None
|
|
||||||
mock_qdrant_client.return_value = MagicMock()
|
mock_qdrant_client.return_value = MagicMock()
|
||||||
mock_parser = MagicMock()
|
mock_parser = MagicMock()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ dependencies = [
|
||||||
"pulsar-client",
|
"pulsar-client",
|
||||||
"prometheus-client",
|
"prometheus-client",
|
||||||
"requests",
|
"requests",
|
||||||
|
"python-logging-loki",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,114 @@
|
||||||
|
|
||||||
from . api import *
|
# Core API
|
||||||
|
from .api import Api
|
||||||
|
|
||||||
|
# Flow clients
|
||||||
|
from .flow import Flow, FlowInstance
|
||||||
|
from .async_flow import AsyncFlow, AsyncFlowInstance
|
||||||
|
|
||||||
|
# WebSocket clients
|
||||||
|
from .socket_client import SocketClient, SocketFlowInstance
|
||||||
|
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
|
||||||
|
|
||||||
|
# Bulk operation clients
|
||||||
|
from .bulk_client import BulkClient
|
||||||
|
from .async_bulk_client import AsyncBulkClient
|
||||||
|
|
||||||
|
# Metrics clients
|
||||||
|
from .metrics import Metrics
|
||||||
|
from .async_metrics import AsyncMetrics
|
||||||
|
|
||||||
|
# Types
|
||||||
|
from .types import (
|
||||||
|
Triple,
|
||||||
|
ConfigKey,
|
||||||
|
ConfigValue,
|
||||||
|
DocumentMetadata,
|
||||||
|
ProcessingMetadata,
|
||||||
|
CollectionMetadata,
|
||||||
|
StreamingChunk,
|
||||||
|
AgentThought,
|
||||||
|
AgentObservation,
|
||||||
|
AgentAnswer,
|
||||||
|
RAGChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exceptions
|
||||||
|
from .exceptions import (
|
||||||
|
ProtocolException,
|
||||||
|
TrustGraphException,
|
||||||
|
AgentError,
|
||||||
|
ConfigError,
|
||||||
|
DocumentRagError,
|
||||||
|
FlowError,
|
||||||
|
GatewayError,
|
||||||
|
GraphRagError,
|
||||||
|
LLMError,
|
||||||
|
LoadError,
|
||||||
|
LookupError,
|
||||||
|
NLPQueryError,
|
||||||
|
ObjectsQueryError,
|
||||||
|
RequestError,
|
||||||
|
StructuredQueryError,
|
||||||
|
UnexpectedError,
|
||||||
|
# Legacy alias
|
||||||
|
ApplicationException,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Core API
|
||||||
|
"Api",
|
||||||
|
|
||||||
|
# Flow clients
|
||||||
|
"Flow",
|
||||||
|
"FlowInstance",
|
||||||
|
"AsyncFlow",
|
||||||
|
"AsyncFlowInstance",
|
||||||
|
|
||||||
|
# WebSocket clients
|
||||||
|
"SocketClient",
|
||||||
|
"SocketFlowInstance",
|
||||||
|
"AsyncSocketClient",
|
||||||
|
"AsyncSocketFlowInstance",
|
||||||
|
|
||||||
|
# Bulk operation clients
|
||||||
|
"BulkClient",
|
||||||
|
"AsyncBulkClient",
|
||||||
|
|
||||||
|
# Metrics clients
|
||||||
|
"Metrics",
|
||||||
|
"AsyncMetrics",
|
||||||
|
|
||||||
|
# Types
|
||||||
|
"Triple",
|
||||||
|
"ConfigKey",
|
||||||
|
"ConfigValue",
|
||||||
|
"DocumentMetadata",
|
||||||
|
"ProcessingMetadata",
|
||||||
|
"CollectionMetadata",
|
||||||
|
"StreamingChunk",
|
||||||
|
"AgentThought",
|
||||||
|
"AgentObservation",
|
||||||
|
"AgentAnswer",
|
||||||
|
"RAGChunk",
|
||||||
|
|
||||||
|
# Exceptions
|
||||||
|
"ProtocolException",
|
||||||
|
"TrustGraphException",
|
||||||
|
"AgentError",
|
||||||
|
"ConfigError",
|
||||||
|
"DocumentRagError",
|
||||||
|
"FlowError",
|
||||||
|
"GatewayError",
|
||||||
|
"GraphRagError",
|
||||||
|
"LLMError",
|
||||||
|
"LoadError",
|
||||||
|
"LookupError",
|
||||||
|
"NLPQueryError",
|
||||||
|
"ObjectsQueryError",
|
||||||
|
"RequestError",
|
||||||
|
"StructuredQueryError",
|
||||||
|
"UnexpectedError",
|
||||||
|
"ApplicationException", # Legacy alias
|
||||||
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import requests
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from . library import Library
|
from . library import Library
|
||||||
from . flow import Flow
|
from . flow import Flow
|
||||||
|
|
@ -26,7 +27,7 @@ def check_error(response):
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
|
|
||||||
def __init__(self, url="http://localhost:8088/", timeout=60):
|
def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None):
|
||||||
|
|
||||||
self.url = url
|
self.url = url
|
||||||
|
|
||||||
|
|
@ -36,6 +37,16 @@ class Api:
|
||||||
self.url += "api/v1/"
|
self.url += "api/v1/"
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
# Lazy initialization for new clients
|
||||||
|
self._socket_client = None
|
||||||
|
self._bulk_client = None
|
||||||
|
self._async_flow = None
|
||||||
|
self._async_socket_client = None
|
||||||
|
self._async_bulk_client = None
|
||||||
|
self._metrics = None
|
||||||
|
self._async_metrics = None
|
||||||
|
|
||||||
def flow(self):
|
def flow(self):
|
||||||
return Flow(api=self)
|
return Flow(api=self)
|
||||||
|
|
@ -50,8 +61,12 @@ class Api:
|
||||||
|
|
||||||
url = f"{self.url}{path}"
|
url = f"{self.url}{path}"
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
# Invoke the API, input is passed as JSON
|
# Invoke the API, input is passed as JSON
|
||||||
resp = requests.post(url, json=request, timeout=self.timeout)
|
resp = requests.post(url, json=request, timeout=self.timeout, headers=headers)
|
||||||
|
|
||||||
# Should be a 200 status code
|
# Should be a 200 status code
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
|
|
@ -72,3 +87,96 @@ class Api:
|
||||||
|
|
||||||
def collection(self):
|
def collection(self):
|
||||||
return Collection(self)
|
return Collection(self)
|
||||||
|
|
||||||
|
# New synchronous methods
|
||||||
|
def socket(self):
|
||||||
|
"""Synchronous WebSocket-based interface for streaming operations"""
|
||||||
|
if self._socket_client is None:
|
||||||
|
from . socket_client import SocketClient
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._socket_client = SocketClient(base_url, self.timeout, self.token)
|
||||||
|
return self._socket_client
|
||||||
|
|
||||||
|
def bulk(self):
|
||||||
|
"""Synchronous bulk operations interface for import/export"""
|
||||||
|
if self._bulk_client is None:
|
||||||
|
from . bulk_client import BulkClient
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
|
||||||
|
return self._bulk_client
|
||||||
|
|
||||||
|
def metrics(self):
|
||||||
|
"""Synchronous metrics interface"""
|
||||||
|
if self._metrics is None:
|
||||||
|
from . metrics import Metrics
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._metrics = Metrics(base_url, self.timeout, self.token)
|
||||||
|
return self._metrics
|
||||||
|
|
||||||
|
# New asynchronous methods
|
||||||
|
def async_flow(self):
|
||||||
|
"""Asynchronous REST-based flow interface"""
|
||||||
|
if self._async_flow is None:
|
||||||
|
from . async_flow import AsyncFlow
|
||||||
|
self._async_flow = AsyncFlow(self.url, self.timeout, self.token)
|
||||||
|
return self._async_flow
|
||||||
|
|
||||||
|
def async_socket(self):
|
||||||
|
"""Asynchronous WebSocket-based interface for streaming operations"""
|
||||||
|
if self._async_socket_client is None:
|
||||||
|
from . async_socket_client import AsyncSocketClient
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token)
|
||||||
|
return self._async_socket_client
|
||||||
|
|
||||||
|
def async_bulk(self):
|
||||||
|
"""Asynchronous bulk operations interface for import/export"""
|
||||||
|
if self._async_bulk_client is None:
|
||||||
|
from . async_bulk_client import AsyncBulkClient
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
|
||||||
|
return self._async_bulk_client
|
||||||
|
|
||||||
|
def async_metrics(self):
|
||||||
|
"""Asynchronous metrics interface"""
|
||||||
|
if self._async_metrics is None:
|
||||||
|
from . async_metrics import AsyncMetrics
|
||||||
|
# Extract base URL (remove api/v1/ suffix)
|
||||||
|
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||||
|
self._async_metrics = AsyncMetrics(base_url, self.timeout, self.token)
|
||||||
|
return self._async_metrics
|
||||||
|
|
||||||
|
# Resource management
|
||||||
|
def close(self):
|
||||||
|
"""Close all synchronous connections"""
|
||||||
|
if self._socket_client:
|
||||||
|
self._socket_client.close()
|
||||||
|
if self._bulk_client:
|
||||||
|
self._bulk_client.close()
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
"""Close all asynchronous connections"""
|
||||||
|
if self._async_socket_client:
|
||||||
|
await self._async_socket_client.aclose()
|
||||||
|
if self._async_bulk_client:
|
||||||
|
await self._async_bulk_client.aclose()
|
||||||
|
if self._async_flow:
|
||||||
|
await self._async_flow.aclose()
|
||||||
|
|
||||||
|
# Context manager support
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
await self.aclose()
|
||||||
|
|
|
||||||
131
trustgraph-base/trustgraph/api/async_bulk_client.py
Normal file
131
trustgraph-base/trustgraph/api/async_bulk_client.py
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
from typing import Optional, AsyncIterator, Dict, Any, Iterator
|
||||||
|
|
||||||
|
from . types import Triple
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBulkClient:
|
||||||
|
"""Asynchronous bulk operations client"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = self._convert_to_ws_url(url)
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
|
||||||
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
if url.startswith("http://"):
|
||||||
|
return url.replace("http://", "ws://", 1)
|
||||||
|
elif url.startswith("https://"):
|
||||||
|
return url.replace("https://", "wss://", 1)
|
||||||
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||||
|
return url
|
||||||
|
else:
|
||||||
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import triples via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for triple in triples:
|
||||||
|
message = {
|
||||||
|
"s": triple.s,
|
||||||
|
"p": triple.p,
|
||||||
|
"o": triple.o
|
||||||
|
}
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
||||||
|
"""Bulk export triples via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
data = json.loads(raw_message)
|
||||||
|
yield Triple(
|
||||||
|
s=data.get("s", ""),
|
||||||
|
p=data.get("p", ""),
|
||||||
|
o=data.get("o", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import graph embeddings via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for embedding in embeddings:
|
||||||
|
await websocket.send(json.dumps(embedding))
|
||||||
|
|
||||||
|
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export graph embeddings via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import document embeddings via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for embedding in embeddings:
|
||||||
|
await websocket.send(json.dumps(embedding))
|
||||||
|
|
||||||
|
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export document embeddings via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import entity contexts via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for context in contexts:
|
||||||
|
await websocket.send(json.dumps(context))
|
||||||
|
|
||||||
|
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export entity contexts via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import objects via WebSocket"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for obj in objects:
|
||||||
|
await websocket.send(json.dumps(obj))
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
"""Close connections"""
|
||||||
|
# Cleanup handled by context managers
|
||||||
|
pass
|
||||||
245
trustgraph-base/trustgraph/api/async_flow.py
Normal file
245
trustgraph-base/trustgraph/api/async_flow.py
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
|
||||||
|
from . exceptions import ProtocolException, ApplicationException
|
||||||
|
|
||||||
|
|
||||||
|
def check_error(response):
|
||||||
|
if "error" in response:
|
||||||
|
try:
|
||||||
|
msg = response["error"]["message"]
|
||||||
|
tp = response["error"]["type"]
|
||||||
|
except:
|
||||||
|
raise ApplicationException(response["error"])
|
||||||
|
|
||||||
|
raise ApplicationException(f"{tp}: {msg}")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncFlow:
|
||||||
|
"""Asynchronous REST-based flow interface"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = url
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
|
||||||
|
async def request(self, path: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Make async HTTP request to Gateway API"""
|
||||||
|
url = f"{self.url}{path}"
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(url, json=request_data, headers=headers) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise ProtocolException(f"Status code {resp.status}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj = await resp.json()
|
||||||
|
except:
|
||||||
|
raise ProtocolException(f"Expected JSON response")
|
||||||
|
|
||||||
|
check_error(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def list(self) -> List[str]:
|
||||||
|
"""List all flows"""
|
||||||
|
result = await self.request("flow", {"operation": "list-flows"})
|
||||||
|
return result.get("flow-ids", [])
|
||||||
|
|
||||||
|
async def get(self, id: str) -> Dict[str, Any]:
|
||||||
|
"""Get flow definition"""
|
||||||
|
result = await self.request("flow", {
|
||||||
|
"operation": "get-flow",
|
||||||
|
"flow-id": id
|
||||||
|
})
|
||||||
|
return json.loads(result.get("flow", "{}"))
|
||||||
|
|
||||||
|
async def start(self, class_name: str, id: str, description: str, parameters: Optional[Dict] = None):
|
||||||
|
"""Start a flow"""
|
||||||
|
request_data = {
|
||||||
|
"operation": "start-flow",
|
||||||
|
"flow-id": id,
|
||||||
|
"class-name": class_name,
|
||||||
|
"description": description
|
||||||
|
}
|
||||||
|
if parameters:
|
||||||
|
request_data["parameters"] = json.dumps(parameters)
|
||||||
|
|
||||||
|
await self.request("flow", request_data)
|
||||||
|
|
||||||
|
async def stop(self, id: str):
|
||||||
|
"""Stop a flow"""
|
||||||
|
await self.request("flow", {
|
||||||
|
"operation": "stop-flow",
|
||||||
|
"flow-id": id
|
||||||
|
})
|
||||||
|
|
||||||
|
async def list_classes(self) -> List[str]:
|
||||||
|
"""List flow classes"""
|
||||||
|
result = await self.request("flow", {"operation": "list-classes"})
|
||||||
|
return result.get("class-names", [])
|
||||||
|
|
||||||
|
async def get_class(self, class_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get flow class definition"""
|
||||||
|
result = await self.request("flow", {
|
||||||
|
"operation": "get-class",
|
||||||
|
"class-name": class_name
|
||||||
|
})
|
||||||
|
return json.loads(result.get("class-definition", "{}"))
|
||||||
|
|
||||||
|
async def put_class(self, class_name: str, definition: Dict[str, Any]):
|
||||||
|
"""Create/update flow class"""
|
||||||
|
await self.request("flow", {
|
||||||
|
"operation": "put-class",
|
||||||
|
"class-name": class_name,
|
||||||
|
"class-definition": json.dumps(definition)
|
||||||
|
})
|
||||||
|
|
||||||
|
async def delete_class(self, class_name: str):
|
||||||
|
"""Delete flow class"""
|
||||||
|
await self.request("flow", {
|
||||||
|
"operation": "delete-class",
|
||||||
|
"class-name": class_name
|
||||||
|
})
|
||||||
|
|
||||||
|
def id(self, flow_id: str):
|
||||||
|
"""Get async flow instance"""
|
||||||
|
return AsyncFlowInstance(self, flow_id)
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
"""Close connection (cleanup handled by aiohttp session)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncFlowInstance:
|
||||||
|
"""Asynchronous REST flow instance"""
|
||||||
|
|
||||||
|
def __init__(self, flow: AsyncFlow, flow_id: str):
|
||||||
|
self.flow = flow
|
||||||
|
self.flow_id = flow_id
|
||||||
|
|
||||||
|
async def request(self, service: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Make request to flow-scoped service"""
|
||||||
|
return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data)
|
||||||
|
|
||||||
|
async def agent(self, question: str, user: str, state: Optional[Dict] = None,
|
||||||
|
group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Execute agent (non-streaming, use async_socket for streaming)"""
|
||||||
|
request_data = {
|
||||||
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"streaming": False # REST doesn't support streaming
|
||||||
|
}
|
||||||
|
if state is not None:
|
||||||
|
request_data["state"] = state
|
||||||
|
if group is not None:
|
||||||
|
request_data["group"] = group
|
||||||
|
if history is not None:
|
||||||
|
request_data["history"] = history
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("agent", request_data)
|
||||||
|
|
||||||
|
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
|
||||||
|
"""Text completion (non-streaming, use async_socket for streaming)"""
|
||||||
|
request_data = {
|
||||||
|
"system": system,
|
||||||
|
"prompt": prompt,
|
||||||
|
"streaming": False
|
||||||
|
}
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
result = await self.request("text-completion", request_data)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def graph_rag(self, query: str, user: str, collection: str,
|
||||||
|
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||||
|
max_entity_distance: int = 3, **kwargs: Any) -> str:
|
||||||
|
"""Graph RAG (non-streaming, use async_socket for streaming)"""
|
||||||
|
request_data = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"max-subgraph-size": max_subgraph_size,
|
||||||
|
"max-subgraph-count": max_subgraph_count,
|
||||||
|
"max-entity-distance": max_entity_distance,
|
||||||
|
"streaming": False
|
||||||
|
}
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
result = await self.request("graph-rag", request_data)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def document_rag(self, query: str, user: str, collection: str,
|
||||||
|
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||||
|
"""Document RAG (non-streaming, use async_socket for streaming)"""
|
||||||
|
request_data = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"doc-limit": doc_limit,
|
||||||
|
"streaming": False
|
||||||
|
}
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
result = await self.request("document-rag", request_data)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any):
|
||||||
|
"""Query graph embeddings for semantic search"""
|
||||||
|
request_data = {
|
||||||
|
"text": text,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("graph-embeddings", request_data)
|
||||||
|
|
||||||
|
async def embeddings(self, text: str, **kwargs: Any):
|
||||||
|
"""Generate text embeddings"""
|
||||||
|
request_data = {"text": text}
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("embeddings", request_data)
|
||||||
|
|
||||||
|
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any):
|
||||||
|
"""Triple pattern query"""
|
||||||
|
request_data = {"limit": limit}
|
||||||
|
if s is not None:
|
||||||
|
request_data["s"] = str(s)
|
||||||
|
if p is not None:
|
||||||
|
request_data["p"] = str(p)
|
||||||
|
if o is not None:
|
||||||
|
request_data["o"] = str(o)
|
||||||
|
if user is not None:
|
||||||
|
request_data["user"] = user
|
||||||
|
if collection is not None:
|
||||||
|
request_data["collection"] = collection
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("triples", request_data)
|
||||||
|
|
||||||
|
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||||
|
operation_name: Optional[str] = None, **kwargs: Any):
|
||||||
|
"""GraphQL query"""
|
||||||
|
request_data = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection
|
||||||
|
}
|
||||||
|
if variables:
|
||||||
|
request_data["variables"] = variables
|
||||||
|
if operation_name:
|
||||||
|
request_data["operationName"] = operation_name
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("objects", request_data)
|
||||||
33
trustgraph-base/trustgraph/api/async_metrics.py
Normal file
33
trustgraph-base/trustgraph/api/async_metrics.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMetrics:
|
||||||
|
"""Asynchronous metrics client"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = url
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
|
||||||
|
async def get(self) -> str:
|
||||||
|
"""Get Prometheus metrics as text"""
|
||||||
|
url: str = f"{self.url}/api/metrics"
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.get(url, headers=headers) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise Exception(f"Status code {resp.status}")
|
||||||
|
|
||||||
|
return await resp.text()
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
"""Close connections"""
|
||||||
|
pass
|
||||||
343
trustgraph-base/trustgraph/api/async_socket_client.py
Normal file
343
trustgraph-base/trustgraph/api/async_socket_client.py
Normal file
|
|
@ -0,0 +1,343 @@
|
||||||
|
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||||
|
|
||||||
|
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
|
||||||
|
from . exceptions import ProtocolException, ApplicationException
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSocketClient:
|
||||||
|
"""Asynchronous WebSocket client"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]):
|
||||||
|
self.url = self._convert_to_ws_url(url)
|
||||||
|
self.timeout = timeout
|
||||||
|
self.token = token
|
||||||
|
self._request_counter = 0
|
||||||
|
|
||||||
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
if url.startswith("http://"):
|
||||||
|
return url.replace("http://", "ws://", 1)
|
||||||
|
elif url.startswith("https://"):
|
||||||
|
return url.replace("https://", "wss://", 1)
|
||||||
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||||
|
return url
|
||||||
|
else:
|
||||||
|
# Assume ws://
|
||||||
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def flow(self, flow_id: str):
|
||||||
|
"""Get async flow instance for WebSocket operations"""
|
||||||
|
return AsyncSocketFlowInstance(self, flow_id)
|
||||||
|
|
||||||
|
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||||
|
"""Async WebSocket request implementation (non-streaming)"""
|
||||||
|
# Generate unique request ID
|
||||||
|
self._request_counter += 1
|
||||||
|
request_id = f"req-{self._request_counter}"
|
||||||
|
|
||||||
|
# Build WebSocket URL with optional token
|
||||||
|
ws_url = f"{self.url}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
# Build request message
|
||||||
|
message = {
|
||||||
|
"id": request_id,
|
||||||
|
"service": service,
|
||||||
|
"request": request
|
||||||
|
}
|
||||||
|
if flow:
|
||||||
|
message["flow"] = flow
|
||||||
|
|
||||||
|
# Connect and send request
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for single response
|
||||||
|
raw_message = await websocket.recv()
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
|
||||||
|
if response.get("id") != request_id:
|
||||||
|
raise ProtocolException(f"Response ID mismatch")
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
raise ApplicationException(response["error"])
|
||||||
|
|
||||||
|
if "response" not in response:
|
||||||
|
raise ProtocolException(f"Missing response in message")
|
||||||
|
|
||||||
|
return response["response"]
|
||||||
|
|
||||||
|
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||||
|
"""Async WebSocket request implementation (streaming)"""
|
||||||
|
# Generate unique request ID
|
||||||
|
self._request_counter += 1
|
||||||
|
request_id = f"req-{self._request_counter}"
|
||||||
|
|
||||||
|
# Build WebSocket URL with optional token
|
||||||
|
ws_url = f"{self.url}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
# Build request message
|
||||||
|
message = {
|
||||||
|
"id": request_id,
|
||||||
|
"service": service,
|
||||||
|
"request": request
|
||||||
|
}
|
||||||
|
if flow:
|
||||||
|
message["flow"] = flow
|
||||||
|
|
||||||
|
# Connect and send request
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Yield chunks as they arrive
|
||||||
|
async for raw_message in websocket:
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
|
||||||
|
if response.get("id") != request_id:
|
||||||
|
continue # Ignore messages for other requests
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
raise ApplicationException(response["error"])
|
||||||
|
|
||||||
|
if "response" in response:
|
||||||
|
resp = response["response"]
|
||||||
|
|
||||||
|
# Parse different chunk types
|
||||||
|
chunk = self._parse_chunk(resp)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Check if this is the final chunk
|
||||||
|
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||||
|
break
|
||||||
|
|
||||||
|
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||||
|
"""Parse response chunk into appropriate type"""
|
||||||
|
chunk_type = resp.get("chunk_type")
|
||||||
|
|
||||||
|
if chunk_type == "thought":
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "observation":
|
||||||
|
return AgentObservation(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||||
|
return AgentAnswer(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False),
|
||||||
|
end_of_dialog=resp.get("end_of_dialog", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "action":
|
||||||
|
# Agent action chunks - treat as thoughts for display purposes
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# RAG-style chunk (or generic chunk)
|
||||||
|
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||||
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||||
|
return RAGChunk(
|
||||||
|
content=content,
|
||||||
|
end_of_stream=resp.get("end_of_stream", False),
|
||||||
|
error=None # Errors are always thrown, never stored
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
"""Close WebSocket connection"""
|
||||||
|
# Cleanup handled by context manager
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSocketFlowInstance:
|
||||||
|
"""Asynchronous WebSocket flow instance"""
|
||||||
|
|
||||||
|
def __init__(self, client: AsyncSocketClient, flow_id: str):
|
||||||
|
self.client = client
|
||||||
|
self.flow_id = flow_id
|
||||||
|
|
||||||
|
async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None,
|
||||||
|
group: Optional[str] = None, history: Optional[list] = None,
|
||||||
|
streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]:
|
||||||
|
"""Agent with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
if state is not None:
|
||||||
|
request["state"] = state
|
||||||
|
if group is not None:
|
||||||
|
request["group"] = group
|
||||||
|
if history is not None:
|
||||||
|
request["history"] = history
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
return self.client._send_request_streaming("agent", self.flow_id, request)
|
||||||
|
else:
|
||||||
|
return await self.client._send_request("agent", self.flow_id, request)
|
||||||
|
|
||||||
|
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
|
||||||
|
"""Text completion with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"system": system,
|
||||||
|
"prompt": prompt,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
return self._text_completion_streaming(request)
|
||||||
|
else:
|
||||||
|
result = await self.client._send_request("text-completion", self.flow_id, request)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def _text_completion_streaming(self, request):
|
||||||
|
"""Helper for streaming text completion"""
|
||||||
|
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
async def graph_rag(self, query: str, user: str, collection: str,
|
||||||
|
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||||
|
max_entity_distance: int = 3, streaming: bool = False, **kwargs):
|
||||||
|
"""Graph RAG with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"max-subgraph-size": max_subgraph_size,
|
||||||
|
"max-subgraph-count": max_subgraph_count,
|
||||||
|
"max-entity-distance": max_entity_distance,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
return self._graph_rag_streaming(request)
|
||||||
|
else:
|
||||||
|
result = await self.client._send_request("graph-rag", self.flow_id, request)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def _graph_rag_streaming(self, request):
|
||||||
|
"""Helper for streaming graph RAG"""
|
||||||
|
async for chunk in self.client._send_request_streaming("graph-rag", self.flow_id, request):
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
async def document_rag(self, query: str, user: str, collection: str,
|
||||||
|
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||||
|
"""Document RAG with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"doc-limit": doc_limit,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
return self._document_rag_streaming(request)
|
||||||
|
else:
|
||||||
|
result = await self.client._send_request("document-rag", self.flow_id, request)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def _document_rag_streaming(self, request):
|
||||||
|
"""Helper for streaming document RAG"""
|
||||||
|
async for chunk in self.client._send_request_streaming("document-rag", self.flow_id, request):
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
async def prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs):
|
||||||
|
"""Execute prompt with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"id": id,
|
||||||
|
"variables": variables,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
return self._prompt_streaming(request)
|
||||||
|
else:
|
||||||
|
result = await self.client._send_request("prompt", self.flow_id, request)
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
async def _prompt_streaming(self, request):
|
||||||
|
"""Helper for streaming prompt"""
|
||||||
|
async for chunk in self.client._send_request_streaming("prompt", self.flow_id, request):
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
|
||||||
|
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||||
|
"""Query graph embeddings for semantic search"""
|
||||||
|
request = {
|
||||||
|
"text": text,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("graph-embeddings", self.flow_id, request)
|
||||||
|
|
||||||
|
async def embeddings(self, text: str, **kwargs):
|
||||||
|
"""Generate text embeddings"""
|
||||||
|
request = {"text": text}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||||
|
|
||||||
|
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs):
|
||||||
|
"""Triple pattern query"""
|
||||||
|
request = {"limit": limit}
|
||||||
|
if s is not None:
|
||||||
|
request["s"] = str(s)
|
||||||
|
if p is not None:
|
||||||
|
request["p"] = str(p)
|
||||||
|
if o is not None:
|
||||||
|
request["o"] = str(o)
|
||||||
|
if user is not None:
|
||||||
|
request["user"] = user
|
||||||
|
if collection is not None:
|
||||||
|
request["collection"] = collection
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("triples", self.flow_id, request)
|
||||||
|
|
||||||
|
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||||
|
operation_name: Optional[str] = None, **kwargs):
|
||||||
|
"""GraphQL query"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection
|
||||||
|
}
|
||||||
|
if variables:
|
||||||
|
request["variables"] = variables
|
||||||
|
if operation_name:
|
||||||
|
request["operationName"] = operation_name
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("objects", self.flow_id, request)
|
||||||
|
|
||||||
|
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
|
||||||
|
"""Execute MCP tool"""
|
||||||
|
request = {
|
||||||
|
"name": name,
|
||||||
|
"parameters": parameters
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||||
270
trustgraph-base/trustgraph/api/bulk_client.py
Normal file
270
trustgraph-base/trustgraph/api/bulk_client.py
Normal file
|
|
@ -0,0 +1,270 @@
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
from typing import Optional, Iterator, Dict, Any, Coroutine
|
||||||
|
|
||||||
|
from . types import Triple
|
||||||
|
from . exceptions import ProtocolException
|
||||||
|
|
||||||
|
|
||||||
|
class BulkClient:
|
||||||
|
"""Synchronous bulk operations client"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = self._convert_to_ws_url(url)
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
|
||||||
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
if url.startswith("http://"):
|
||||||
|
return url.replace("http://", "ws://", 1)
|
||||||
|
elif url.startswith("https://"):
|
||||||
|
return url.replace("https://", "wss://", 1)
|
||||||
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||||
|
return url
|
||||||
|
else:
|
||||||
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
||||||
|
"""Run async coroutine synchronously"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import triples via WebSocket"""
|
||||||
|
self._run_async(self._import_triples_async(flow, triples))
|
||||||
|
|
||||||
|
async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None:
|
||||||
|
"""Async implementation of triple import"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
for triple in triples:
|
||||||
|
message = {
|
||||||
|
"s": triple.s,
|
||||||
|
"p": triple.p,
|
||||||
|
"o": triple.o
|
||||||
|
}
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]:
|
||||||
|
"""Bulk export triples via WebSocket"""
|
||||||
|
async_gen = self._export_triples_async(flow)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
triple = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield triple
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
||||||
|
"""Async implementation of triple export"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
data = json.loads(raw_message)
|
||||||
|
yield Triple(
|
||||||
|
s=data.get("s", ""),
|
||||||
|
p=data.get("p", ""),
|
||||||
|
o=data.get("o", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
def import_graph_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import graph embeddings via WebSocket"""
|
||||||
|
self._run_async(self._import_graph_embeddings_async(flow, embeddings))
|
||||||
|
|
||||||
|
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||||
|
"""Async implementation of graph embeddings import"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
for embedding in embeddings:
|
||||||
|
await websocket.send(json.dumps(embedding))
|
||||||
|
|
||||||
|
def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export graph embeddings via WebSocket"""
|
||||||
|
async_gen = self._export_graph_embeddings_async(flow)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
embedding = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield embedding
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Async implementation of graph embeddings export"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
def import_document_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import document embeddings via WebSocket"""
|
||||||
|
self._run_async(self._import_document_embeddings_async(flow, embeddings))
|
||||||
|
|
||||||
|
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||||
|
"""Async implementation of document embeddings import"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
for embedding in embeddings:
|
||||||
|
await websocket.send(json.dumps(embedding))
|
||||||
|
|
||||||
|
def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export document embeddings via WebSocket"""
|
||||||
|
async_gen = self._export_document_embeddings_async(flow)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
embedding = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield embedding
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Async implementation of document embeddings export"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import entity contexts via WebSocket"""
|
||||||
|
self._run_async(self._import_entity_contexts_async(flow, contexts))
|
||||||
|
|
||||||
|
async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None:
|
||||||
|
"""Async implementation of entity contexts import"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
for context in contexts:
|
||||||
|
await websocket.send(json.dumps(context))
|
||||||
|
|
||||||
|
def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Bulk export entity contexts via WebSocket"""
|
||||||
|
async_gen = self._export_entity_contexts_async(flow)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
context = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield context
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Async implementation of entity contexts export"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
async for raw_message in websocket:
|
||||||
|
yield json.loads(raw_message)
|
||||||
|
|
||||||
|
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||||
|
"""Bulk import objects via WebSocket"""
|
||||||
|
self._run_async(self._import_objects_async(flow, objects))
|
||||||
|
|
||||||
|
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
|
||||||
|
"""Async implementation of objects import"""
|
||||||
|
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
for obj in objects:
|
||||||
|
await websocket.send(json.dumps(obj))
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close connections"""
|
||||||
|
# Cleanup handled by context managers
|
||||||
|
pass
|
||||||
|
|
@ -41,9 +41,7 @@ class Collection:
|
||||||
collection = v["collection"],
|
collection = v["collection"],
|
||||||
name = v["name"],
|
name = v["name"],
|
||||||
description = v["description"],
|
description = v["description"],
|
||||||
tags = v["tags"],
|
tags = v["tags"]
|
||||||
created_at = v["created_at"],
|
|
||||||
updated_at = v["updated_at"]
|
|
||||||
)
|
)
|
||||||
for v in collections
|
for v in collections
|
||||||
]
|
]
|
||||||
|
|
@ -76,9 +74,7 @@ class Collection:
|
||||||
collection = v["collection"],
|
collection = v["collection"],
|
||||||
name = v["name"],
|
name = v["name"],
|
||||||
description = v["description"],
|
description = v["description"],
|
||||||
tags = v["tags"],
|
tags = v["tags"]
|
||||||
created_at = v["created_at"],
|
|
||||||
updated_at = v["updated_at"]
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,134 @@
|
||||||
|
"""
|
||||||
|
TrustGraph API Exceptions
|
||||||
|
|
||||||
|
Exception hierarchy for errors returned by TrustGraph services.
|
||||||
|
Each service error type maps to a specific exception class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Protocol-level exceptions (communication errors)
|
||||||
class ProtocolException(Exception):
|
class ProtocolException(Exception):
|
||||||
|
"""Raised when WebSocket protocol errors occur"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ApplicationException(Exception):
|
|
||||||
|
# Base class for all TrustGraph application errors
|
||||||
|
class TrustGraphException(Exception):
|
||||||
|
"""Base class for all TrustGraph service errors"""
|
||||||
|
def __init__(self, message: str, error_type: str = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.error_type = error_type
|
||||||
|
|
||||||
|
|
||||||
|
# Service-specific exceptions
|
||||||
|
class AgentError(TrustGraphException):
|
||||||
|
"""Agent service error"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(TrustGraphException):
|
||||||
|
"""Configuration service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRagError(TrustGraphException):
|
||||||
|
"""Document RAG retrieval error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FlowError(TrustGraphException):
|
||||||
|
"""Flow management error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayError(TrustGraphException):
|
||||||
|
"""API Gateway error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRagError(TrustGraphException):
|
||||||
|
"""Graph RAG retrieval error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(TrustGraphException):
|
||||||
|
"""LLM service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LoadError(TrustGraphException):
|
||||||
|
"""Data loading error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LookupError(TrustGraphException):
|
||||||
|
"""Lookup/search error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NLPQueryError(TrustGraphException):
|
||||||
|
"""NLP query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectsQueryError(TrustGraphException):
|
||||||
|
"""Objects query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RequestError(TrustGraphException):
|
||||||
|
"""Request processing error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredQueryError(TrustGraphException):
|
||||||
|
"""Structured query service error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedError(TrustGraphException):
|
||||||
|
"""Unexpected/unknown error"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from error type string to exception class
|
||||||
|
ERROR_TYPE_MAPPING = {
|
||||||
|
"agent-error": AgentError,
|
||||||
|
"config-error": ConfigError,
|
||||||
|
"document-rag-error": DocumentRagError,
|
||||||
|
"flow-error": FlowError,
|
||||||
|
"gateway-error": GatewayError,
|
||||||
|
"graph-rag-error": GraphRagError,
|
||||||
|
"llm-error": LLMError,
|
||||||
|
"load-error": LoadError,
|
||||||
|
"lookup-error": LookupError,
|
||||||
|
"nlp-query-error": NLPQueryError,
|
||||||
|
"objects-query-error": ObjectsQueryError,
|
||||||
|
"request-error": RequestError,
|
||||||
|
"structured-query-error": StructuredQueryError,
|
||||||
|
"unexpected-error": UnexpectedError,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def raise_from_error_dict(error_dict: dict) -> None:
|
||||||
|
"""
|
||||||
|
Raise appropriate exception from TrustGraph error dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_dict: Dictionary with 'type' and 'message' keys
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Appropriate TrustGraphException subclass based on error type
|
||||||
|
"""
|
||||||
|
error_type = error_dict.get("type", "unexpected-error")
|
||||||
|
message = error_dict.get("message", "Unknown error")
|
||||||
|
|
||||||
|
# Look up exception class, default to UnexpectedError
|
||||||
|
exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError)
|
||||||
|
|
||||||
|
# Raise the appropriate exception
|
||||||
|
raise exception_class(message, error_type)
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy exception for backwards compatibility
|
||||||
|
ApplicationException = TrustGraphException
|
||||||
|
|
|
||||||
|
|
@ -160,14 +160,14 @@ class FlowInstance:
|
||||||
)["answer"]
|
)["answer"]
|
||||||
|
|
||||||
def graph_rag(
|
def graph_rag(
|
||||||
self, question, user="trustgraph", collection="default",
|
self, query, user="trustgraph", collection="default",
|
||||||
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
||||||
max_path_length=2,
|
max_path_length=2,
|
||||||
):
|
):
|
||||||
|
|
||||||
# The input consists of a question
|
# The input consists of a question
|
||||||
input = {
|
input = {
|
||||||
"query": question,
|
"query": query,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"entity-limit": entity_limit,
|
"entity-limit": entity_limit,
|
||||||
|
|
@ -182,13 +182,13 @@ class FlowInstance:
|
||||||
)["response"]
|
)["response"]
|
||||||
|
|
||||||
def document_rag(
|
def document_rag(
|
||||||
self, question, user="trustgraph", collection="default",
|
self, query, user="trustgraph", collection="default",
|
||||||
doc_limit=10,
|
doc_limit=10,
|
||||||
):
|
):
|
||||||
|
|
||||||
# The input consists of a question
|
# The input consists of a question
|
||||||
input = {
|
input = {
|
||||||
"query": question,
|
"query": query,
|
||||||
"user": user,
|
"user": user,
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"doc-limit": doc_limit,
|
"doc-limit": doc_limit,
|
||||||
|
|
@ -211,6 +211,21 @@ class FlowInstance:
|
||||||
input
|
input
|
||||||
)["vectors"]
|
)["vectors"]
|
||||||
|
|
||||||
|
def graph_embeddings_query(self, text, user, collection, limit=10):
|
||||||
|
|
||||||
|
# Query graph embeddings for semantic search
|
||||||
|
input = {
|
||||||
|
"text": text,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.request(
|
||||||
|
"service/graph-embeddings",
|
||||||
|
input
|
||||||
|
)
|
||||||
|
|
||||||
def prompt(self, id, variables):
|
def prompt(self, id, variables):
|
||||||
|
|
||||||
input = {
|
input = {
|
||||||
|
|
|
||||||
27
trustgraph-base/trustgraph/api/metrics.py
Normal file
27
trustgraph-base/trustgraph/api/metrics.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics:
|
||||||
|
"""Synchronous metrics client"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = url
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
|
||||||
|
def get(self) -> str:
|
||||||
|
"""Get Prometheus metrics as text"""
|
||||||
|
url: str = f"{self.url}/api/metrics"
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
resp = requests.get(url, timeout=self.timeout, headers=headers)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise Exception(f"Status code {resp.status_code}")
|
||||||
|
|
||||||
|
return resp.text
|
||||||
457
trustgraph-base/trustgraph/api/socket_client.py
Normal file
457
trustgraph-base/trustgraph/api/socket_client.py
Normal file
|
|
@ -0,0 +1,457 @@
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
|
||||||
|
from . exceptions import ProtocolException, raise_from_error_dict
|
||||||
|
|
||||||
|
|
||||||
|
class SocketClient:
|
||||||
|
"""Synchronous WebSocket client (wraps async websockets library)"""
|
||||||
|
|
||||||
|
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||||
|
self.url: str = self._convert_to_ws_url(url)
|
||||||
|
self.timeout: int = timeout
|
||||||
|
self.token: Optional[str] = token
|
||||||
|
self._connection: Optional[Any] = None
|
||||||
|
self._request_counter: int = 0
|
||||||
|
self._lock: Lock = Lock()
|
||||||
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
|
|
||||||
|
def _convert_to_ws_url(self, url: str) -> str:
|
||||||
|
"""Convert HTTP URL to WebSocket URL"""
|
||||||
|
if url.startswith("http://"):
|
||||||
|
return url.replace("http://", "ws://", 1)
|
||||||
|
elif url.startswith("https://"):
|
||||||
|
return url.replace("https://", "wss://", 1)
|
||||||
|
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||||
|
return url
|
||||||
|
else:
|
||||||
|
# Assume ws://
|
||||||
|
return f"ws://{url}"
|
||||||
|
|
||||||
|
def flow(self, flow_id: str) -> "SocketFlowInstance":
|
||||||
|
"""Get flow instance for WebSocket operations"""
|
||||||
|
return SocketFlowInstance(self, flow_id)
|
||||||
|
|
||||||
|
def _send_request_sync(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any],
|
||||||
|
streaming: bool = False
|
||||||
|
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||||
|
"""Synchronous wrapper around async WebSocket communication"""
|
||||||
|
# Create event loop if needed
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
# If loop is running (e.g., in Jupyter), create new loop
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# For streaming, we need to return an iterator
|
||||||
|
# Create a generator that runs async code
|
||||||
|
return self._streaming_generator(service, flow, request, loop)
|
||||||
|
else:
|
||||||
|
# For non-streaming, just run the async code and return result
|
||||||
|
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||||
|
|
||||||
|
def _streaming_generator(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any],
|
||||||
|
loop: asyncio.AbstractEventLoop
|
||||||
|
) -> Iterator[StreamingChunk]:
|
||||||
|
"""Generator that yields streaming chunks"""
|
||||||
|
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = loop.run_until_complete(async_gen.__anext__())
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
# Clean up async generator
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(async_gen.aclose())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _send_request_async(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Async implementation of WebSocket request (non-streaming)"""
|
||||||
|
# Generate unique request ID
|
||||||
|
with self._lock:
|
||||||
|
self._request_counter += 1
|
||||||
|
request_id = f"req-{self._request_counter}"
|
||||||
|
|
||||||
|
# Build WebSocket URL with optional token
|
||||||
|
ws_url = f"{self.url}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
# Build request message
|
||||||
|
message = {
|
||||||
|
"id": request_id,
|
||||||
|
"service": service,
|
||||||
|
"request": request
|
||||||
|
}
|
||||||
|
if flow:
|
||||||
|
message["flow"] = flow
|
||||||
|
|
||||||
|
# Connect and send request
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Wait for single response
|
||||||
|
raw_message = await websocket.recv()
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
|
||||||
|
if response.get("id") != request_id:
|
||||||
|
raise ProtocolException(f"Response ID mismatch")
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
raise_from_error_dict(response["error"])
|
||||||
|
|
||||||
|
if "response" not in response:
|
||||||
|
raise ProtocolException(f"Missing response in message")
|
||||||
|
|
||||||
|
return response["response"]
|
||||||
|
|
||||||
|
async def _send_request_async_streaming(
|
||||||
|
self,
|
||||||
|
service: str,
|
||||||
|
flow: Optional[str],
|
||||||
|
request: Dict[str, Any]
|
||||||
|
) -> Iterator[StreamingChunk]:
|
||||||
|
"""Async implementation of WebSocket request (streaming)"""
|
||||||
|
# Generate unique request ID
|
||||||
|
with self._lock:
|
||||||
|
self._request_counter += 1
|
||||||
|
request_id = f"req-{self._request_counter}"
|
||||||
|
|
||||||
|
# Build WebSocket URL with optional token
|
||||||
|
ws_url = f"{self.url}/api/v1/socket"
|
||||||
|
if self.token:
|
||||||
|
ws_url = f"{ws_url}?token={self.token}"
|
||||||
|
|
||||||
|
# Build request message
|
||||||
|
message = {
|
||||||
|
"id": request_id,
|
||||||
|
"service": service,
|
||||||
|
"request": request
|
||||||
|
}
|
||||||
|
if flow:
|
||||||
|
message["flow"] = flow
|
||||||
|
|
||||||
|
# Connect and send request
|
||||||
|
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||||
|
await websocket.send(json.dumps(message))
|
||||||
|
|
||||||
|
# Yield chunks as they arrive
|
||||||
|
async for raw_message in websocket:
|
||||||
|
response = json.loads(raw_message)
|
||||||
|
|
||||||
|
if response.get("id") != request_id:
|
||||||
|
continue # Ignore messages for other requests
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
raise_from_error_dict(response["error"])
|
||||||
|
|
||||||
|
if "response" in response:
|
||||||
|
resp = response["response"]
|
||||||
|
|
||||||
|
# Check for errors in response chunks
|
||||||
|
if "error" in resp:
|
||||||
|
raise_from_error_dict(resp["error"])
|
||||||
|
|
||||||
|
# Parse different chunk types
|
||||||
|
chunk = self._parse_chunk(resp)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Check if this is the final chunk
|
||||||
|
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||||
|
break
|
||||||
|
|
||||||
|
def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk:
|
||||||
|
"""Parse response chunk into appropriate type"""
|
||||||
|
chunk_type = resp.get("chunk_type")
|
||||||
|
|
||||||
|
if chunk_type == "thought":
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "observation":
|
||||||
|
return AgentObservation(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||||
|
return AgentAnswer(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False),
|
||||||
|
end_of_dialog=resp.get("end_of_dialog", False)
|
||||||
|
)
|
||||||
|
elif chunk_type == "action":
|
||||||
|
# Agent action chunks - treat as thoughts for display purposes
|
||||||
|
return AgentThought(
|
||||||
|
content=resp.get("content", ""),
|
||||||
|
end_of_message=resp.get("end_of_message", False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# RAG-style chunk (or generic chunk)
|
||||||
|
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||||
|
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||||
|
return RAGChunk(
|
||||||
|
content=content,
|
||||||
|
end_of_stream=resp.get("end_of_stream", False),
|
||||||
|
error=None # Errors are always thrown, never stored
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close WebSocket connection"""
|
||||||
|
# Cleanup handled by context manager in async code
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SocketFlowInstance:
|
||||||
|
"""Synchronous WebSocket flow instance with same interface as REST FlowInstance"""
|
||||||
|
|
||||||
|
def __init__(self, client: SocketClient, flow_id: str) -> None:
|
||||||
|
self.client: SocketClient = client
|
||||||
|
self.flow_id: str = flow_id
|
||||||
|
|
||||||
|
def agent(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
user: str,
|
||||||
|
state: Optional[Dict[str, Any]] = None,
|
||||||
|
group: Optional[str] = None,
|
||||||
|
history: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
streaming: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||||
|
"""Agent with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"question": question,
|
||||||
|
"user": user,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
if state is not None:
|
||||||
|
request["state"] = state
|
||||||
|
if group is not None:
|
||||||
|
request["group"] = group
|
||||||
|
if history is not None:
|
||||||
|
request["history"] = history
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("agent", self.flow_id, request, streaming)
|
||||||
|
|
||||||
|
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
|
||||||
|
"""Text completion with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"system": system,
|
||||||
|
"prompt": prompt,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# For text completion, yield just the content
|
||||||
|
for chunk in result:
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
else:
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
def graph_rag(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user: str,
|
||||||
|
collection: str,
|
||||||
|
max_subgraph_size: int = 1000,
|
||||||
|
max_subgraph_count: int = 5,
|
||||||
|
max_entity_distance: int = 3,
|
||||||
|
streaming: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[str, Iterator[str]]:
|
||||||
|
"""Graph RAG with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"max-subgraph-size": max_subgraph_size,
|
||||||
|
"max-subgraph-count": max_subgraph_count,
|
||||||
|
"max-entity-distance": max_entity_distance,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
for chunk in result:
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
else:
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
def document_rag(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user: str,
|
||||||
|
collection: str,
|
||||||
|
doc_limit: int = 10,
|
||||||
|
streaming: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[str, Iterator[str]]:
|
||||||
|
"""Document RAG with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"doc-limit": doc_limit,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
for chunk in result:
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
else:
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
def prompt(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
variables: Dict[str, str],
|
||||||
|
streaming: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[str, Iterator[str]]:
|
||||||
|
"""Execute prompt with optional streaming"""
|
||||||
|
request = {
|
||||||
|
"id": id,
|
||||||
|
"variables": variables,
|
||||||
|
"streaming": streaming
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
for chunk in result:
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
yield chunk.content
|
||||||
|
else:
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
def graph_embeddings_query(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
user: str,
|
||||||
|
collection: str,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Query graph embeddings for semantic search"""
|
||||||
|
request = {
|
||||||
|
"text": text,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Generate text embeddings"""
|
||||||
|
request = {"text": text}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def triples_query(
|
||||||
|
self,
|
||||||
|
s: Optional[str] = None,
|
||||||
|
p: Optional[str] = None,
|
||||||
|
o: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
collection: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Triple pattern query"""
|
||||||
|
request = {"limit": limit}
|
||||||
|
if s is not None:
|
||||||
|
request["s"] = str(s)
|
||||||
|
if p is not None:
|
||||||
|
request["p"] = str(p)
|
||||||
|
if o is not None:
|
||||||
|
request["o"] = str(o)
|
||||||
|
if user is not None:
|
||||||
|
request["user"] = user
|
||||||
|
if collection is not None:
|
||||||
|
request["collection"] = collection
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def objects_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user: str,
|
||||||
|
collection: str,
|
||||||
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
|
operation_name: Optional[str] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""GraphQL query"""
|
||||||
|
request = {
|
||||||
|
"query": query,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection
|
||||||
|
}
|
||||||
|
if variables:
|
||||||
|
request["variables"] = variables
|
||||||
|
if operation_name:
|
||||||
|
request["operationName"] = operation_name
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("objects", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def mcp_tool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
parameters: Dict[str, Any],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute MCP tool"""
|
||||||
|
request = {
|
||||||
|
"name": name,
|
||||||
|
"parameters": parameters
|
||||||
|
}
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List
|
from typing import List, Optional, Dict, Any
|
||||||
from .. knowledge import hash, Uri, Literal
|
from .. knowledge import hash, Uri, Literal
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|
@ -49,5 +49,34 @@ class CollectionMetadata:
|
||||||
name : str
|
name : str
|
||||||
description : str
|
description : str
|
||||||
tags : List[str]
|
tags : List[str]
|
||||||
created_at : str
|
|
||||||
updated_at : str
|
# Streaming chunk types
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class StreamingChunk:
|
||||||
|
"""Base class for streaming chunks"""
|
||||||
|
content: str
|
||||||
|
end_of_message: bool = False
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AgentThought(StreamingChunk):
|
||||||
|
"""Agent reasoning chunk"""
|
||||||
|
chunk_type: str = "thought"
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AgentObservation(StreamingChunk):
|
||||||
|
"""Agent tool observation chunk"""
|
||||||
|
chunk_type: str = "observation"
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AgentAnswer(StreamingChunk):
|
||||||
|
"""Agent final answer chunk"""
|
||||||
|
chunk_type: str = "final-answer"
|
||||||
|
end_of_dialog: bool = False
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RAGChunk(StreamingChunk):
|
||||||
|
"""RAG streaming chunk"""
|
||||||
|
chunk_type: str = "rag"
|
||||||
|
end_of_stream: bool = False
|
||||||
|
error: Optional[Dict[str, str]] = None
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
|
|
||||||
from . pubsub import PulsarClient
|
from . pubsub import PulsarClient, get_pubsub
|
||||||
from . async_processor import AsyncProcessor
|
from . async_processor import AsyncProcessor
|
||||||
from . consumer import Consumer
|
from . consumer import Consumer
|
||||||
from . producer import Producer
|
from . producer import Producer
|
||||||
from . publisher import Publisher
|
from . publisher import Publisher
|
||||||
from . subscriber import Subscriber
|
from . subscriber import Subscriber
|
||||||
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||||
|
from . logging import add_logging_args, setup_logging
|
||||||
from . flow_processor import FlowProcessor
|
from . flow_processor import FlowProcessor
|
||||||
from . consumer_spec import ConsumerSpec
|
from . consumer_spec import ConsumerSpec
|
||||||
from . parameter_spec import ParameterSpec
|
from . parameter_spec import ParameterSpec
|
||||||
|
|
@ -33,4 +34,5 @@ from . tool_service import ToolService
|
||||||
from . tool_client import ToolClientSpec
|
from . tool_client import ToolClientSpec
|
||||||
from . agent_client import AgentClientSpec
|
from . agent_client import AgentClientSpec
|
||||||
from . structured_query_client import StructuredQueryClientSpec
|
from . structured_query_client import StructuredQueryClientSpec
|
||||||
|
from . collection_config_handler import CollectionConfigHandler
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,11 @@ from prometheus_client import start_http_server, Info
|
||||||
|
|
||||||
from .. schema import ConfigPush, config_push_queue
|
from .. schema import ConfigPush, config_push_queue
|
||||||
from .. log_level import LogLevel
|
from .. log_level import LogLevel
|
||||||
from . pubsub import PulsarClient
|
from . pubsub import PulsarClient, get_pubsub
|
||||||
from . producer import Producer
|
from . producer import Producer
|
||||||
from . consumer import Consumer
|
from . consumer import Consumer
|
||||||
from . metrics import ProcessorMetrics, ConsumerMetrics
|
from . metrics import ProcessorMetrics, ConsumerMetrics
|
||||||
|
from . logging import add_logging_args, setup_logging
|
||||||
|
|
||||||
default_config_queue = config_push_queue
|
default_config_queue = config_push_queue
|
||||||
|
|
||||||
|
|
@ -33,8 +34,11 @@ class AsyncProcessor:
|
||||||
# Store the identity
|
# Store the identity
|
||||||
self.id = params.get("id")
|
self.id = params.get("id")
|
||||||
|
|
||||||
# Register a pulsar client
|
# Create pub/sub backend via factory
|
||||||
self.pulsar_client_object = PulsarClient(**params)
|
self.pubsub_backend = get_pubsub(**params)
|
||||||
|
|
||||||
|
# Store pulsar_host for backward compatibility
|
||||||
|
self._pulsar_host = params.get("pulsar_host", "pulsar://pulsar:6650")
|
||||||
|
|
||||||
# Initialise metrics, records the parameters
|
# Initialise metrics, records the parameters
|
||||||
ProcessorMetrics(processor = self.id).info({
|
ProcessorMetrics(processor = self.id).info({
|
||||||
|
|
@ -69,7 +73,7 @@ class AsyncProcessor:
|
||||||
self.config_sub_task = Consumer(
|
self.config_sub_task = Consumer(
|
||||||
|
|
||||||
taskgroup = self.taskgroup,
|
taskgroup = self.taskgroup,
|
||||||
client = self.pulsar_client,
|
backend = self.pubsub_backend, # Changed from client to backend
|
||||||
subscriber = config_subscriber_id,
|
subscriber = config_subscriber_id,
|
||||||
flow = None,
|
flow = None,
|
||||||
|
|
||||||
|
|
@ -95,16 +99,16 @@ class AsyncProcessor:
|
||||||
# This is called to stop all threads. An over-ride point for extra
|
# This is called to stop all threads. An over-ride point for extra
|
||||||
# functionality
|
# functionality
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.pulsar_client.close()
|
self.pubsub_backend.close()
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
# Returns the pulsar host
|
# Returns the pub/sub backend (new interface)
|
||||||
@property
|
@property
|
||||||
def pulsar_host(self): return self.pulsar_client_object.pulsar_host
|
def pubsub(self): return self.pubsub_backend
|
||||||
|
|
||||||
# Returns the pulsar client
|
# Returns the pulsar host (backward compatibility)
|
||||||
@property
|
@property
|
||||||
def pulsar_client(self): return self.pulsar_client_object.client
|
def pulsar_host(self): return self._pulsar_host
|
||||||
|
|
||||||
# Register a new event handler for configuration change
|
# Register a new event handler for configuration change
|
||||||
def register_config_handler(self, handler):
|
def register_config_handler(self, handler):
|
||||||
|
|
@ -165,18 +169,9 @@ class AsyncProcessor:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_logging(cls, log_level='INFO'):
|
def setup_logging(cls, args):
|
||||||
"""Configure logging for the entire application"""
|
"""Configure logging for the entire application"""
|
||||||
# Support environment variable override
|
setup_logging(args)
|
||||||
env_log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', log_level)
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=getattr(logging, env_log_level.upper()),
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[logging.StreamHandler()]
|
|
||||||
)
|
|
||||||
logger.info(f"Logging configured with level: {env_log_level}")
|
|
||||||
|
|
||||||
# Startup fabric. launch calls launch_async in async mode.
|
# Startup fabric. launch calls launch_async in async mode.
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -203,7 +198,7 @@ class AsyncProcessor:
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
|
|
||||||
# Setup logging before anything else
|
# Setup logging before anything else
|
||||||
cls.setup_logging(args.get('log_level', 'INFO').upper())
|
cls.setup_logging(args)
|
||||||
|
|
||||||
# Debug
|
# Debug
|
||||||
logger.debug(f"Arguments: {args}")
|
logger.debug(f"Arguments: {args}")
|
||||||
|
|
@ -255,12 +250,21 @@ class AsyncProcessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
# Pub/sub backend selection
|
||||||
|
parser.add_argument(
|
||||||
|
'--pubsub-backend',
|
||||||
|
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||||
|
choices=['pulsar', 'mqtt'],
|
||||||
|
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||||
|
)
|
||||||
|
|
||||||
PulsarClient.add_args(parser)
|
PulsarClient.add_args(parser)
|
||||||
|
add_logging_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config-queue',
|
'--config-push-queue',
|
||||||
default=default_config_queue,
|
default=default_config_queue,
|
||||||
help=f'Config push queue {default_config_queue}',
|
help=f'Config push queue (default: {default_config_queue})',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
||||||
148
trustgraph-base/trustgraph/base/backend.py
Normal file
148
trustgraph-base/trustgraph/base/backend.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""
|
||||||
|
Backend abstraction interfaces for pub/sub systems.
|
||||||
|
|
||||||
|
This module defines Protocol classes that all pub/sub backends must implement,
|
||||||
|
allowing TrustGraph to work with different messaging systems (Pulsar, MQTT, Kafka, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Protocol, Any, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Message(Protocol):
|
||||||
|
"""Protocol for a received message."""
|
||||||
|
|
||||||
|
def value(self) -> Any:
|
||||||
|
"""
|
||||||
|
Get the deserialized message content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataclass instance representing the message
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def properties(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get message properties/metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of message properties
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BackendProducer(Protocol):
|
||||||
|
"""Protocol for backend-specific producer."""
|
||||||
|
|
||||||
|
def send(self, message: Any, properties: dict = {}) -> None:
|
||||||
|
"""
|
||||||
|
Send a message (dataclass instance) with optional properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Dataclass instance to send
|
||||||
|
properties: Optional metadata properties
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush any buffered messages."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the producer."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BackendConsumer(Protocol):
|
||||||
|
"""Protocol for backend-specific consumer."""
|
||||||
|
|
||||||
|
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||||
|
"""
|
||||||
|
Receive a message from the topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout_millis: Timeout in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If no message received within timeout
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def acknowledge(self, message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Acknowledge successful processing of a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to acknowledge
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def negative_acknowledge(self, message: Message) -> None:
|
||||||
|
"""
|
||||||
|
Negative acknowledge - triggers redelivery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to negatively acknowledge
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def unsubscribe(self) -> None:
|
||||||
|
"""Unsubscribe from the topic."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the consumer."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class PubSubBackend(Protocol):
|
||||||
|
"""Protocol defining the interface all pub/sub backends must implement."""
|
||||||
|
|
||||||
|
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||||
|
"""
|
||||||
|
Create a producer for a topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
**options: Backend-specific options (e.g., chunking_enabled)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific producer instance
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def create_consumer(
|
||||||
|
self,
|
||||||
|
topic: str,
|
||||||
|
subscription: str,
|
||||||
|
schema: type,
|
||||||
|
initial_position: str = 'latest',
|
||||||
|
consumer_type: str = 'shared',
|
||||||
|
**options
|
||||||
|
) -> BackendConsumer:
|
||||||
|
"""
|
||||||
|
Create a consumer for a topic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
subscription: Subscription/consumer group name
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
initial_position: 'earliest' or 'latest' (some backends may ignore)
|
||||||
|
consumer_type: 'shared', 'exclusive', 'failover' (some backends may ignore)
|
||||||
|
**options: Backend-specific options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific consumer instance
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the backend connection."""
|
||||||
|
...
|
||||||
|
|
@ -15,12 +15,13 @@ def get_cassandra_defaults() -> dict:
|
||||||
Get default Cassandra configuration values from environment variables or fallback defaults.
|
Get default Cassandra configuration values from environment variables or fallback defaults.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary with 'host', 'username', and 'password' keys
|
dict: Dictionary with 'host', 'username', 'password', and 'keyspace' keys
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'host': os.getenv('CASSANDRA_HOST', 'cassandra'),
|
'host': os.getenv('CASSANDRA_HOST', 'cassandra'),
|
||||||
'username': os.getenv('CASSANDRA_USERNAME'),
|
'username': os.getenv('CASSANDRA_USERNAME'),
|
||||||
'password': os.getenv('CASSANDRA_PASSWORD')
|
'password': os.getenv('CASSANDRA_PASSWORD'),
|
||||||
|
'keyspace': os.getenv('CASSANDRA_KEYSPACE')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,6 +55,12 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
|
||||||
if 'CASSANDRA_PASSWORD' in os.environ:
|
if 'CASSANDRA_PASSWORD' in os.environ:
|
||||||
password_help += " [from CASSANDRA_PASSWORD]"
|
password_help += " [from CASSANDRA_PASSWORD]"
|
||||||
|
|
||||||
|
keyspace_help = "Cassandra keyspace (default: service-specific)"
|
||||||
|
if defaults['keyspace']:
|
||||||
|
keyspace_help = f"Cassandra keyspace (default: {defaults['keyspace']})"
|
||||||
|
if 'CASSANDRA_KEYSPACE' in os.environ:
|
||||||
|
keyspace_help += " [from CASSANDRA_KEYSPACE]"
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--cassandra-host',
|
'--cassandra-host',
|
||||||
default=defaults['host'],
|
default=defaults['host'],
|
||||||
|
|
@ -72,13 +79,20 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
|
||||||
help=password_help
|
help=password_help
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--cassandra-keyspace',
|
||||||
|
default=defaults['keyspace'],
|
||||||
|
help=keyspace_help
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_cassandra_config(
|
def resolve_cassandra_config(
|
||||||
args: Optional[Any] = None,
|
args: Optional[Any] = None,
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None,
|
||||||
) -> Tuple[List[str], Optional[str], Optional[str]]:
|
default_keyspace: Optional[str] = None
|
||||||
|
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Resolve Cassandra configuration from various sources.
|
Resolve Cassandra configuration from various sources.
|
||||||
|
|
||||||
|
|
@ -86,25 +100,29 @@ def resolve_cassandra_config(
|
||||||
Converts host string to list format for Cassandra driver.
|
Converts host string to list format for Cassandra driver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password
|
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace
|
||||||
host: Optional explicit host parameter (overrides args)
|
host: Optional explicit host parameter (overrides args)
|
||||||
username: Optional explicit username parameter (overrides args)
|
username: Optional explicit username parameter (overrides args)
|
||||||
password: Optional explicit password parameter (overrides args)
|
password: Optional explicit password parameter (overrides args)
|
||||||
|
default_keyspace: Optional default keyspace if not specified elsewhere
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (hosts_list, username, password)
|
tuple: (hosts_list, username, password, keyspace)
|
||||||
"""
|
"""
|
||||||
# If args provided, extract values
|
# If args provided, extract values
|
||||||
|
keyspace = None
|
||||||
if args is not None:
|
if args is not None:
|
||||||
host = host or getattr(args, 'cassandra_host', None)
|
host = host or getattr(args, 'cassandra_host', None)
|
||||||
username = username or getattr(args, 'cassandra_username', None)
|
username = username or getattr(args, 'cassandra_username', None)
|
||||||
password = password or getattr(args, 'cassandra_password', None)
|
password = password or getattr(args, 'cassandra_password', None)
|
||||||
|
keyspace = getattr(args, 'cassandra_keyspace', None)
|
||||||
|
|
||||||
# Apply defaults if still None
|
# Apply defaults if still None
|
||||||
defaults = get_cassandra_defaults()
|
defaults = get_cassandra_defaults()
|
||||||
host = host or defaults['host']
|
host = host or defaults['host']
|
||||||
username = username or defaults['username']
|
username = username or defaults['username']
|
||||||
password = password or defaults['password']
|
password = password or defaults['password']
|
||||||
|
keyspace = keyspace or defaults['keyspace'] or default_keyspace
|
||||||
|
|
||||||
# Convert host string to list
|
# Convert host string to list
|
||||||
if isinstance(host, str):
|
if isinstance(host, str):
|
||||||
|
|
@ -112,18 +130,22 @@ def resolve_cassandra_config(
|
||||||
else:
|
else:
|
||||||
hosts = host
|
hosts = host
|
||||||
|
|
||||||
return hosts, username, password
|
return hosts, username, password, keyspace
|
||||||
|
|
||||||
|
|
||||||
def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[str], Optional[str]]:
|
def get_cassandra_config_from_params(
|
||||||
|
params: dict,
|
||||||
|
default_keyspace: Optional[str] = None
|
||||||
|
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Extract and resolve Cassandra configuration from a parameters dictionary.
|
Extract and resolve Cassandra configuration from a parameters dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Dictionary of parameters that may contain Cassandra configuration
|
params: Dictionary of parameters that may contain Cassandra configuration
|
||||||
|
default_keyspace: Optional default keyspace if not specified in params
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (hosts_list, username, password)
|
tuple: (hosts_list, username, password, keyspace)
|
||||||
"""
|
"""
|
||||||
# Get Cassandra parameters
|
# Get Cassandra parameters
|
||||||
host = params.get('cassandra_host')
|
host = params.get('cassandra_host')
|
||||||
|
|
@ -131,4 +153,9 @@ def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[
|
||||||
password = params.get('cassandra_password')
|
password = params.get('cassandra_password')
|
||||||
|
|
||||||
# Use resolve function to handle defaults and list conversion
|
# Use resolve function to handle defaults and list conversion
|
||||||
return resolve_cassandra_config(host=host, username=username, password=password)
|
return resolve_cassandra_config(
|
||||||
|
host=host,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
default_keyspace=default_keyspace
|
||||||
|
)
|
||||||
128
trustgraph-base/trustgraph/base/collection_config_handler.py
Normal file
128
trustgraph-base/trustgraph/base/collection_config_handler.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
"""
|
||||||
|
Handler for storage services to process collection configuration from config push
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CollectionConfigHandler:
|
||||||
|
"""
|
||||||
|
Handles collection configuration from config push messages for storage services.
|
||||||
|
|
||||||
|
Storage services should:
|
||||||
|
1. Inherit from this class along with their service base class
|
||||||
|
2. Call register_config_handler(self.on_collection_config) in __init__
|
||||||
|
3. Implement create_collection(user, collection, metadata) method
|
||||||
|
4. Implement delete_collection(user, collection) method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# Track known collections: {(user, collection): metadata_dict}
|
||||||
|
self.known_collections: Dict[tuple, dict] = {}
|
||||||
|
# Pass remaining kwargs up the inheritance chain
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def on_collection_config(self, config: dict, version: int):
|
||||||
|
"""
|
||||||
|
Handle config push messages and extract collection information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary from ConfigPush message
|
||||||
|
version: Configuration version number
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing collection configuration (version {version})")
|
||||||
|
|
||||||
|
# Extract collections from config (treat missing key as empty)
|
||||||
|
collection_config = config.get("collection", {})
|
||||||
|
|
||||||
|
# Track which collections we've seen in this config
|
||||||
|
current_collections: Set[tuple] = set()
|
||||||
|
|
||||||
|
# Process each collection in the config
|
||||||
|
for key, value_json in collection_config.items():
|
||||||
|
try:
|
||||||
|
# Parse user:collection key
|
||||||
|
if ":" not in key:
|
||||||
|
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
user, collection = key.split(":", 1)
|
||||||
|
current_collections.add((user, collection))
|
||||||
|
|
||||||
|
# Parse metadata
|
||||||
|
metadata = json.loads(value_json)
|
||||||
|
|
||||||
|
# Check if this is a new collection or updated
|
||||||
|
collection_key = (user, collection)
|
||||||
|
if collection_key not in self.known_collections:
|
||||||
|
logger.info(f"New collection detected: {user}/{collection}")
|
||||||
|
await self.create_collection(user, collection, metadata)
|
||||||
|
self.known_collections[collection_key] = metadata
|
||||||
|
else:
|
||||||
|
# Collection already exists, update metadata if changed
|
||||||
|
if self.known_collections[collection_key] != metadata:
|
||||||
|
logger.info(f"Collection metadata updated: {user}/{collection}")
|
||||||
|
# Most storage services don't need to do anything for metadata updates
|
||||||
|
# They just need to know the collection exists
|
||||||
|
self.known_collections[collection_key] = metadata
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# Find collections that were deleted (in known but not in current)
|
||||||
|
deleted_collections = set(self.known_collections.keys()) - current_collections
|
||||||
|
for user, collection in deleted_collections:
|
||||||
|
logger.info(f"Collection deleted: {user}/{collection}")
|
||||||
|
try:
|
||||||
|
# Remove from known_collections FIRST to immediately reject new writes
|
||||||
|
# This eliminates race condition with worker threads
|
||||||
|
del self.known_collections[(user, collection)]
|
||||||
|
# Physical deletion happens after - worker threads already rejecting writes
|
||||||
|
await self.delete_collection(user, collection)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
|
||||||
|
# If physical deletion failed, should we re-add to known_collections?
|
||||||
|
# For now, keep it removed - collection is logically deleted per config
|
||||||
|
|
||||||
|
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
|
||||||
|
|
||||||
|
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||||
|
"""
|
||||||
|
Create a collection in the storage backend.
|
||||||
|
|
||||||
|
Subclasses must implement this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User ID
|
||||||
|
collection: Collection ID
|
||||||
|
metadata: Collection metadata dictionary
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Storage service must implement create_collection method")
|
||||||
|
|
||||||
|
async def delete_collection(self, user: str, collection: str):
|
||||||
|
"""
|
||||||
|
Delete a collection from the storage backend.
|
||||||
|
|
||||||
|
Subclasses must implement this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User ID
|
||||||
|
collection: Collection ID
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Storage service must implement delete_collection method")
|
||||||
|
|
||||||
|
def collection_exists(self, user: str, collection: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a collection is known to exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User ID
|
||||||
|
collection: Collection ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if collection exists, False otherwise
|
||||||
|
"""
|
||||||
|
return (user, collection) in self.known_collections
|
||||||
|
|
@ -9,9 +9,6 @@
|
||||||
# one handler, and a single thread of concurrency, nothing too outrageous
|
# one handler, and a single thread of concurrency, nothing too outrageous
|
||||||
# will happen if synchronous / blocking code is used
|
# will happen if synchronous / blocking code is used
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
import pulsar
|
|
||||||
import _pulsar
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -21,10 +18,14 @@ from .. exceptions import TooManyRequests
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Timeout exception - can come from different backends
|
||||||
|
class TimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class Consumer:
|
class Consumer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, taskgroup, flow, client, topic, subscriber, schema,
|
self, taskgroup, flow, backend, topic, subscriber, schema,
|
||||||
handler,
|
handler,
|
||||||
metrics = None,
|
metrics = None,
|
||||||
start_of_messages=False,
|
start_of_messages=False,
|
||||||
|
|
@ -35,7 +36,7 @@ class Consumer:
|
||||||
|
|
||||||
self.taskgroup = taskgroup
|
self.taskgroup = taskgroup
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.client = client
|
self.backend = backend # Changed from 'client' to 'backend'
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.subscriber = subscriber
|
self.subscriber = subscriber
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
|
@ -96,18 +97,20 @@ class Consumer:
|
||||||
|
|
||||||
logger.info(f"Subscribing to topic: {self.topic}")
|
logger.info(f"Subscribing to topic: {self.topic}")
|
||||||
|
|
||||||
|
# Determine initial position
|
||||||
if self.start_of_messages:
|
if self.start_of_messages:
|
||||||
pos = pulsar.InitialPosition.Earliest
|
initial_pos = 'earliest'
|
||||||
else:
|
else:
|
||||||
pos = pulsar.InitialPosition.Latest
|
initial_pos = 'latest'
|
||||||
|
|
||||||
|
# Create consumer via backend
|
||||||
self.consumer = await asyncio.to_thread(
|
self.consumer = await asyncio.to_thread(
|
||||||
self.client.subscribe,
|
self.backend.create_consumer,
|
||||||
topic = self.topic,
|
topic = self.topic,
|
||||||
subscription_name = self.subscriber,
|
subscription = self.subscriber,
|
||||||
schema = JsonSchema(self.schema),
|
schema = self.schema,
|
||||||
initial_position = pos,
|
initial_position = initial_pos,
|
||||||
consumer_type = pulsar.ConsumerType.Shared,
|
consumer_type = 'shared',
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -159,9 +162,10 @@ class Consumer:
|
||||||
self.consumer.receive,
|
self.consumer.receive,
|
||||||
timeout_millis=2000
|
timeout_millis=2000
|
||||||
)
|
)
|
||||||
except _pulsar.Timeout:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Handle timeout from any backend
|
||||||
|
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
|
||||||
|
continue
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
await self.handle_one_from_queue(msg)
|
await self.handle_one_from_queue(msg)
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class ConsumerSpec(Spec):
|
||||||
consumer = Consumer(
|
consumer = Consumer(
|
||||||
taskgroup = processor.taskgroup,
|
taskgroup = processor.taskgroup,
|
||||||
flow = flow,
|
flow = flow,
|
||||||
client = processor.pulsar_client,
|
backend = processor.pubsub,
|
||||||
topic = definition[self.name],
|
topic = definition[self.name],
|
||||||
subscriber = processor.id + "--" + flow.name + "--" + self.name,
|
subscriber = processor.id + "--" + flow.name + "--" + self.name,
|
||||||
schema = self.schema,
|
schema = self.schema,
|
||||||
|
|
|
||||||
159
trustgraph-base/trustgraph/base/logging.py
Normal file
159
trustgraph-base/trustgraph/base/logging.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
Centralized logging configuration for TrustGraph server-side components.
|
||||||
|
|
||||||
|
This module provides standardized logging setup across all TrustGraph services,
|
||||||
|
ensuring consistent log formats, levels, and command-line arguments.
|
||||||
|
|
||||||
|
Supports dual output to console and Loki for centralized log aggregation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
from queue import Queue
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def add_logging_args(parser):
|
||||||
|
"""
|
||||||
|
Add standard logging arguments to an argument parser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: argparse.ArgumentParser instance to add arguments to
|
||||||
|
"""
|
||||||
|
parser.add_argument(
|
||||||
|
'-l', '--log-level',
|
||||||
|
default='INFO',
|
||||||
|
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||||
|
help='Log level (default: INFO)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--loki-enabled',
|
||||||
|
action='store_true',
|
||||||
|
default=True,
|
||||||
|
help='Enable Loki logging (default: True)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--no-loki-enabled',
|
||||||
|
dest='loki_enabled',
|
||||||
|
action='store_false',
|
||||||
|
help='Disable Loki logging'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--loki-url',
|
||||||
|
default=os.getenv('LOKI_URL', 'http://loki:3100/loki/api/v1/push'),
|
||||||
|
help='Loki push URL (default: http://loki:3100/loki/api/v1/push)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--loki-username',
|
||||||
|
default=os.getenv('LOKI_USERNAME', None),
|
||||||
|
help='Loki username for authentication (optional)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--loki-password',
|
||||||
|
default=os.getenv('LOKI_PASSWORD', None),
|
||||||
|
help='Loki password for authentication (optional)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(args):
|
||||||
|
"""
|
||||||
|
Configure logging from parsed command-line arguments.
|
||||||
|
|
||||||
|
Sets up logging with a standardized format and output to stdout.
|
||||||
|
Optionally enables Loki integration for centralized log aggregation.
|
||||||
|
|
||||||
|
This should be called early in application startup, before any
|
||||||
|
logging calls are made.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary of parsed arguments (typically from vars(args))
|
||||||
|
Must contain 'log_level' key, optional Loki configuration
|
||||||
|
"""
|
||||||
|
log_level = args.get('log_level', 'INFO')
|
||||||
|
loki_enabled = args.get('loki_enabled', True)
|
||||||
|
|
||||||
|
# Build list of handlers starting with console
|
||||||
|
handlers = [logging.StreamHandler()]
|
||||||
|
|
||||||
|
# Add Loki handler if enabled
|
||||||
|
queue_listener = None
|
||||||
|
if loki_enabled:
|
||||||
|
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
|
||||||
|
loki_username = args.get('loki_username')
|
||||||
|
loki_password = args.get('loki_password')
|
||||||
|
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from logging_loki import LokiHandler
|
||||||
|
|
||||||
|
# Create Loki handler with optional authentication and processor label
|
||||||
|
loki_handler_kwargs = {
|
||||||
|
'url': loki_url,
|
||||||
|
'version': "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
if loki_username and loki_password:
|
||||||
|
loki_handler_kwargs['auth'] = (loki_username, loki_password)
|
||||||
|
|
||||||
|
# Add processor label if available (for consistency with Prometheus metrics)
|
||||||
|
if processor_id:
|
||||||
|
loki_handler_kwargs['tags'] = {'processor': processor_id}
|
||||||
|
|
||||||
|
loki_handler = LokiHandler(**loki_handler_kwargs)
|
||||||
|
|
||||||
|
# Wrap in QueueHandler for non-blocking operation
|
||||||
|
log_queue = Queue(maxsize=500)
|
||||||
|
queue_handler = logging.handlers.QueueHandler(log_queue)
|
||||||
|
handlers.append(queue_handler)
|
||||||
|
|
||||||
|
# Start QueueListener in background thread
|
||||||
|
queue_listener = logging.handlers.QueueListener(
|
||||||
|
log_queue,
|
||||||
|
loki_handler,
|
||||||
|
respect_handler_level=True
|
||||||
|
)
|
||||||
|
queue_listener.start()
|
||||||
|
|
||||||
|
# Store listener reference for potential cleanup
|
||||||
|
# (attached to root logger for access if needed)
|
||||||
|
logging.getLogger().loki_queue_listener = queue_listener
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Graceful degradation if python-logging-loki not installed
|
||||||
|
print("WARNING: python-logging-loki not installed, Loki logging disabled")
|
||||||
|
print("Install with: pip install python-logging-loki")
|
||||||
|
except Exception as e:
|
||||||
|
# Graceful degradation if Loki connection fails
|
||||||
|
print(f"WARNING: Failed to setup Loki logging: {e}")
|
||||||
|
print("Continuing with console-only logging")
|
||||||
|
|
||||||
|
# Get processor ID for log formatting (use 'unknown' if not available)
|
||||||
|
processor_id = args.get('id', 'unknown')
|
||||||
|
|
||||||
|
# Configure logging with all handlers
|
||||||
|
# Use processor ID as the primary identifier in logs
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, log_level.upper()),
|
||||||
|
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
|
||||||
|
handlers=handlers,
|
||||||
|
force=True # Force reconfiguration if already configured
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent recursive logging from Loki's HTTP client
|
||||||
|
if loki_enabled and queue_listener:
|
||||||
|
# Disable urllib3 logging to prevent infinite loop
|
||||||
|
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Logging configured with level: {log_level}")
|
||||||
|
if loki_enabled and queue_listener:
|
||||||
|
logger.info(f"Loki logging enabled: {loki_url}")
|
||||||
|
elif loki_enabled:
|
||||||
|
logger.warning("Loki logging requested but not available")
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -8,10 +7,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Producer:
|
class Producer:
|
||||||
|
|
||||||
def __init__(self, client, topic, schema, metrics=None,
|
def __init__(self, backend, topic, schema, metrics=None,
|
||||||
chunking_enabled=True):
|
chunking_enabled=True):
|
||||||
|
|
||||||
self.client = client
|
self.backend = backend # Changed from 'client' to 'backend'
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
|
||||||
|
|
@ -44,9 +43,9 @@ class Producer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Connecting publisher to {self.topic}...")
|
logger.info(f"Connecting publisher to {self.topic}...")
|
||||||
self.producer = self.client.create_producer(
|
self.producer = self.backend.create_producer(
|
||||||
topic = self.topic,
|
topic = self.topic,
|
||||||
schema = JsonSchema(self.schema),
|
schema = self.schema,
|
||||||
chunking_enabled = self.chunking_enabled,
|
chunking_enabled = self.chunking_enabled,
|
||||||
)
|
)
|
||||||
logger.info(f"Connected publisher to {self.topic}")
|
logger.info(f"Connected publisher to {self.topic}")
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ class ProducerSpec(Spec):
|
||||||
)
|
)
|
||||||
|
|
||||||
producer = Producer(
|
producer = Producer(
|
||||||
client = processor.pulsar_client,
|
backend = processor.pubsub,
|
||||||
topic = definition[self.name],
|
topic = definition[self.name],
|
||||||
schema = self.schema,
|
schema = self.schema,
|
||||||
metrics = producer_metrics,
|
metrics = producer_metrics,
|
||||||
|
|
|
||||||
|
|
@ -37,33 +37,34 @@ class PromptClient(RequestResponse):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.info("DEBUG prompt_client: Streaming path")
|
logger.info("DEBUG prompt_client: Streaming path")
|
||||||
# Streaming path - collect all chunks
|
# Streaming path - just forward chunks, don't accumulate
|
||||||
full_text = ""
|
last_text = ""
|
||||||
full_object = None
|
last_object = None
|
||||||
|
|
||||||
async def collect_chunks(resp):
|
async def forward_chunks(resp):
|
||||||
nonlocal full_text, full_object
|
nonlocal last_text, last_object
|
||||||
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||||
|
|
||||||
if resp.error:
|
if resp.error:
|
||||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||||
raise RuntimeError(resp.error.message)
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
if resp.text:
|
end_stream = getattr(resp, 'end_of_stream', False)
|
||||||
full_text += resp.text
|
|
||||||
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars")
|
# Always call callback if there's text OR if it's the final message
|
||||||
# Call chunk callback if provided
|
if resp.text is not None:
|
||||||
|
last_text = resp.text
|
||||||
|
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
|
||||||
if asyncio.iscoroutinefunction(chunk_callback):
|
if asyncio.iscoroutinefunction(chunk_callback):
|
||||||
await chunk_callback(resp.text)
|
await chunk_callback(resp.text, end_stream)
|
||||||
else:
|
else:
|
||||||
chunk_callback(resp.text)
|
chunk_callback(resp.text, end_stream)
|
||||||
elif resp.object:
|
elif resp.object:
|
||||||
logger.info(f"DEBUG prompt_client: Got object response")
|
logger.info(f"DEBUG prompt_client: Got object response")
|
||||||
full_object = resp.object
|
last_object = resp.object
|
||||||
|
|
||||||
end_stream = getattr(resp, 'end_of_stream', False)
|
|
||||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||||
return end_stream
|
return end_stream
|
||||||
|
|
||||||
|
|
@ -79,17 +80,17 @@ class PromptClient(RequestResponse):
|
||||||
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
||||||
await self.request(
|
await self.request(
|
||||||
req,
|
req,
|
||||||
recipient=collect_chunks,
|
recipient=forward_chunks,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars")
|
logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
|
||||||
|
|
||||||
if full_text:
|
if last_text:
|
||||||
logger.info("DEBUG prompt_client: Returning full_text")
|
logger.info("DEBUG prompt_client: Returning last_text")
|
||||||
return full_text
|
return last_text
|
||||||
|
|
||||||
logger.info("DEBUG prompt_client: Returning parsed full_object")
|
logger.info("DEBUG prompt_client: Returning parsed last_object")
|
||||||
return json.loads(full_object)
|
return json.loads(last_object) if last_object else None
|
||||||
|
|
||||||
async def extract_definitions(self, text, timeout=600):
|
async def extract_definitions(self, text, timeout=600):
|
||||||
return await self.prompt(
|
return await self.prompt(
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,6 @@
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import pulsar
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -11,9 +8,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Publisher:
|
class Publisher:
|
||||||
|
|
||||||
def __init__(self, client, topic, schema=None, max_size=10,
|
def __init__(self, backend, topic, schema=None, max_size=10,
|
||||||
chunking_enabled=True, drain_timeout=5.0):
|
chunking_enabled=True, drain_timeout=5.0):
|
||||||
self.client = client
|
self.backend = backend # Changed from 'client' to 'backend'
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.q = asyncio.Queue(maxsize=max_size)
|
self.q = asyncio.Queue(maxsize=max_size)
|
||||||
|
|
@ -47,9 +44,9 @@ class Publisher:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
producer = self.client.create_producer(
|
producer = self.backend.create_producer(
|
||||||
topic=self.topic,
|
topic=self.topic,
|
||||||
schema=JsonSchema(self.schema),
|
schema=self.schema,
|
||||||
chunking_enabled=self.chunking_enabled,
|
chunking_enabled=self.chunking_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,45 @@ import pulsar
|
||||||
import _pulsar
|
import _pulsar
|
||||||
import uuid
|
import uuid
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
|
import logging
|
||||||
|
|
||||||
from .. log_level import LogLevel
|
from .. log_level import LogLevel
|
||||||
|
from .pulsar_backend import PulsarBackend
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pubsub(**config):
|
||||||
|
"""
|
||||||
|
Factory function to create a pub/sub backend based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary from command-line args
|
||||||
|
Must include 'pubsub_backend' key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend instance (PulsarBackend, MQTTBackend, etc.)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
backend = get_pubsub(
|
||||||
|
pubsub_backend='pulsar',
|
||||||
|
pulsar_host='pulsar://localhost:6650'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
backend_type = config.get('pubsub_backend', 'pulsar')
|
||||||
|
|
||||||
|
if backend_type == 'pulsar':
|
||||||
|
return PulsarBackend(
|
||||||
|
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
|
||||||
|
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
|
||||||
|
listener=config.get('pulsar_listener'),
|
||||||
|
)
|
||||||
|
elif backend_type == 'mqtt':
|
||||||
|
# TODO: Implement MQTT backend
|
||||||
|
raise NotImplementedError("MQTT backend not yet implemented")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||||
|
|
||||||
|
|
||||||
class PulsarClient:
|
class PulsarClient:
|
||||||
|
|
||||||
|
|
@ -71,10 +108,3 @@ class PulsarClient:
|
||||||
'--pulsar-listener',
|
'--pulsar-listener',
|
||||||
help=f'Pulsar listener (default: none)',
|
help=f'Pulsar listener (default: none)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-l', '--log-level',
|
|
||||||
default='INFO',
|
|
||||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
||||||
help=f'Log level (default: INFO)'
|
|
||||||
)
|
|
||||||
|
|
|
||||||
350
trustgraph-base/trustgraph/base/pulsar_backend.py
Normal file
350
trustgraph-base/trustgraph/base/pulsar_backend.py
Normal file
|
|
@ -0,0 +1,350 @@
|
||||||
|
"""
|
||||||
|
Pulsar backend implementation for pub/sub abstraction.
|
||||||
|
|
||||||
|
This module provides a Pulsar-specific implementation of the backend interfaces,
|
||||||
|
handling topic mapping, serialization, and Pulsar client management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pulsar
|
||||||
|
import _pulsar
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import types
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def dataclass_to_dict(obj: Any) -> dict:
|
||||||
|
"""
|
||||||
|
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||||
|
|
||||||
|
None values are excluded from the dictionary (not serialized).
|
||||||
|
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
|
||||||
|
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
|
||||||
|
# Handle dataclass - convert to dict then recursively process all values
|
||||||
|
if is_dataclass(obj):
|
||||||
|
result = {}
|
||||||
|
for key, value in asdict(obj).items():
|
||||||
|
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Handle list - recursively process all items
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [dataclass_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
# Handle dict - recursively process all values
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||||
|
|
||||||
|
# Return primitive types as-is
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||||
|
"""
|
||||||
|
Convert a dictionary back to a dataclass instance.
|
||||||
|
|
||||||
|
Handles nested dataclasses and missing fields.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not is_dataclass(cls):
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Get field types from the dataclass
|
||||||
|
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
if key in field_types:
|
||||||
|
field_type = field_types[key]
|
||||||
|
|
||||||
|
# Handle modern union types (X | Y)
|
||||||
|
if isinstance(field_type, types.UnionType):
|
||||||
|
# Check if it's Optional (X | None)
|
||||||
|
if type(None) in field_type.__args__:
|
||||||
|
# Get the non-None type
|
||||||
|
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||||
|
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Check if this is a generic type (list, dict, etc.)
|
||||||
|
elif hasattr(field_type, '__origin__'):
|
||||||
|
# Handle list[T]
|
||||||
|
if field_type.__origin__ == list:
|
||||||
|
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||||
|
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||||
|
kwargs[key] = [
|
||||||
|
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Handle old-style Optional[T] (which is Union[T, None])
|
||||||
|
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||||
|
# Get the non-None type from Union
|
||||||
|
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||||
|
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Handle direct dataclass fields
|
||||||
|
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||||
|
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||||
|
elif field_type == bytes and isinstance(value, str):
|
||||||
|
kwargs[key] = value.encode('utf-8')
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarMessage:
|
||||||
|
"""Wrapper for Pulsar messages to match Message protocol."""
|
||||||
|
|
||||||
|
def __init__(self, pulsar_msg, schema_cls):
|
||||||
|
self._msg = pulsar_msg
|
||||||
|
self._schema_cls = schema_cls
|
||||||
|
self._value = None
|
||||||
|
|
||||||
|
def value(self) -> Any:
|
||||||
|
"""Deserialize and return the message value as a dataclass."""
|
||||||
|
if self._value is None:
|
||||||
|
# Get JSON string from Pulsar message
|
||||||
|
json_data = self._msg.data().decode('utf-8')
|
||||||
|
data_dict = json.loads(json_data)
|
||||||
|
# Convert to dataclass
|
||||||
|
self._value = dict_to_dataclass(data_dict, self._schema_cls)
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def properties(self) -> dict:
|
||||||
|
"""Return message properties."""
|
||||||
|
return self._msg.properties()
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarBackendProducer:
|
||||||
|
"""Pulsar-specific producer implementation."""
|
||||||
|
|
||||||
|
def __init__(self, pulsar_producer, schema_cls):
|
||||||
|
self._producer = pulsar_producer
|
||||||
|
self._schema_cls = schema_cls
|
||||||
|
|
||||||
|
def send(self, message: Any, properties: dict = {}) -> None:
|
||||||
|
"""Send a dataclass message."""
|
||||||
|
# Convert dataclass to dict, excluding None values
|
||||||
|
data_dict = dataclass_to_dict(message)
|
||||||
|
# Serialize to JSON
|
||||||
|
json_data = json.dumps(data_dict)
|
||||||
|
# Send via Pulsar
|
||||||
|
self._producer.send(json_data.encode('utf-8'), properties=properties)
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush buffered messages."""
|
||||||
|
self._producer.flush()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the producer."""
|
||||||
|
self._producer.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarBackendConsumer:
|
||||||
|
"""Pulsar-specific consumer implementation."""
|
||||||
|
|
||||||
|
def __init__(self, pulsar_consumer, schema_cls):
|
||||||
|
self._consumer = pulsar_consumer
|
||||||
|
self._schema_cls = schema_cls
|
||||||
|
|
||||||
|
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||||
|
"""Receive a message."""
|
||||||
|
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
|
||||||
|
return PulsarMessage(pulsar_msg, self._schema_cls)
|
||||||
|
|
||||||
|
def acknowledge(self, message: Message) -> None:
|
||||||
|
"""Acknowledge a message."""
|
||||||
|
if isinstance(message, PulsarMessage):
|
||||||
|
self._consumer.acknowledge(message._msg)
|
||||||
|
|
||||||
|
def negative_acknowledge(self, message: Message) -> None:
|
||||||
|
"""Negative acknowledge a message."""
|
||||||
|
if isinstance(message, PulsarMessage):
|
||||||
|
self._consumer.negative_acknowledge(message._msg)
|
||||||
|
|
||||||
|
def unsubscribe(self) -> None:
|
||||||
|
"""Unsubscribe from the topic."""
|
||||||
|
self._consumer.unsubscribe()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the consumer."""
|
||||||
|
self._consumer.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarBackend:
|
||||||
|
"""
|
||||||
|
Pulsar backend implementation.
|
||||||
|
|
||||||
|
Handles topic mapping, client management, and creation of Pulsar-specific
|
||||||
|
producers and consumers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host: str, api_key: str = None, listener: str = None):
|
||||||
|
"""
|
||||||
|
Initialize Pulsar backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Pulsar broker URL (e.g., pulsar://localhost:6650)
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
listener: Optional listener name for multi-homed setups
|
||||||
|
"""
|
||||||
|
self.host = host
|
||||||
|
self.api_key = api_key
|
||||||
|
self.listener = listener
|
||||||
|
|
||||||
|
# Create Pulsar client
|
||||||
|
client_args = {'service_url': host}
|
||||||
|
|
||||||
|
if listener:
|
||||||
|
client_args['listener_name'] = listener
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
client_args['authentication'] = pulsar.AuthenticationToken(api_key)
|
||||||
|
|
||||||
|
self.client = pulsar.Client(**client_args)
|
||||||
|
logger.info(f"Pulsar client connected to {host}")
|
||||||
|
|
||||||
|
def map_topic(self, generic_topic: str) -> str:
|
||||||
|
"""
|
||||||
|
Map generic topic format to Pulsar URI.
|
||||||
|
|
||||||
|
Format: qos/tenant/namespace/queue
|
||||||
|
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generic_topic: Generic topic string or already-formatted Pulsar URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pulsar topic URI
|
||||||
|
"""
|
||||||
|
# If already a Pulsar URI, return as-is
|
||||||
|
if '://' in generic_topic:
|
||||||
|
return generic_topic
|
||||||
|
|
||||||
|
parts = generic_topic.split('/', 3)
|
||||||
|
if len(parts) != 4:
|
||||||
|
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
|
||||||
|
|
||||||
|
qos, tenant, namespace, queue = parts
|
||||||
|
|
||||||
|
# Map QoS to persistence
|
||||||
|
if qos == 'q0':
|
||||||
|
persistence = 'non-persistent'
|
||||||
|
elif qos in ['q1', 'q2']:
|
||||||
|
persistence = 'persistent'
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
|
||||||
|
|
||||||
|
return f"{persistence}://{tenant}/{namespace}/{queue}"
|
||||||
|
|
||||||
|
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||||
|
"""
|
||||||
|
Create a Pulsar producer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
**options: Backend-specific options (e.g., chunking_enabled)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PulsarBackendProducer instance
|
||||||
|
"""
|
||||||
|
pulsar_topic = self.map_topic(topic)
|
||||||
|
|
||||||
|
producer_args = {
|
||||||
|
'topic': pulsar_topic,
|
||||||
|
'schema': pulsar.schema.BytesSchema(), # We handle serialization ourselves
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if 'chunking_enabled' in options:
|
||||||
|
producer_args['chunking_enabled'] = options['chunking_enabled']
|
||||||
|
|
||||||
|
pulsar_producer = self.client.create_producer(**producer_args)
|
||||||
|
logger.debug(f"Created producer for topic: {pulsar_topic}")
|
||||||
|
|
||||||
|
return PulsarBackendProducer(pulsar_producer, schema)
|
||||||
|
|
||||||
|
def create_consumer(
|
||||||
|
self,
|
||||||
|
topic: str,
|
||||||
|
subscription: str,
|
||||||
|
schema: type,
|
||||||
|
initial_position: str = 'latest',
|
||||||
|
consumer_type: str = 'shared',
|
||||||
|
**options
|
||||||
|
) -> BackendConsumer:
|
||||||
|
"""
|
||||||
|
Create a Pulsar consumer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||||
|
subscription: Subscription name
|
||||||
|
schema: Dataclass type for messages
|
||||||
|
initial_position: 'earliest' or 'latest'
|
||||||
|
consumer_type: 'shared', 'exclusive', or 'failover'
|
||||||
|
**options: Backend-specific options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PulsarBackendConsumer instance
|
||||||
|
"""
|
||||||
|
pulsar_topic = self.map_topic(topic)
|
||||||
|
|
||||||
|
# Map initial position
|
||||||
|
if initial_position == 'earliest':
|
||||||
|
pos = pulsar.InitialPosition.Earliest
|
||||||
|
else:
|
||||||
|
pos = pulsar.InitialPosition.Latest
|
||||||
|
|
||||||
|
# Map consumer type
|
||||||
|
if consumer_type == 'exclusive':
|
||||||
|
ctype = pulsar.ConsumerType.Exclusive
|
||||||
|
elif consumer_type == 'failover':
|
||||||
|
ctype = pulsar.ConsumerType.Failover
|
||||||
|
else:
|
||||||
|
ctype = pulsar.ConsumerType.Shared
|
||||||
|
|
||||||
|
consumer_args = {
|
||||||
|
'topic': pulsar_topic,
|
||||||
|
'subscription_name': subscription,
|
||||||
|
'schema': pulsar.schema.BytesSchema(), # We handle deserialization ourselves
|
||||||
|
'initial_position': pos,
|
||||||
|
'consumer_type': ctype,
|
||||||
|
}
|
||||||
|
|
||||||
|
pulsar_consumer = self.client.subscribe(**consumer_args)
|
||||||
|
logger.debug(f"Created consumer for topic: {pulsar_topic}, subscription: {subscription}")
|
||||||
|
|
||||||
|
return PulsarBackendConsumer(pulsar_consumer, schema)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the Pulsar client."""
|
||||||
|
self.client.close()
|
||||||
|
logger.info("Pulsar client closed")
|
||||||
|
|
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||||
class RequestResponse(Subscriber):
|
class RequestResponse(Subscriber):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, client, subscription, consumer_name,
|
self, backend, subscription, consumer_name,
|
||||||
request_topic, request_schema,
|
request_topic, request_schema,
|
||||||
request_metrics,
|
request_metrics,
|
||||||
response_topic, response_schema,
|
response_topic, response_schema,
|
||||||
|
|
@ -22,7 +22,7 @@ class RequestResponse(Subscriber):
|
||||||
):
|
):
|
||||||
|
|
||||||
super(RequestResponse, self).__init__(
|
super(RequestResponse, self).__init__(
|
||||||
client = client,
|
backend = backend,
|
||||||
subscription = subscription,
|
subscription = subscription,
|
||||||
consumer_name = consumer_name,
|
consumer_name = consumer_name,
|
||||||
topic = response_topic,
|
topic = response_topic,
|
||||||
|
|
@ -31,7 +31,7 @@ class RequestResponse(Subscriber):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.producer = Producer(
|
self.producer = Producer(
|
||||||
client = client,
|
backend = backend,
|
||||||
topic = request_topic,
|
topic = request_topic,
|
||||||
schema = request_schema,
|
schema = request_schema,
|
||||||
metrics = request_metrics,
|
metrics = request_metrics,
|
||||||
|
|
@ -126,7 +126,7 @@ class RequestResponseSpec(Spec):
|
||||||
)
|
)
|
||||||
|
|
||||||
rr = self.impl(
|
rr = self.impl(
|
||||||
client = processor.pulsar_client,
|
backend = processor.pubsub,
|
||||||
|
|
||||||
# Make subscription names unique, so that all subscribers get
|
# Make subscription names unique, so that all subscribers get
|
||||||
# to see all response messages
|
# to see all response messages
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,7 @@
|
||||||
# off of a queue and make it available using an internal broker system,
|
# off of a queue and make it available using an internal broker system,
|
||||||
# so suitable for when multiple recipients are reading from the same queue
|
# so suitable for when multiple recipients are reading from the same queue
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import _pulsar
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -13,12 +11,16 @@ import uuid
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Timeout exception - can come from different backends
|
||||||
|
class TimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class Subscriber:
|
class Subscriber:
|
||||||
|
|
||||||
def __init__(self, client, topic, subscription, consumer_name,
|
def __init__(self, backend, topic, subscription, consumer_name,
|
||||||
schema=None, max_size=100, metrics=None,
|
schema=None, max_size=100, metrics=None,
|
||||||
backpressure_strategy="block", drain_timeout=5.0):
|
backpressure_strategy="block", drain_timeout=5.0):
|
||||||
self.client = client
|
self.backend = backend # Changed from 'client' to 'backend'
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.subscription = subscription
|
self.subscription = subscription
|
||||||
self.consumer_name = consumer_name
|
self.consumer_name = consumer_name
|
||||||
|
|
@ -43,18 +45,14 @@ class Subscriber:
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
# Build subscribe arguments
|
# Create consumer via backend
|
||||||
subscribe_args = {
|
self.consumer = await asyncio.to_thread(
|
||||||
'topic': self.topic,
|
self.backend.create_consumer,
|
||||||
'subscription_name': self.subscription,
|
topic=self.topic,
|
||||||
'consumer_name': self.consumer_name,
|
subscription=self.subscription,
|
||||||
}
|
schema=self.schema,
|
||||||
|
consumer_type='shared',
|
||||||
# Only add schema if provided (omit if None)
|
)
|
||||||
if self.schema is not None:
|
|
||||||
subscribe_args['schema'] = JsonSchema(self.schema)
|
|
||||||
|
|
||||||
self.consumer = self.client.subscribe(**subscribe_args)
|
|
||||||
|
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
||||||
|
|
@ -94,12 +92,13 @@ class Subscriber:
|
||||||
drain_end_time = time.time() + self.drain_timeout
|
drain_end_time = time.time() + self.drain_timeout
|
||||||
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||||
|
|
||||||
# Stop accepting new messages from Pulsar during drain
|
# Stop accepting new messages during drain
|
||||||
if self.consumer:
|
# Note: Not all backends support pausing message listeners
|
||||||
|
if self.consumer and hasattr(self.consumer, 'pause_message_listener'):
|
||||||
try:
|
try:
|
||||||
self.consumer.pause_message_listener()
|
self.consumer.pause_message_listener()
|
||||||
except _pulsar.InvalidConfiguration:
|
except Exception:
|
||||||
# Not all consumers have message listeners (e.g., blocking receive mode)
|
# Not all consumers support message listeners
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check drain timeout
|
# Check drain timeout
|
||||||
|
|
@ -133,9 +132,10 @@ class Subscriber:
|
||||||
self.consumer.receive,
|
self.consumer.receive,
|
||||||
timeout_millis=250
|
timeout_millis=250
|
||||||
)
|
)
|
||||||
except _pulsar.Timeout:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Handle timeout from any backend
|
||||||
|
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
|
||||||
|
continue
|
||||||
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
@ -157,19 +157,20 @@ class Subscriber:
|
||||||
for msg in self.pending_acks.values():
|
for msg in self.pending_acks.values():
|
||||||
try:
|
try:
|
||||||
self.consumer.negative_acknowledge(msg)
|
self.consumer.negative_acknowledge(msg)
|
||||||
except _pulsar.AlreadyClosed:
|
except Exception:
|
||||||
pass # Consumer already closed
|
pass # Consumer already closed or error
|
||||||
self.pending_acks.clear()
|
self.pending_acks.clear()
|
||||||
|
|
||||||
if self.consumer:
|
if self.consumer:
|
||||||
|
if hasattr(self.consumer, 'unsubscribe'):
|
||||||
try:
|
try:
|
||||||
self.consumer.unsubscribe()
|
self.consumer.unsubscribe()
|
||||||
except _pulsar.AlreadyClosed:
|
except Exception:
|
||||||
pass # Already closed
|
pass # Already closed or error
|
||||||
try:
|
try:
|
||||||
self.consumer.close()
|
self.consumer.close()
|
||||||
except _pulsar.AlreadyClosed:
|
except Exception:
|
||||||
pass # Already closed
|
pass # Already closed or error
|
||||||
self.consumer = None
|
self.consumer = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ class SubscriberSpec(Spec):
|
||||||
)
|
)
|
||||||
|
|
||||||
subscriber = Subscriber(
|
subscriber = Subscriber(
|
||||||
client = processor.pulsar_client,
|
backend = processor.pubsub,
|
||||||
topic = definition[self.name],
|
topic = definition[self.name],
|
||||||
subscription = flow.id,
|
subscription = flow.id,
|
||||||
consumer_name = flow.id,
|
consumer_name = flow.id,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import time
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
|
|
||||||
from .. exceptions import *
|
from .. exceptions import *
|
||||||
|
from ..base.pubsub import get_pubsub
|
||||||
|
|
||||||
# Default timeout for a request/response. In seconds.
|
# Default timeout for a request/response. In seconds.
|
||||||
DEFAULT_TIMEOUT=300
|
DEFAULT_TIMEOUT=300
|
||||||
|
|
@ -39,30 +40,25 @@ class BaseClient:
|
||||||
if subscriber == None:
|
if subscriber == None:
|
||||||
subscriber = str(uuid.uuid4())
|
subscriber = str(uuid.uuid4())
|
||||||
|
|
||||||
if pulsar_api_key:
|
# Create backend using factory
|
||||||
auth = pulsar.AuthenticationToken(pulsar_api_key)
|
self.backend = get_pubsub(
|
||||||
self.client = pulsar.Client(
|
pulsar_host=pulsar_host,
|
||||||
pulsar_host,
|
pulsar_api_key=pulsar_api_key,
|
||||||
logger=pulsar.ConsoleLogger(log_level),
|
pulsar_listener=listener,
|
||||||
authentication=auth,
|
pubsub_backend='pulsar'
|
||||||
listener=listener,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = pulsar.Client(
|
|
||||||
pulsar_host,
|
|
||||||
logger=pulsar.ConsoleLogger(log_level),
|
|
||||||
listener_name=listener,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.producer = self.client.create_producer(
|
self.producer = self.backend.create_producer(
|
||||||
topic=input_queue,
|
topic=input_queue,
|
||||||
schema=JsonSchema(input_schema),
|
schema=input_schema,
|
||||||
chunking_enabled=True,
|
chunking_enabled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.consumer = self.client.subscribe(
|
self.consumer = self.backend.create_consumer(
|
||||||
output_queue, subscriber,
|
topic=output_queue,
|
||||||
schema=JsonSchema(output_schema),
|
subscription=subscriber,
|
||||||
|
schema=output_schema,
|
||||||
|
consumer_type='shared',
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_schema = input_schema
|
self.input_schema = input_schema
|
||||||
|
|
@ -141,5 +137,6 @@ class BaseClient:
|
||||||
self.producer.flush()
|
self.producer.flush()
|
||||||
self.producer.close()
|
self.producer.close()
|
||||||
|
|
||||||
self.client.close()
|
if hasattr(self, "backend"):
|
||||||
|
self.backend.close()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,6 @@ class ConfigClient(BaseClient):
|
||||||
def get(self, keys, timeout=300):
|
def get(self, keys, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="get",
|
operation="get",
|
||||||
keys=[
|
keys=[
|
||||||
ConfigKey(
|
ConfigKey(
|
||||||
|
|
@ -88,7 +87,6 @@ class ConfigClient(BaseClient):
|
||||||
def list(self, type, timeout=300):
|
def list(self, type, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="list",
|
operation="list",
|
||||||
type=type,
|
type=type,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
|
|
@ -99,7 +97,6 @@ class ConfigClient(BaseClient):
|
||||||
def getvalues(self, type, timeout=300):
|
def getvalues(self, type, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="getvalues",
|
operation="getvalues",
|
||||||
type=type,
|
type=type,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
|
|
@ -117,7 +114,6 @@ class ConfigClient(BaseClient):
|
||||||
def delete(self, keys, timeout=300):
|
def delete(self, keys, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="delete",
|
operation="delete",
|
||||||
keys=[
|
keys=[
|
||||||
ConfigKey(
|
ConfigKey(
|
||||||
|
|
@ -134,7 +130,6 @@ class ConfigClient(BaseClient):
|
||||||
def put(self, values, timeout=300):
|
def put(self, values, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="put",
|
operation="put",
|
||||||
values=[
|
values=[
|
||||||
ConfigValue(
|
ConfigValue(
|
||||||
|
|
@ -152,7 +147,6 @@ class ConfigClient(BaseClient):
|
||||||
def config(self, timeout=300):
|
def config(self, timeout=300):
|
||||||
|
|
||||||
resp = self.call(
|
resp = self.call(
|
||||||
id=id,
|
|
||||||
operation="config",
|
operation="config",
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
||||||
name=data.get("name"),
|
name=data.get("name"),
|
||||||
description=data.get("description"),
|
description=data.get("description"),
|
||||||
tags=data.get("tags"),
|
tags=data.get("tags"),
|
||||||
created_at=data.get("created_at"),
|
|
||||||
updated_at=data.get("updated_at"),
|
|
||||||
tag_filter=data.get("tag_filter"),
|
tag_filter=data.get("tag_filter"),
|
||||||
limit=data.get("limit")
|
limit=data.get("limit")
|
||||||
)
|
)
|
||||||
|
|
@ -38,10 +36,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
||||||
result["description"] = obj.description
|
result["description"] = obj.description
|
||||||
if obj.tags is not None:
|
if obj.tags is not None:
|
||||||
result["tags"] = list(obj.tags)
|
result["tags"] = list(obj.tags)
|
||||||
if obj.created_at is not None:
|
|
||||||
result["created_at"] = obj.created_at
|
|
||||||
if obj.updated_at is not None:
|
|
||||||
result["updated_at"] = obj.updated_at
|
|
||||||
if obj.tag_filter is not None:
|
if obj.tag_filter is not None:
|
||||||
result["tag_filter"] = list(obj.tag_filter)
|
result["tag_filter"] = list(obj.tag_filter)
|
||||||
if obj.limit is not None:
|
if obj.limit is not None:
|
||||||
|
|
@ -73,9 +67,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
|
||||||
collection=coll_data.get("collection"),
|
collection=coll_data.get("collection"),
|
||||||
name=coll_data.get("name"),
|
name=coll_data.get("name"),
|
||||||
description=coll_data.get("description"),
|
description=coll_data.get("description"),
|
||||||
tags=coll_data.get("tags"),
|
tags=coll_data.get("tags", [])
|
||||||
created_at=coll_data.get("created_at"),
|
|
||||||
updated_at=coll_data.get("updated_at")
|
|
||||||
))
|
))
|
||||||
|
|
||||||
return CollectionManagementResponse(
|
return CollectionManagementResponse(
|
||||||
|
|
@ -104,9 +96,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
|
||||||
"collection": coll.collection,
|
"collection": coll.collection,
|
||||||
"name": coll.name,
|
"name": coll.name,
|
||||||
"description": coll.description,
|
"description": coll.description,
|
||||||
"tags": list(coll.tags) if coll.tags else [],
|
"tags": list(coll.tags) if coll.tags else []
|
||||||
"created_at": coll.created_at,
|
|
||||||
"updated_at": coll.updated_at
|
|
||||||
})
|
})
|
||||||
|
|
||||||
print("RESULT IS", result, flush=True)
|
print("RESULT IS", result, flush=True)
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,9 @@ class StructuredDataDiagnosisResponseTranslator(MessageTranslator):
|
||||||
result["descriptor"] = obj.descriptor
|
result["descriptor"] = obj.descriptor
|
||||||
if obj.metadata:
|
if obj.metadata:
|
||||||
result["metadata"] = obj.metadata
|
result["metadata"] = obj.metadata
|
||||||
if obj.schema_matches is not None:
|
# For schema-selection, always include schema_matches (even if empty)
|
||||||
|
# For other operations, only include if non-empty
|
||||||
|
if obj.operation == "schema-selection" or obj.schema_matches:
|
||||||
result["schema-matches"] = obj.schema_matches
|
result["schema-matches"] = obj.schema_matches
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -43,11 +43,16 @@ class PromptResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if obj.text:
|
# Include text field if present (even if empty string)
|
||||||
|
if obj.text is not None:
|
||||||
result["text"] = obj.text
|
result["text"] = obj.text
|
||||||
if obj.object:
|
# Include object field if present
|
||||||
|
if obj.object is not None:
|
||||||
result["object"] = obj.object
|
result["object"] = obj.object
|
||||||
|
|
||||||
|
# Always include end_of_stream flag for streaming support
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
|
|
||||||
|
|
@ -34,15 +34,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# Check if this is a streaming response (has chunk)
|
# Include response content (even if empty string)
|
||||||
if hasattr(obj, 'chunk') and obj.chunk:
|
if obj.response is not None:
|
||||||
result["chunk"] = obj.chunk
|
|
||||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
|
||||||
else:
|
|
||||||
# Non-streaming response
|
|
||||||
if obj.response:
|
|
||||||
result["response"] = obj.response
|
result["response"] = obj.response
|
||||||
|
|
||||||
|
# Include end_of_stream flag
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
# Always include error if present
|
# Always include error if present
|
||||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||||
|
|
@ -51,13 +49,7 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
# For streaming responses, check end_of_stream
|
|
||||||
if hasattr(obj, 'chunk') and obj.chunk:
|
|
||||||
is_final = getattr(obj, 'end_of_stream', False)
|
is_final = getattr(obj, 'end_of_stream', False)
|
||||||
else:
|
|
||||||
# For non-streaming responses, it's always final
|
|
||||||
is_final = True
|
|
||||||
|
|
||||||
return self.from_pulsar(obj), is_final
|
return self.from_pulsar(obj), is_final
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -98,15 +90,13 @@ class GraphRagResponseTranslator(MessageTranslator):
|
||||||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# Check if this is a streaming response (has chunk)
|
# Include response content (even if empty string)
|
||||||
if hasattr(obj, 'chunk') and obj.chunk:
|
if obj.response is not None:
|
||||||
result["chunk"] = obj.chunk
|
|
||||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
|
||||||
else:
|
|
||||||
# Non-streaming response
|
|
||||||
if obj.response:
|
|
||||||
result["response"] = obj.response
|
result["response"] = obj.response
|
||||||
|
|
||||||
|
# Include end_of_stream flag
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
# Always include error if present
|
# Always include error if present
|
||||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||||
|
|
@ -115,11 +105,5 @@ class GraphRagResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
# For streaming responses, check end_of_stream
|
|
||||||
if hasattr(obj, 'chunk') and obj.chunk:
|
|
||||||
is_final = getattr(obj, 'end_of_stream', False)
|
is_final = getattr(obj, 'end_of_stream', False)
|
||||||
else:
|
|
||||||
# For non-streaming responses, it's always final
|
|
||||||
is_final = True
|
|
||||||
|
|
||||||
return self.from_pulsar(obj), is_final
|
return self.from_pulsar(obj), is_final
|
||||||
|
|
@ -36,6 +36,9 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
||||||
if obj.model:
|
if obj.model:
|
||||||
result["model"] = obj.model
|
result["model"] = obj.model
|
||||||
|
|
||||||
|
# Always include end_of_stream flag for streaming support
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,14 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pulsar.schema import Record, String, Array
|
|
||||||
from .primitives import Triple
|
from .primitives import Triple
|
||||||
|
|
||||||
class Metadata(Record):
|
@dataclass
|
||||||
|
class Metadata:
|
||||||
# Source identifier
|
# Source identifier
|
||||||
id = String()
|
id: str = ""
|
||||||
|
|
||||||
# Subgraph
|
# Subgraph
|
||||||
metadata = Array(Triple())
|
metadata: list[Triple] = field(default_factory=list)
|
||||||
|
|
||||||
# Collection management
|
# Collection management
|
||||||
user = String()
|
user: str = ""
|
||||||
collection = String()
|
collection: str = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,39 @@
|
||||||
|
|
||||||
from pulsar.schema import Record, String, Boolean, Array, Integer
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
class Error(Record):
|
@dataclass
|
||||||
type = String()
|
class Error:
|
||||||
message = String()
|
type: str = ""
|
||||||
|
message: str = ""
|
||||||
|
|
||||||
class Value(Record):
|
@dataclass
|
||||||
value = String()
|
class Value:
|
||||||
is_uri = Boolean()
|
value: str = ""
|
||||||
type = String()
|
is_uri: bool = False
|
||||||
|
type: str = ""
|
||||||
|
|
||||||
class Triple(Record):
|
@dataclass
|
||||||
s = Value()
|
class Triple:
|
||||||
p = Value()
|
s: Value | None = None
|
||||||
o = Value()
|
p: Value | None = None
|
||||||
|
o: Value | None = None
|
||||||
|
|
||||||
class Field(Record):
|
@dataclass
|
||||||
name = String()
|
class Field:
|
||||||
|
name: str = ""
|
||||||
# int, string, long, bool, float, double, timestamp
|
# int, string, long, bool, float, double, timestamp
|
||||||
type = String()
|
type: str = ""
|
||||||
size = Integer()
|
size: int = 0
|
||||||
primary = Boolean()
|
primary: bool = False
|
||||||
description = String()
|
description: str = ""
|
||||||
# NEW FIELDS for structured data:
|
# NEW FIELDS for structured data:
|
||||||
required = Boolean() # Whether field is required
|
required: bool = False # Whether field is required
|
||||||
enum_values = Array(String()) # For enum type fields
|
enum_values: list[str] = field(default_factory=list) # For enum type fields
|
||||||
indexed = Boolean() # Whether field should be indexed
|
indexed: bool = False # Whether field should be indexed
|
||||||
|
|
||||||
class RowSchema(Record):
|
@dataclass
|
||||||
name = String()
|
class RowSchema:
|
||||||
description = String()
|
name: str = ""
|
||||||
fields = Array(Field())
|
description: str = ""
|
||||||
|
fields: list[Field] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,23 @@
|
||||||
|
|
||||||
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
|
def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
|
||||||
return f"{kind}://{tenant}/{namespace}/{topic}"
|
"""
|
||||||
|
Create a generic topic identifier that can be mapped by backends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_name: The queue/topic name
|
||||||
|
qos: Quality of service
|
||||||
|
- 'q0' = best-effort (no ack)
|
||||||
|
- 'q1' = at-least-once (ack required)
|
||||||
|
- 'q2' = exactly-once (two-phase ack)
|
||||||
|
tenant: Tenant identifier for multi-tenancy
|
||||||
|
namespace: Namespace within tenant
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generic topic string: qos/tenant/namespace/queue_name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
topic('my-queue') # q1/tg/flow/my-queue
|
||||||
|
topic('config', qos='q2', namespace='config') # q2/tg/config/config
|
||||||
|
"""
|
||||||
|
return f"{qos}/{tenant}/{namespace}/{queue_name}"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, Bytes
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
|
|
@ -6,24 +6,27 @@ from ..core.topic import topic
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# PDF docs etc.
|
# PDF docs etc.
|
||||||
class Document(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class Document:
|
||||||
data = Bytes()
|
metadata: Metadata | None = None
|
||||||
|
data: bytes = b""
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Text documents / text from PDF
|
# Text documents / text from PDF
|
||||||
|
|
||||||
class TextDocument(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class TextDocument:
|
||||||
text = Bytes()
|
metadata: Metadata | None = None
|
||||||
|
text: bytes = b""
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Chunks of text
|
# Chunks of text
|
||||||
|
|
||||||
class Chunk(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class Chunk:
|
||||||
chunk = Bytes()
|
metadata: Metadata | None = None
|
||||||
|
chunk: bytes = b""
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double, Map
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..core.primitives import Value, RowSchema
|
from ..core.primitives import Value, RowSchema
|
||||||
|
|
@ -8,49 +8,55 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Graph embeddings are embeddings associated with a graph entity
|
# Graph embeddings are embeddings associated with a graph entity
|
||||||
|
|
||||||
class EntityEmbeddings(Record):
|
@dataclass
|
||||||
entity = Value()
|
class EntityEmbeddings:
|
||||||
vectors = Array(Array(Double()))
|
entity: Value | None = None
|
||||||
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
|
|
||||||
# This is a 'batching' mechanism for the above data
|
# This is a 'batching' mechanism for the above data
|
||||||
class GraphEmbeddings(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class GraphEmbeddings:
|
||||||
entities = Array(EntityEmbeddings())
|
metadata: Metadata | None = None
|
||||||
|
entities: list[EntityEmbeddings] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Document embeddings are embeddings associated with a chunk
|
# Document embeddings are embeddings associated with a chunk
|
||||||
|
|
||||||
class ChunkEmbeddings(Record):
|
@dataclass
|
||||||
chunk = Bytes()
|
class ChunkEmbeddings:
|
||||||
vectors = Array(Array(Double()))
|
chunk: bytes = b""
|
||||||
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
|
|
||||||
# This is a 'batching' mechanism for the above data
|
# This is a 'batching' mechanism for the above data
|
||||||
class DocumentEmbeddings(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class DocumentEmbeddings:
|
||||||
chunks = Array(ChunkEmbeddings())
|
metadata: Metadata | None = None
|
||||||
|
chunks: list[ChunkEmbeddings] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Object embeddings are embeddings associated with the primary key of an
|
# Object embeddings are embeddings associated with the primary key of an
|
||||||
# object
|
# object
|
||||||
|
|
||||||
class ObjectEmbeddings(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class ObjectEmbeddings:
|
||||||
vectors = Array(Array(Double()))
|
metadata: Metadata | None = None
|
||||||
name = String()
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
key_name = String()
|
name: str = ""
|
||||||
id = String()
|
key_name: str = ""
|
||||||
|
id: str = ""
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Structured object embeddings with enhanced capabilities
|
# Structured object embeddings with enhanced capabilities
|
||||||
|
|
||||||
class StructuredObjectEmbedding(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class StructuredObjectEmbedding:
|
||||||
vectors = Array(Array(Double()))
|
metadata: Metadata | None = None
|
||||||
schema_name = String()
|
vectors: list[list[float]] = field(default_factory=list)
|
||||||
object_id = String() # Primary key value
|
schema_name: str = ""
|
||||||
field_embeddings = Map(Array(Double())) # Per-field embeddings
|
object_id: str = "" # Primary key value
|
||||||
|
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, String, Array
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.primitives import Value, Triple
|
from ..core.primitives import Value, Triple
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
|
|
@ -8,21 +8,24 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Entity context are an entity associated with textual context
|
# Entity context are an entity associated with textual context
|
||||||
|
|
||||||
class EntityContext(Record):
|
@dataclass
|
||||||
entity = Value()
|
class EntityContext:
|
||||||
context = String()
|
entity: Value | None = None
|
||||||
|
context: str = ""
|
||||||
|
|
||||||
# This is a 'batching' mechanism for the above data
|
# This is a 'batching' mechanism for the above data
|
||||||
class EntityContexts(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class EntityContexts:
|
||||||
entities = Array(EntityContext())
|
metadata: Metadata | None = None
|
||||||
|
entities: list[EntityContext] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Graph triples
|
# Graph triples
|
||||||
|
|
||||||
class Triples(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class Triples:
|
||||||
triples = Array(Triple())
|
metadata: Metadata | None = None
|
||||||
|
triples: list[Triple] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pulsar.schema import Record, Bytes, String, Array, Long, Boolean
|
|
||||||
from ..core.primitives import Triple, Error
|
from ..core.primitives import Triple, Error
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
|
|
@ -22,40 +21,40 @@ from .embeddings import GraphEmbeddings
|
||||||
# <- ()
|
# <- ()
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
class KnowledgeRequest(Record):
|
@dataclass
|
||||||
|
class KnowledgeRequest:
|
||||||
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
|
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
|
||||||
# load-kg-core, unload-kg-core
|
# load-kg-core, unload-kg-core
|
||||||
operation = String()
|
operation: str = ""
|
||||||
|
|
||||||
# list-kg-cores, delete-kg-core, put-kg-core
|
# list-kg-cores, delete-kg-core, put-kg-core
|
||||||
user = String()
|
user: str = ""
|
||||||
|
|
||||||
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
|
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
|
||||||
# load-kg-core, unload-kg-core
|
# load-kg-core, unload-kg-core
|
||||||
id = String()
|
id: str = ""
|
||||||
|
|
||||||
# load-kg-core
|
# load-kg-core
|
||||||
flow = String()
|
flow: str = ""
|
||||||
|
|
||||||
# load-kg-core
|
# load-kg-core
|
||||||
collection = String()
|
collection: str = ""
|
||||||
|
|
||||||
# put-kg-core
|
# put-kg-core
|
||||||
triples = Triples()
|
triples: Triples | None = None
|
||||||
graph_embeddings = GraphEmbeddings()
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
|
|
||||||
class KnowledgeResponse(Record):
|
@dataclass
|
||||||
error = Error()
|
class KnowledgeResponse:
|
||||||
ids = Array(String())
|
error: Error | None = None
|
||||||
eos = Boolean() # Indicates end of knowledge core stream
|
ids: list[str] = field(default_factory=list)
|
||||||
triples = Triples()
|
eos: bool = False # Indicates end of knowledge core stream
|
||||||
graph_embeddings = GraphEmbeddings()
|
triples: Triples | None = None
|
||||||
|
graph_embeddings: GraphEmbeddings | None = None
|
||||||
|
|
||||||
knowledge_request_queue = topic(
|
knowledge_request_queue = topic(
|
||||||
'knowledge', kind='non-persistent', namespace='request'
|
'knowledge', qos='q0', namespace='request'
|
||||||
)
|
)
|
||||||
knowledge_response_queue = topic(
|
knowledge_response_queue = topic(
|
||||||
'knowledge', kind='non-persistent', namespace='response',
|
'knowledge', qos='q0', namespace='response',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, String, Boolean
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
|
|
||||||
|
|
@ -6,21 +6,25 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# NLP extraction data types
|
# NLP extraction data types
|
||||||
|
|
||||||
class Definition(Record):
|
@dataclass
|
||||||
name = String()
|
class Definition:
|
||||||
definition = String()
|
name: str = ""
|
||||||
|
definition: str = ""
|
||||||
|
|
||||||
class Topic(Record):
|
@dataclass
|
||||||
name = String()
|
class Topic:
|
||||||
definition = String()
|
name: str = ""
|
||||||
|
definition: str = ""
|
||||||
|
|
||||||
class Relationship(Record):
|
@dataclass
|
||||||
s = String()
|
class Relationship:
|
||||||
p = String()
|
s: str = ""
|
||||||
o = String()
|
p: str = ""
|
||||||
o_entity = Boolean()
|
o: str = ""
|
||||||
|
o_entity: bool = False
|
||||||
|
|
||||||
class Fact(Record):
|
@dataclass
|
||||||
s = String()
|
class Fact:
|
||||||
p = String()
|
s: str = ""
|
||||||
o = String()
|
p: str = ""
|
||||||
|
o: str = ""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, String, Map, Double, Array
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
|
|
@ -7,11 +7,13 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Extracted object from text processing
|
# Extracted object from text processing
|
||||||
|
|
||||||
class ExtractedObject(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class ExtractedObject:
|
||||||
schema_name = String() # Which schema this object belongs to
|
metadata: Metadata | None = None
|
||||||
values = Array(Map(String())) # Array of objects, each object is field name -> value
|
schema_name: str = "" # Which schema this object belongs to
|
||||||
confidence = Double()
|
values: list[dict[str, str]] = field(default_factory=list) # Array of objects, each object is field name -> value
|
||||||
source_span = String() # Text span where object was found
|
confidence: float = 0.0
|
||||||
|
source_span: str = "" # Text span where object was found
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, Array, Map, String
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..core.primitives import RowSchema
|
from ..core.primitives import RowSchema
|
||||||
|
|
@ -8,9 +8,10 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Stores rows of information
|
# Stores rows of information
|
||||||
|
|
||||||
class Rows(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class Rows:
|
||||||
row_schema = RowSchema()
|
metadata: Metadata | None = None
|
||||||
rows = Array(Map(String()))
|
row_schema: RowSchema | None = None
|
||||||
|
rows: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, String, Bytes, Map
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
|
|
@ -7,11 +7,13 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Structured data submission for fire-and-forget processing
|
# Structured data submission for fire-and-forget processing
|
||||||
|
|
||||||
class StructuredDataSubmission(Record):
|
@dataclass
|
||||||
metadata = Metadata()
|
class StructuredDataSubmission:
|
||||||
format = String() # "json", "csv", "xml"
|
metadata: Metadata | None = None
|
||||||
schema_name = String() # Reference to schema in config
|
format: str = "" # "json", "csv", "xml"
|
||||||
data = Bytes() # Raw data to ingest
|
schema_name: str = "" # Reference to schema in config
|
||||||
options = Map(String()) # Format-specific options
|
data: bytes = b"" # Raw data to ingest
|
||||||
|
options: dict[str, str] = field(default_factory=dict) # Format-specific options
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
from pulsar.schema import Record, String, Array, Map, Boolean
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
from ..core.primitives import Error
|
from ..core.primitives import Error
|
||||||
|
|
@ -8,33 +8,36 @@ from ..core.primitives import Error
|
||||||
|
|
||||||
# Prompt services, abstract the prompt generation
|
# Prompt services, abstract the prompt generation
|
||||||
|
|
||||||
class AgentStep(Record):
|
@dataclass
|
||||||
thought = String()
|
class AgentStep:
|
||||||
action = String()
|
thought: str = ""
|
||||||
arguments = Map(String())
|
action: str = ""
|
||||||
observation = String()
|
arguments: dict[str, str] = field(default_factory=dict)
|
||||||
user = String() # User context for the step
|
observation: str = ""
|
||||||
|
user: str = "" # User context for the step
|
||||||
|
|
||||||
class AgentRequest(Record):
|
@dataclass
|
||||||
question = String()
|
class AgentRequest:
|
||||||
state = String()
|
question: str = ""
|
||||||
group = Array(String())
|
state: str = ""
|
||||||
history = Array(AgentStep())
|
group: list[str] | None = None
|
||||||
user = String() # User context for multi-tenancy
|
history: list[AgentStep] = field(default_factory=list)
|
||||||
streaming = Boolean() # NEW: Enable streaming response delivery (default false)
|
user: str = "" # User context for multi-tenancy
|
||||||
|
streaming: bool = False # NEW: Enable streaming response delivery (default false)
|
||||||
|
|
||||||
class AgentResponse(Record):
|
@dataclass
|
||||||
|
class AgentResponse:
|
||||||
# Streaming-first design
|
# Streaming-first design
|
||||||
chunk_type = String() # "thought", "action", "observation", "answer", "error"
|
chunk_type: str = "" # "thought", "action", "observation", "answer", "error"
|
||||||
content = String() # The actual content (interpretation depends on chunk_type)
|
content: str = "" # The actual content (interpretation depends on chunk_type)
|
||||||
end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete
|
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
|
||||||
end_of_dialog = Boolean() # Entire agent dialog is complete
|
end_of_dialog: bool = False # Entire agent dialog is complete
|
||||||
|
|
||||||
# Legacy fields (deprecated but kept for backward compatibility)
|
# Legacy fields (deprecated but kept for backward compatibility)
|
||||||
answer = String()
|
answer: str = ""
|
||||||
error = Error()
|
error: Error | None = None
|
||||||
thought = String()
|
thought: str = ""
|
||||||
observation = String()
|
observation: str = ""
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pulsar.schema import Record, String, Integer, Array
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ..core.primitives import Error
|
from ..core.primitives import Error
|
||||||
|
|
@ -10,41 +10,40 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# Collection metadata operations (for librarian service)
|
# Collection metadata operations (for librarian service)
|
||||||
|
|
||||||
class CollectionMetadata(Record):
|
@dataclass
|
||||||
|
class CollectionMetadata:
|
||||||
"""Collection metadata record"""
|
"""Collection metadata record"""
|
||||||
user = String()
|
user: str = ""
|
||||||
collection = String()
|
collection: str = ""
|
||||||
name = String()
|
name: str = ""
|
||||||
description = String()
|
description: str = ""
|
||||||
tags = Array(String())
|
tags: list[str] = field(default_factory=list)
|
||||||
created_at = String() # ISO timestamp
|
|
||||||
updated_at = String() # ISO timestamp
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
class CollectionManagementRequest(Record):
|
@dataclass
|
||||||
|
class CollectionManagementRequest:
|
||||||
"""Request for collection management operations"""
|
"""Request for collection management operations"""
|
||||||
operation = String() # e.g., "delete-collection"
|
operation: str = "" # e.g., "delete-collection"
|
||||||
|
|
||||||
# For 'list-collections'
|
# For 'list-collections'
|
||||||
user = String()
|
user: str = ""
|
||||||
collection = String()
|
collection: str = ""
|
||||||
timestamp = String() # ISO timestamp
|
timestamp: str = "" # ISO timestamp
|
||||||
name = String()
|
name: str = ""
|
||||||
description = String()
|
description: str = ""
|
||||||
tags = Array(String())
|
tags: list[str] = field(default_factory=list)
|
||||||
created_at = String() # ISO timestamp
|
|
||||||
updated_at = String() # ISO timestamp
|
|
||||||
|
|
||||||
# For list
|
# For list
|
||||||
tag_filter = Array(String()) # Optional filter by tags
|
tag_filter: list[str] = field(default_factory=list) # Optional filter by tags
|
||||||
limit = Integer()
|
limit: int = 0
|
||||||
|
|
||||||
class CollectionManagementResponse(Record):
|
@dataclass
|
||||||
|
class CollectionManagementResponse:
|
||||||
"""Response for collection management operations"""
|
"""Response for collection management operations"""
|
||||||
error = Error() # Only populated if there's an error
|
error: Error | None = None # Only populated if there's an error
|
||||||
timestamp = String() # ISO timestamp
|
timestamp: str = "" # ISO timestamp
|
||||||
collections = Array(CollectionMetadata())
|
collections: list[CollectionMetadata] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
@ -52,8 +51,9 @@ class CollectionManagementResponse(Record):
|
||||||
# Topics
|
# Topics
|
||||||
|
|
||||||
collection_request_queue = topic(
|
collection_request_queue = topic(
|
||||||
'collection', kind='non-persistent', namespace='request'
|
'collection', qos='q0', namespace='request'
|
||||||
)
|
)
|
||||||
collection_response_queue = topic(
|
collection_response_queue = topic(
|
||||||
'collection', kind='non-persistent', namespace='response'
|
'collection', qos='q0', namespace='response'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
from ..core.primitives import Error
|
from ..core.primitives import Error
|
||||||
|
|
@ -13,58 +13,61 @@ from ..core.primitives import Error
|
||||||
# put(values) -> ()
|
# put(values) -> ()
|
||||||
# delete(keys) -> ()
|
# delete(keys) -> ()
|
||||||
# config() -> (version, config)
|
# config() -> (version, config)
|
||||||
class ConfigKey(Record):
|
@dataclass
|
||||||
type = String()
|
class ConfigKey:
|
||||||
key = String()
|
type: str = ""
|
||||||
|
key: str = ""
|
||||||
|
|
||||||
class ConfigValue(Record):
|
@dataclass
|
||||||
type = String()
|
class ConfigValue:
|
||||||
key = String()
|
type: str = ""
|
||||||
value = String()
|
key: str = ""
|
||||||
|
value: str = ""
|
||||||
|
|
||||||
# Prompt services, abstract the prompt generation
|
# Prompt services, abstract the prompt generation
|
||||||
class ConfigRequest(Record):
|
@dataclass
|
||||||
|
class ConfigRequest:
|
||||||
operation = String() # get, list, getvalues, delete, put, config
|
operation: str = "" # get, list, getvalues, delete, put, config
|
||||||
|
|
||||||
# get, delete
|
# get, delete
|
||||||
keys = Array(ConfigKey())
|
keys: list[ConfigKey] = field(default_factory=list)
|
||||||
|
|
||||||
# list, getvalues
|
# list, getvalues
|
||||||
type = String()
|
type: str = ""
|
||||||
|
|
||||||
# put
|
# put
|
||||||
values = Array(ConfigValue())
|
values: list[ConfigValue] = field(default_factory=list)
|
||||||
|
|
||||||
class ConfigResponse(Record):
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigResponse:
|
||||||
# get, list, getvalues, config
|
# get, list, getvalues, config
|
||||||
version = Integer()
|
version: int = 0
|
||||||
|
|
||||||
# get, getvalues
|
# get, getvalues
|
||||||
values = Array(ConfigValue())
|
values: list[ConfigValue] = field(default_factory=list)
|
||||||
|
|
||||||
# list
|
# list
|
||||||
directory = Array(String())
|
directory: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# config
|
# config
|
||||||
config = Map(Map(String()))
|
config: dict[str, dict[str, str]] = field(default_factory=dict)
|
||||||
|
|
||||||
# Everything
|
# Everything
|
||||||
error = Error()
|
error: Error | None = None
|
||||||
|
|
||||||
class ConfigPush(Record):
|
@dataclass
|
||||||
version = Integer()
|
class ConfigPush:
|
||||||
config = Map(Map(String()))
|
version: int = 0
|
||||||
|
config: dict[str, dict[str, str]] = field(default_factory=dict)
|
||||||
|
|
||||||
config_request_queue = topic(
|
config_request_queue = topic(
|
||||||
'config', kind='non-persistent', namespace='request'
|
'config', qos='q0', namespace='request'
|
||||||
)
|
)
|
||||||
config_response_queue = topic(
|
config_response_queue = topic(
|
||||||
'config', kind='non-persistent', namespace='response'
|
'config', qos='q0', namespace='response'
|
||||||
)
|
)
|
||||||
config_push_queue = topic(
|
config_push_queue = topic(
|
||||||
'config', kind='persistent', namespace='config'
|
'config', qos='q2', namespace='config'
|
||||||
)
|
)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
|
||||||
|
|
@ -1,33 +1,36 @@
|
||||||
from pulsar.schema import Record, String, Map, Double, Array
|
from dataclasses import dataclass, field
|
||||||
from ..core.primitives import Error
|
from ..core.primitives import Error
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Structured data diagnosis services
|
# Structured data diagnosis services
|
||||||
|
|
||||||
class StructuredDataDiagnosisRequest(Record):
|
@dataclass
|
||||||
operation = String() # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
|
class StructuredDataDiagnosisRequest:
|
||||||
sample = String() # Data sample to analyze (text content)
|
operation: str = "" # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
|
||||||
type = String() # Data type (csv, json, xml) - optional, required for generate-descriptor
|
sample: str = "" # Data sample to analyze (text content)
|
||||||
schema_name = String() # Target schema name for descriptor generation - optional
|
type: str = "" # Data type (csv, json, xml) - optional, required for generate-descriptor
|
||||||
|
schema_name: str = "" # Target schema name for descriptor generation - optional
|
||||||
|
|
||||||
# JSON encoded options (e.g., delimiter for CSV)
|
# JSON encoded options (e.g., delimiter for CSV)
|
||||||
options = Map(String())
|
options: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
class StructuredDataDiagnosisResponse(Record):
|
@dataclass
|
||||||
error = Error()
|
class StructuredDataDiagnosisResponse:
|
||||||
|
error: Error | None = None
|
||||||
|
|
||||||
operation = String() # The operation that was performed
|
operation: str = "" # The operation that was performed
|
||||||
detected_type = String() # Detected data type (for detect-type/diagnose) - optional
|
detected_type: str = "" # Detected data type (for detect-type/diagnose) - optional
|
||||||
confidence = Double() # Confidence score for type detection - optional
|
confidence: float = 0.0 # Confidence score for type detection - optional
|
||||||
|
|
||||||
# JSON encoded descriptor (for generate-descriptor/diagnose) - optional
|
# JSON encoded descriptor (for generate-descriptor/diagnose) - optional
|
||||||
descriptor = String()
|
descriptor: str = ""
|
||||||
|
|
||||||
# JSON encoded additional metadata (e.g., field count, sample records)
|
# JSON encoded additional metadata (e.g., field count, sample records)
|
||||||
metadata = Map(String())
|
metadata: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
# Array of matching schema IDs (for schema-selection operation) - optional
|
# Array of matching schema IDs (for schema-selection operation) - optional
|
||||||
schema_matches = Array(String())
|
schema_matches: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
from ..core.primitives import Error
|
from ..core.primitives import Error
|
||||||
|
|
@ -18,54 +18,54 @@ from ..core.primitives import Error
|
||||||
# stop_flow(flowid) -> ()
|
# stop_flow(flowid) -> ()
|
||||||
|
|
||||||
# Prompt services, abstract the prompt generation
|
# Prompt services, abstract the prompt generation
|
||||||
class FlowRequest(Record):
|
@dataclass
|
||||||
|
class FlowRequest:
|
||||||
operation = String() # list-classes, get-class, put-class, delete-class
|
operation: str = "" # list-classes, get-class, put-class, delete-class
|
||||||
# list-flows, get-flow, start-flow, stop-flow
|
# list-flows, get-flow, start-flow, stop-flow
|
||||||
|
|
||||||
# get_class, put_class, delete_class, start_flow
|
# get_class, put_class, delete_class, start_flow
|
||||||
class_name = String()
|
class_name: str = ""
|
||||||
|
|
||||||
# put_class
|
# put_class
|
||||||
class_definition = String()
|
class_definition: str = ""
|
||||||
|
|
||||||
# start_flow
|
# start_flow
|
||||||
description = String()
|
description: str = ""
|
||||||
|
|
||||||
# get_flow, start_flow, stop_flow
|
# get_flow, start_flow, stop_flow
|
||||||
flow_id = String()
|
flow_id: str = ""
|
||||||
|
|
||||||
# start_flow - optional parameters for flow customization
|
# start_flow - optional parameters for flow customization
|
||||||
parameters = Map(String())
|
parameters: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
class FlowResponse(Record):
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlowResponse:
|
||||||
# list_classes
|
# list_classes
|
||||||
class_names = Array(String())
|
class_names: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# list_flows
|
# list_flows
|
||||||
flow_ids = Array(String())
|
flow_ids: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# get_class
|
# get_class
|
||||||
class_definition = String()
|
class_definition: str = ""
|
||||||
|
|
||||||
# get_flow
|
# get_flow
|
||||||
flow = String()
|
flow: str = ""
|
||||||
|
|
||||||
# get_flow
|
# get_flow
|
||||||
description = String()
|
description: str = ""
|
||||||
|
|
||||||
# get_flow - parameters used when flow was started
|
# get_flow - parameters used when flow was started
|
||||||
parameters = Map(String())
|
parameters: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
# Everything
|
# Everything
|
||||||
error = Error()
|
error: Error | None = None
|
||||||
|
|
||||||
flow_request_queue = topic(
|
flow_request_queue = topic(
|
||||||
'flow', kind='non-persistent', namespace='request'
|
'flow', qos='q0', namespace='request'
|
||||||
)
|
)
|
||||||
flow_response_queue = topic(
|
flow_response_queue = topic(
|
||||||
'flow', kind='non-persistent', namespace='response'
|
'flow', qos='q0', namespace='response'
|
||||||
)
|
)
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pulsar.schema import Record, Bytes, String, Array, Long
|
|
||||||
from ..core.primitives import Triple, Error
|
from ..core.primitives import Triple, Error
|
||||||
from ..core.topic import topic
|
from ..core.topic import topic
|
||||||
from ..core.metadata import Metadata
|
from ..core.metadata import Metadata
|
||||||
from ..knowledge.document import Document, TextDocument
|
# Note: Document imports will be updated after knowledge schemas are converted
|
||||||
|
|
||||||
# add-document
|
# add-document
|
||||||
# -> (document_id, document_metadata, content)
|
# -> (document_id, document_metadata, content)
|
||||||
|
|
@ -50,76 +49,79 @@ from ..knowledge.document import Document, TextDocument
|
||||||
# <- (processing_metadata[])
|
# <- (processing_metadata[])
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
class DocumentMetadata(Record):
|
@dataclass
|
||||||
id = String()
|
class DocumentMetadata:
|
||||||
time = Long()
|
id: str = ""
|
||||||
kind = String()
|
time: int = 0
|
||||||
title = String()
|
kind: str = ""
|
||||||
comments = String()
|
title: str = ""
|
||||||
metadata = Array(Triple())
|
comments: str = ""
|
||||||
user = String()
|
metadata: list[Triple] = field(default_factory=list)
|
||||||
tags = Array(String())
|
user: str = ""
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
class ProcessingMetadata(Record):
|
@dataclass
|
||||||
id = String()
|
class ProcessingMetadata:
|
||||||
document_id = String()
|
id: str = ""
|
||||||
time = Long()
|
document_id: str = ""
|
||||||
flow = String()
|
time: int = 0
|
||||||
user = String()
|
flow: str = ""
|
||||||
collection = String()
|
user: str = ""
|
||||||
tags = Array(String())
|
collection: str = ""
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
class Criteria(Record):
|
@dataclass
|
||||||
key = String()
|
class Criteria:
|
||||||
value = String()
|
key: str = ""
|
||||||
operator = String()
|
value: str = ""
|
||||||
|
operator: str = ""
|
||||||
class LibrarianRequest(Record):
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LibrarianRequest:
|
||||||
# add-document, remove-document, update-document, get-document-metadata,
|
# add-document, remove-document, update-document, get-document-metadata,
|
||||||
# get-document-content, add-processing, remove-processing, list-documents,
|
# get-document-content, add-processing, remove-processing, list-documents,
|
||||||
# list-processing
|
# list-processing
|
||||||
operation = String()
|
operation: str = ""
|
||||||
|
|
||||||
# add-document, remove-document, update-document, get-document-metadata,
|
# add-document, remove-document, update-document, get-document-metadata,
|
||||||
# get-document-content
|
# get-document-content
|
||||||
document_id = String()
|
document_id: str = ""
|
||||||
|
|
||||||
# add-processing, remove-processing
|
# add-processing, remove-processing
|
||||||
processing_id = String()
|
processing_id: str = ""
|
||||||
|
|
||||||
# add-document, update-document
|
# add-document, update-document
|
||||||
document_metadata = DocumentMetadata()
|
document_metadata: DocumentMetadata | None = None
|
||||||
|
|
||||||
# add-processing
|
# add-processing
|
||||||
processing_metadata = ProcessingMetadata()
|
processing_metadata: ProcessingMetadata | None = None
|
||||||
|
|
||||||
# add-document
|
# add-document
|
||||||
content = Bytes()
|
content: bytes = b""
|
||||||
|
|
||||||
# list-documents, list-processing
|
# list-documents, list-processing
|
||||||
user = String()
|
user: str = ""
|
||||||
|
|
||||||
# list-documents?, list-processing?
|
# list-documents?, list-processing?
|
||||||
collection = String()
|
collection: str = ""
|
||||||
|
|
||||||
#
|
#
|
||||||
criteria = Array(Criteria())
|
criteria: list[Criteria] = field(default_factory=list)
|
||||||
|
|
||||||
class LibrarianResponse(Record):
|
@dataclass
|
||||||
error = Error()
|
class LibrarianResponse:
|
||||||
document_metadata = DocumentMetadata()
|
error: Error | None = None
|
||||||
content = Bytes()
|
document_metadata: DocumentMetadata | None = None
|
||||||
document_metadatas = Array(DocumentMetadata())
|
content: bytes = b""
|
||||||
processing_metadatas = Array(ProcessingMetadata())
|
document_metadatas: list[DocumentMetadata] = field(default_factory=list)
|
||||||
|
processing_metadatas: list[ProcessingMetadata] = field(default_factory=list)
|
||||||
|
|
||||||
# FIXME: Is this right? Using persistence on librarian so that
|
# FIXME: Is this right? Using persistence on librarian so that
|
||||||
# message chunking works
|
# message chunking works
|
||||||
|
|
||||||
librarian_request_queue = topic(
|
librarian_request_queue = topic(
|
||||||
'librarian', kind='persistent', namespace='request'
|
'librarian', qos='q1', namespace='request'
|
||||||
)
|
)
|
||||||
librarian_response_queue = topic(
|
librarian_response_queue = topic(
|
||||||
'librarian', kind='persistent', namespace='response',
|
'librarian', qos='q1', namespace='response',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue