mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge pull request #604 from trustgraph-ai/release/v1.8
Merge 1.8 into master
This commit is contained in:
commit
8dff90f36f
233 changed files with 13294 additions and 4542 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=1.7.999
|
||||
run: make update-package-versions VERSION=1.8.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# TrustGraph Librarian API
|
||||
|
||||
This API provides document library management for TrustGraph. It handles document storage,
|
||||
metadata management, and processing orchestration using hybrid storage (MinIO for content,
|
||||
Cassandra for metadata) with multi-user support.
|
||||
This API provides document library management for TrustGraph. It handles document storage,
|
||||
metadata management, and processing orchestration using hybrid storage (S3-compatible object
|
||||
storage for content, Cassandra for metadata) with multi-user support.
|
||||
|
||||
## Request/response
|
||||
|
||||
|
|
@ -374,13 +374,14 @@ await client.add_processing(
|
|||
|
||||
## Features
|
||||
|
||||
- **Hybrid Storage**: MinIO for content, Cassandra for metadata
|
||||
- **Hybrid Storage**: S3-compatible object storage (MinIO, Ceph RGW, AWS S3, etc.) for content, Cassandra for metadata
|
||||
- **Multi-user Support**: User-based document ownership and access control
|
||||
- **Rich Metadata**: RDF-style metadata triples and tagging system
|
||||
- **Processing Integration**: Automatic triggering of document processing workflows
|
||||
- **Content Types**: Support for multiple document formats (PDF, text, etc.)
|
||||
- **Collection Management**: Optional document grouping by collection
|
||||
- **Metadata Search**: Query documents by metadata criteria
|
||||
- **Flexible Storage Backend**: Works with any S3-compatible storage (MinIO, Ceph RADOS Gateway, AWS S3, Cloudflare R2, etc.)
|
||||
|
||||
## Use Cases
|
||||
|
||||
|
|
|
|||
|
|
@ -233,9 +233,13 @@ When a user initiates collection deletion through the librarian service:
|
|||
|
||||
#### Collection Management Interface
|
||||
|
||||
All store writers implement a standardized collection management interface with a common schema:
|
||||
**⚠️ LEGACY APPROACH - REPLACED BY CONFIG-BASED PATTERN**
|
||||
|
||||
**Message Schema (`StorageManagementRequest`):**
|
||||
The queue-based architecture described below has been replaced with a config-based approach using `CollectionConfigHandler`. All storage backends now receive collection updates via config push messages instead of dedicated management queues.
|
||||
|
||||
~~All store writers implement a standardized collection management interface with a common schema:~~
|
||||
|
||||
~~**Message Schema (`StorageManagementRequest`):**~~
|
||||
```json
|
||||
{
|
||||
"operation": "create-collection" | "delete-collection",
|
||||
|
|
@ -244,24 +248,26 @@ All store writers implement a standardized collection management interface with
|
|||
}
|
||||
```
|
||||
|
||||
**Queue Architecture:**
|
||||
- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores
|
||||
- **Object Store Management Queue** (`object-storage-management`): Object/document stores
|
||||
- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores
|
||||
- **Storage Response Queue** (`storage-management-response`): All responses sent here
|
||||
~~**Queue Architecture:**~~
|
||||
- ~~**Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores~~
|
||||
- ~~**Object Store Management Queue** (`object-storage-management`): Object/document stores~~
|
||||
- ~~**Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores~~
|
||||
- ~~**Storage Response Queue** (`storage-management-response`): All responses sent here~~
|
||||
|
||||
Each store writer implements:
|
||||
- **Collection Management Handler**: Processes `StorageManagementRequest` messages
|
||||
- **Create Collection Operation**: Establishes collection in storage backend
|
||||
- **Delete Collection Operation**: Removes all data associated with collection
|
||||
- **Collection State Tracking**: Maintains knowledge of which collections exist
|
||||
- **Message Processing**: Consumes from dedicated management queue
|
||||
- **Status Reporting**: Returns success/failure via `StorageManagementResponse`
|
||||
- **Idempotent Operations**: Safe to call create/delete multiple times
|
||||
**Current Implementation:**
|
||||
|
||||
**Supported Operations:**
|
||||
- `create-collection`: Create collection in storage backend
|
||||
- `delete-collection`: Remove all collection data from storage backend
|
||||
All storage backends now use `CollectionConfigHandler`:
|
||||
- **Config Push Integration**: Storage services register for config push notifications
|
||||
- **Automatic Synchronization**: Collections created/deleted based on config changes
|
||||
- **Declarative Model**: Collections defined in config service, backends sync to match
|
||||
- **No Request/Response**: Eliminates coordination overhead and response tracking
|
||||
- **Collection State Tracking**: Maintained via `known_collections` cache
|
||||
- **Idempotent Operations**: Safe to process same config multiple times
|
||||
|
||||
Each storage backend implements:
|
||||
- `create_collection(user: str, collection: str, metadata: dict)` - Create collection structures
|
||||
- `delete_collection(user: str, collection: str)` - Remove all collection data
|
||||
- `collection_exists(user: str, collection: str) -> bool` - Validate before writes
|
||||
|
||||
#### Cassandra Triple Store Refactor
|
||||
|
||||
|
|
@ -365,62 +371,33 @@ Comprehensive testing will cover:
|
|||
- `triples_collection` table for SPO queries and deletion tracking
|
||||
- Collection deletion implemented with read-then-delete pattern
|
||||
|
||||
### 🔄 In Progress Components
|
||||
### ✅ Migration to Config-Based Pattern - COMPLETED
|
||||
|
||||
1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`)
|
||||
- Update `update_collection()` to send "create-collection" to storage backends
|
||||
- Wait for confirmations from all storage processors
|
||||
- Handle creation failures appropriately
|
||||
**All storage backends have been migrated from the queue-based pattern to the config-based `CollectionConfigHandler` pattern.**
|
||||
|
||||
2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar)
|
||||
- Check if collection exists when document submitted
|
||||
- If not exists: Create collection with defaults before processing document
|
||||
- Trigger same "create-collection" broadcast as `tg-set-collection`
|
||||
- Ensure collection established before document flows to storage processors
|
||||
Completed migrations:
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||
- ✅ `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||
|
||||
### ❌ Pending Components
|
||||
All backends now:
|
||||
- Inherit from `CollectionConfigHandler`
|
||||
- Register for config push notifications via `self.register_config_handler(self.on_collection_config)`
|
||||
- Implement `create_collection(user, collection, metadata)` and `delete_collection(user, collection)`
|
||||
- Use `collection_exists(user, collection)` to validate before writes
|
||||
- Automatically sync with config service changes
|
||||
|
||||
1. **Collection State Tracking** - Need to implement in each storage backend:
|
||||
- **Cassandra Triples**: Use `triples_collection` table with marker triples
|
||||
- **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes
|
||||
- **Qdrant/Milvus/Pinecone**: Use native collection APIs
|
||||
- **Cassandra Objects**: Add collection metadata tracking
|
||||
|
||||
2. **Storage Management Handlers** - Need "create-collection" support in 12 files:
|
||||
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||
- Plus any other storage implementations
|
||||
|
||||
3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods
|
||||
|
||||
4. **Query Operation Handling** - Update queries to return empty for non-existent collections
|
||||
|
||||
### Next Implementation Steps
|
||||
|
||||
**Phase 1: Core Infrastructure (2-3 days)**
|
||||
1. Add collection state tracking methods to all storage backends
|
||||
2. Implement `collection_exists()` and `create_collection()` methods
|
||||
|
||||
**Phase 2: Storage Handlers (1 week)**
|
||||
3. Add "create-collection" handlers to all storage processors
|
||||
4. Add write validation to reject non-existent collections
|
||||
5. Update query handling for non-existent collections
|
||||
|
||||
**Phase 3: Collection Manager (2-3 days)**
|
||||
6. Update collection_manager to broadcast creates
|
||||
7. Implement response tracking and error handling
|
||||
|
||||
**Phase 4: Testing (3-5 days)**
|
||||
8. End-to-end testing of explicit creation workflow
|
||||
9. Test all storage backends
|
||||
10. Validate error handling and edge cases
|
||||
Legacy queue-based infrastructure removed:
|
||||
- ✅ Removed `StorageManagementRequest` and `StorageManagementResponse` schemas
|
||||
- ✅ Removed storage management queue topic definitions
|
||||
- ✅ Removed storage management consumer/producer from all backends
|
||||
- ✅ Removed `on_storage_management` handlers from all backends
|
||||
|
||||
|
|
|
|||
|
|
@ -2,17 +2,29 @@
|
|||
|
||||
## Overview
|
||||
|
||||
TrustGraph uses Python's built-in `logging` module for all logging operations. This provides a standardized, flexible approach to logging across all components of the system.
|
||||
TrustGraph uses Python's built-in `logging` module for all logging operations, with centralized configuration and optional Loki integration for log aggregation. This provides a standardized, flexible approach to logging across all components of the system.
|
||||
|
||||
## Default Configuration
|
||||
|
||||
### Logging Level
|
||||
- **Default Level**: `INFO`
|
||||
- **Debug Mode**: `DEBUG` (enabled via command-line argument)
|
||||
- **Production**: `WARNING` or `ERROR` as appropriate
|
||||
- **Configurable via**: `--log-level` command-line argument
|
||||
- **Choices**: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
|
||||
|
||||
### Output Destination
|
||||
All logs should be written to **standard output (stdout)** to ensure compatibility with containerized environments and log aggregation systems.
|
||||
### Output Destinations
|
||||
1. **Console (stdout)**: Always enabled - ensures compatibility with containerized environments
|
||||
2. **Loki**: Optional centralized log aggregation (enabled by default, can be disabled)
|
||||
|
||||
## Centralized Logging Module
|
||||
|
||||
All logging configuration is managed by `trustgraph.base.logging` module, which provides:
|
||||
- `add_logging_args(parser)` - Adds standard logging CLI arguments
|
||||
- `setup_logging(args)` - Configures logging from parsed arguments
|
||||
|
||||
This module is used by all server-side components:
|
||||
- AsyncProcessor-based services
|
||||
- API Gateway
|
||||
- MCP Server
|
||||
|
||||
## Implementation Guidelines
|
||||
|
||||
|
|
@ -26,39 +38,80 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
```
|
||||
|
||||
### 2. Centralized Configuration
|
||||
The logger name is automatically used as a label in Loki for filtering and searching.
|
||||
|
||||
The logging configuration should be centralized in `async_processor.py` (or a dedicated logging configuration module) since it's inherited by much of the codebase:
|
||||
### 2. Service Initialization
|
||||
|
||||
All server-side services automatically get logging configuration through the centralized module:
|
||||
|
||||
```python
|
||||
import logging
|
||||
from trustgraph.base import add_logging_args, setup_logging
|
||||
import argparse
|
||||
|
||||
def setup_logging(log_level='INFO'):
|
||||
"""Configure logging for the entire application"""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
|
||||
def parse_args():
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--log-level',
|
||||
default='INFO',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
help='Set the logging level (default: INFO)'
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
# In main execution
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
setup_logging(args.log_level)
|
||||
# Add standard logging arguments (includes Loki configuration)
|
||||
add_logging_args(parser)
|
||||
|
||||
# Add your service-specific arguments
|
||||
parser.add_argument('--port', type=int, default=8080)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
# Setup logging early in startup
|
||||
setup_logging(args)
|
||||
|
||||
# Rest of your service initialization
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Service starting...")
|
||||
```
|
||||
|
||||
### 3. Logging Best Practices
|
||||
### 3. Command-Line Arguments
|
||||
|
||||
All services support these logging arguments:
|
||||
|
||||
**Log Level:**
|
||||
```bash
|
||||
--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
|
||||
```
|
||||
|
||||
**Loki Configuration:**
|
||||
```bash
|
||||
--loki-enabled # Enable Loki (default)
|
||||
--no-loki-enabled # Disable Loki
|
||||
--loki-url URL # Loki push URL (default: http://loki:3100/loki/api/v1/push)
|
||||
--loki-username USERNAME # Optional authentication
|
||||
--loki-password PASSWORD # Optional authentication
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Default - INFO level, Loki enabled
|
||||
./my-service
|
||||
|
||||
# Debug mode, console only
|
||||
./my-service --log-level DEBUG --no-loki-enabled
|
||||
|
||||
# Custom Loki server with auth
|
||||
./my-service --loki-url http://loki.prod:3100/loki/api/v1/push \
|
||||
--loki-username admin --loki-password secret
|
||||
```
|
||||
|
||||
### 4. Environment Variables
|
||||
|
||||
Loki configuration supports environment variable fallbacks:
|
||||
|
||||
```bash
|
||||
export LOKI_URL=http://loki.prod:3100/loki/api/v1/push
|
||||
export LOKI_USERNAME=admin
|
||||
export LOKI_PASSWORD=secret
|
||||
```
|
||||
|
||||
Command-line arguments take precedence over environment variables.
|
||||
|
||||
### 5. Logging Best Practices
|
||||
|
||||
#### Log Levels Usage
|
||||
- **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit)
|
||||
|
|
@ -89,20 +142,25 @@ if logger.isEnabledFor(logging.DEBUG):
|
|||
logger.debug(f"Debug data: {debug_data}")
|
||||
```
|
||||
|
||||
### 4. Structured Logging
|
||||
### 6. Structured Logging with Loki
|
||||
|
||||
For complex data, use structured logging:
|
||||
For complex data, use structured logging with extra tags for Loki:
|
||||
|
||||
```python
|
||||
logger.info("Request processed", extra={
|
||||
'request_id': request_id,
|
||||
'duration_ms': duration,
|
||||
'status_code': status_code,
|
||||
'user_id': user_id
|
||||
'tags': {
|
||||
'request_id': request_id,
|
||||
'user_id': user_id,
|
||||
'status': 'success'
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Exception Logging
|
||||
These tags become searchable labels in Loki, in addition to automatic labels:
|
||||
- `severity` - Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
- `logger` - Module name (from `__name__`)
|
||||
|
||||
### 7. Exception Logging
|
||||
|
||||
Always include stack traces for exceptions:
|
||||
|
||||
|
|
@ -114,9 +172,13 @@ except Exception as e:
|
|||
raise
|
||||
```
|
||||
|
||||
### 6. Async Logging Considerations
|
||||
### 8. Async Logging Considerations
|
||||
|
||||
For async code, ensure thread-safe logging:
|
||||
The logging system uses non-blocking queued handlers for Loki:
|
||||
- Console output is synchronous (fast)
|
||||
- Loki output is queued with 500-message buffer
|
||||
- Background thread handles Loki transmission
|
||||
- No blocking of main application code
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
|
@ -124,46 +186,165 @@ import logging
|
|||
|
||||
async def async_operation():
|
||||
logger = logging.getLogger(__name__)
|
||||
# Logging is thread-safe and won't block async operations
|
||||
logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}")
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
## Loki Integration
|
||||
|
||||
Support environment-based configuration as a fallback:
|
||||
### Architecture
|
||||
|
||||
The logging system uses Python's built-in `QueueHandler` and `QueueListener` for non-blocking Loki integration:
|
||||
|
||||
1. **QueueHandler**: Logs are placed in a 500-message queue (non-blocking)
|
||||
2. **Background Thread**: QueueListener sends logs to Loki asynchronously
|
||||
3. **Graceful Degradation**: If Loki is unavailable, console logging continues
|
||||
|
||||
### Automatic Labels
|
||||
|
||||
Every log sent to Loki includes:
|
||||
- `processor`: Processor identity (e.g., `config-svc`, `text-completion`, `embeddings`)
|
||||
- `severity`: Log level (DEBUG, INFO, etc.)
|
||||
- `logger`: Module name (e.g., `trustgraph.gateway.service`, `trustgraph.agent.react.service`)
|
||||
|
||||
### Custom Labels
|
||||
|
||||
Add custom labels via the `extra` parameter:
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', 'INFO')
|
||||
logger.info("User action", extra={
|
||||
'tags': {
|
||||
'user_id': user_id,
|
||||
'action': 'document_upload',
|
||||
'collection': collection_name
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### Querying Logs in Loki
|
||||
|
||||
```logql
|
||||
# All logs from a specific processor (recommended - matches Prometheus metrics)
|
||||
{processor="config-svc"}
|
||||
{processor="text-completion"}
|
||||
{processor="embeddings"}
|
||||
|
||||
# Error logs from a specific processor
|
||||
{processor="config-svc", severity="ERROR"}
|
||||
|
||||
# Error logs from all processors
|
||||
{severity="ERROR"}
|
||||
|
||||
# Logs from a specific processor with text filter
|
||||
{processor="text-completion"} |= "Processing"
|
||||
|
||||
# All logs from API gateway
|
||||
{processor="api-gateway"}
|
||||
|
||||
# Logs from processors matching pattern
|
||||
{processor=~".*-completion"}
|
||||
|
||||
# Logs with custom tags
|
||||
{processor="api-gateway"} | json | user_id="12345"
|
||||
```
|
||||
|
||||
### Graceful Degradation
|
||||
|
||||
If Loki is unavailable or `python-logging-loki` is not installed:
|
||||
- Warning message printed to console
|
||||
- Console logging continues normally
|
||||
- Application continues running
|
||||
- No retry logic for Loki connection (fail fast, degrade gracefully)
|
||||
|
||||
## Testing
|
||||
|
||||
During tests, consider using a different logging configuration:
|
||||
|
||||
```python
|
||||
# In test setup
|
||||
logging.getLogger().setLevel(logging.WARNING) # Reduce noise during tests
|
||||
import logging
|
||||
|
||||
# Reduce noise during tests
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
# Or disable Loki for tests
|
||||
setup_logging({'log_level': 'WARNING', 'loki_enabled': False})
|
||||
```
|
||||
|
||||
## Monitoring Integration
|
||||
|
||||
Ensure log format is compatible with monitoring tools:
|
||||
- Include timestamps in ISO format
|
||||
- Use consistent field names
|
||||
- Include correlation IDs where applicable
|
||||
- Structure logs for easy parsing (JSON format for production)
|
||||
### Standard Format
|
||||
All logs use consistent format:
|
||||
```
|
||||
2025-01-09 10:30:45,123 - trustgraph.gateway.service - INFO - Request processed
|
||||
```
|
||||
|
||||
Format components:
|
||||
- Timestamp (ISO format with milliseconds)
|
||||
- Logger name (module path)
|
||||
- Log level
|
||||
- Message
|
||||
|
||||
### Loki Queries for Monitoring
|
||||
|
||||
Common monitoring queries:
|
||||
|
||||
```logql
|
||||
# Error rate by processor
|
||||
rate({severity="ERROR"}[5m]) by (processor)
|
||||
|
||||
# Top error-producing processors
|
||||
topk(5, count_over_time({severity="ERROR"}[1h]) by (processor))
|
||||
|
||||
# Recent errors with processor name
|
||||
{severity="ERROR"} | line_format "{{.processor}}: {{.message}}"
|
||||
|
||||
# All agent processors
|
||||
{processor=~".*agent.*"} |= "exception"
|
||||
|
||||
# Specific processor error count
|
||||
count_over_time({processor="config-svc", severity="ERROR"}[1h])
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never log sensitive information (passwords, API keys, personal data)
|
||||
- Sanitize user input before logging
|
||||
- Use placeholders for sensitive fields: `user_id=****1234`
|
||||
- **Never log sensitive information** (passwords, API keys, personal data, tokens)
|
||||
- **Sanitize user input** before logging
|
||||
- **Use placeholders** for sensitive fields: `user_id=****1234`
|
||||
- **Loki authentication**: Use `--loki-username` and `--loki-password` for secure deployments
|
||||
- **Secure transport**: Use HTTPS for Loki URL in production: `https://loki.prod:3100/loki/api/v1/push`
|
||||
|
||||
## Dependencies
|
||||
|
||||
The centralized logging module requires:
|
||||
- `python-logging-loki` - For Loki integration (optional, graceful degradation if missing)
|
||||
|
||||
Already included in `trustgraph-base/pyproject.toml` and `requirements.txt`.
|
||||
|
||||
## Migration Path
|
||||
|
||||
For existing code using print statements:
|
||||
1. Replace `print()` with appropriate logger calls
|
||||
2. Choose appropriate log levels based on message importance
|
||||
3. Add context to make logs more useful
|
||||
4. Test logging output at different levels
|
||||
For existing code:
|
||||
|
||||
1. **Services already using AsyncProcessor**: No changes needed, Loki support is automatic
|
||||
2. **Services not using AsyncProcessor** (api-gateway, mcp-server): Already updated
|
||||
3. **CLI tools**: Out of scope - continue using print() or simple logging
|
||||
|
||||
### From print() to logging:
|
||||
```python
|
||||
# Before
|
||||
print(f"Processing document {doc_id}")
|
||||
|
||||
# After
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Processing document {doc_id}")
|
||||
```
|
||||
|
||||
## Configuration Summary
|
||||
|
||||
| Argument | Default | Environment Variable | Description |
|
||||
|----------|---------|---------------------|-------------|
|
||||
| `--log-level` | `INFO` | - | Console and Loki log level |
|
||||
| `--loki-enabled` | `True` | - | Enable Loki logging |
|
||||
| `--loki-url` | `http://loki:3100/loki/api/v1/push` | `LOKI_URL` | Loki push endpoint |
|
||||
| `--loki-username` | `None` | `LOKI_USERNAME` | Loki auth username |
|
||||
| `--loki-password` | `None` | `LOKI_PASSWORD` | Loki auth password |
|
||||
|
|
|
|||
258
docs/tech-specs/minio-to-s3-migration.md
Normal file
258
docs/tech-specs/minio-to-s3-migration.md
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
# Tech Spec: S3-Compatible Storage Backend Support
|
||||
|
||||
## Overview
|
||||
|
||||
The Librarian service uses S3-compatible object storage for document blob storage. This spec documents the implementation that enables support for any S3-compatible backend including MinIO, Ceph RADOS Gateway (RGW), AWS S3, Cloudflare R2, DigitalOcean Spaces, and others.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Storage Components
|
||||
- **Blob Storage**: S3-compatible object storage via `minio` Python client library
|
||||
- **Metadata Storage**: Cassandra (stores object_id mapping and document metadata)
|
||||
- **Affected Component**: Librarian service only
|
||||
- **Storage Pattern**: Hybrid storage with metadata in Cassandra, content in S3-compatible storage
|
||||
|
||||
### Implementation
|
||||
- **Library**: `minio` Python client (supports any S3-compatible API)
|
||||
- **Location**: `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||
- **Operations**:
|
||||
- `add()` - Store blob with UUID object_id
|
||||
- `get()` - Retrieve blob by object_id
|
||||
- `remove()` - Delete blob by object_id
|
||||
- `ensure_bucket()` - Create bucket if not exists
|
||||
- **Bucket**: `library`
|
||||
- **Object Path**: `doc/{object_id}`
|
||||
- **Supported MIME Types**: `text/plain`, `application/pdf`
|
||||
|
||||
### Key Files
|
||||
1. `trustgraph-flow/trustgraph/librarian/blob_store.py` - BlobStore implementation
|
||||
2. `trustgraph-flow/trustgraph/librarian/librarian.py` - BlobStore initialization
|
||||
3. `trustgraph-flow/trustgraph/librarian/service.py` - Service configuration
|
||||
4. `trustgraph-flow/pyproject.toml` - Dependencies (`minio` package)
|
||||
5. `docs/apis/api-librarian.md` - API documentation
|
||||
|
||||
## Supported Storage Backends
|
||||
|
||||
The implementation works with any S3-compatible object storage system:
|
||||
|
||||
### Tested/Supported
|
||||
- **Ceph RADOS Gateway (RGW)** - Distributed storage system with S3 API (default configuration)
|
||||
- **MinIO** - Lightweight self-hosted object storage
|
||||
- **Garage** - Lightweight geo-distributed S3-compatible storage
|
||||
|
||||
### Should Work (S3-Compatible)
|
||||
- **AWS S3** - Amazon's cloud object storage
|
||||
- **Cloudflare R2** - Cloudflare's S3-compatible storage
|
||||
- **DigitalOcean Spaces** - DigitalOcean's object storage
|
||||
- **Wasabi** - S3-compatible cloud storage
|
||||
- **Backblaze B2** - S3-compatible backup storage
|
||||
- Any other service implementing the S3 REST API
|
||||
|
||||
## Configuration
|
||||
|
||||
### CLI Arguments
|
||||
|
||||
```bash
|
||||
librarian \
|
||||
--object-store-endpoint <hostname:port> \
|
||||
--object-store-access-key <access_key> \
|
||||
--object-store-secret-key <secret_key> \
|
||||
[--object-store-use-ssl] \
|
||||
[--object-store-region <region>]
|
||||
```
|
||||
|
||||
**Note:** Do not include `http://` or `https://` in the endpoint. Use `--object-store-use-ssl` to enable HTTPS.
|
||||
|
||||
### Environment Variables (Alternative)
|
||||
|
||||
```bash
|
||||
OBJECT_STORE_ENDPOINT=<hostname:port>
|
||||
OBJECT_STORE_ACCESS_KEY=<access_key>
|
||||
OBJECT_STORE_SECRET_KEY=<secret_key>
|
||||
OBJECT_STORE_USE_SSL=true|false # Optional, default: false
|
||||
OBJECT_STORE_REGION=<region> # Optional
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**Ceph RADOS Gateway (default):**
|
||||
```bash
|
||||
--object-store-endpoint ceph-rgw:7480 \
|
||||
--object-store-access-key object-user \
|
||||
--object-store-secret-key object-password
|
||||
```
|
||||
|
||||
**MinIO:**
|
||||
```bash
|
||||
--object-store-endpoint minio:9000 \
|
||||
--object-store-access-key minioadmin \
|
||||
--object-store-secret-key minioadmin
|
||||
```
|
||||
|
||||
**Garage (S3-compatible):**
|
||||
```bash
|
||||
--object-store-endpoint garage:3900 \
|
||||
--object-store-access-key GK000000000000000000000001 \
|
||||
--object-store-secret-key b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
|
||||
```
|
||||
|
||||
**AWS S3 with SSL:**
|
||||
```bash
|
||||
--object-store-endpoint s3.amazonaws.com \
|
||||
--object-store-access-key AKIAIOSFODNN7EXAMPLE \
|
||||
--object-store-secret-key wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY \
|
||||
--object-store-use-ssl \
|
||||
--object-store-region us-east-1
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
All S3-compatible backends require AWS Signature Version 4 (or v2) authentication:
|
||||
|
||||
- **Access Key** - Public identifier (like username)
|
||||
- **Secret Key** - Private signing key (like password)
|
||||
|
||||
The MinIO Python client handles all signature calculation automatically.
|
||||
|
||||
### Creating Credentials
|
||||
|
||||
**For MinIO:**
|
||||
```bash
|
||||
# Use default credentials or create user via MinIO Console
|
||||
minioadmin / minioadmin
|
||||
```
|
||||
|
||||
**For Ceph RGW:**
|
||||
```bash
|
||||
radosgw-admin user create --uid="trustgraph" --display-name="TrustGraph Service"
|
||||
# Returns access_key and secret_key
|
||||
```
|
||||
|
||||
**For AWS S3:**
|
||||
- Create IAM user with S3 permissions
|
||||
- Generate access key in AWS Console
|
||||
|
||||
## Library Selection: MinIO Python Client
|
||||
|
||||
**Rationale:**
|
||||
- Lightweight (~500KB vs boto3's ~50MB)
|
||||
- S3-compatible - works with any S3 API endpoint
|
||||
- Simpler API than boto3 for basic operations
|
||||
- Already in use, no migration needed
|
||||
- Battle-tested with MinIO and other S3 systems
|
||||
|
||||
## BlobStore Implementation
|
||||
|
||||
**Location:** `trustgraph-flow/trustgraph/librarian/blob_store.py`
|
||||
|
||||
```python
|
||||
from minio import Minio
|
||||
import io
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BlobStore:
|
||||
"""
|
||||
S3-compatible blob storage for document content.
|
||||
Supports MinIO, Ceph RGW, AWS S3, and other S3-compatible backends.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint, access_key, secret_key, bucket_name,
|
||||
use_ssl=False, region=None):
|
||||
"""
|
||||
Initialize S3-compatible blob storage.
|
||||
|
||||
Args:
|
||||
endpoint: S3 endpoint (e.g., "minio:9000", "ceph-rgw:7480")
|
||||
access_key: S3 access key
|
||||
secret_key: S3 secret key
|
||||
bucket_name: Bucket name for storage
|
||||
use_ssl: Use HTTPS instead of HTTP (default: False)
|
||||
region: S3 region (optional, e.g., "us-east-1")
|
||||
"""
|
||||
self.client = Minio(
|
||||
endpoint=endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=use_ssl,
|
||||
region=region,
|
||||
)
|
||||
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
protocol = "https" if use_ssl else "http"
|
||||
logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}")
|
||||
|
||||
self.ensure_bucket()
|
||||
|
||||
def ensure_bucket(self):
|
||||
"""Create bucket if it doesn't exist"""
|
||||
found = self.client.bucket_exists(bucket_name=self.bucket_name)
|
||||
if not found:
|
||||
self.client.make_bucket(bucket_name=self.bucket_name)
|
||||
logger.info(f"Created bucket {self.bucket_name}")
|
||||
else:
|
||||
logger.debug(f"Bucket {self.bucket_name} already exists")
|
||||
|
||||
async def add(self, object_id, blob, kind):
|
||||
"""Store blob in S3-compatible storage"""
|
||||
self.client.put_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=f"doc/{object_id}",
|
||||
length=len(blob),
|
||||
data=io.BytesIO(blob),
|
||||
content_type=kind,
|
||||
)
|
||||
logger.debug("Add blob complete")
|
||||
|
||||
async def remove(self, object_id):
|
||||
"""Delete blob from S3-compatible storage"""
|
||||
self.client.remove_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=f"doc/{object_id}",
|
||||
)
|
||||
logger.debug("Remove blob complete")
|
||||
|
||||
async def get(self, object_id):
|
||||
"""Retrieve blob from S3-compatible storage"""
|
||||
resp = self.client.get_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=f"doc/{object_id}",
|
||||
)
|
||||
return resp.read()
|
||||
```
|
||||
|
||||
## Key Benefits
|
||||
|
||||
1. **No Vendor Lock-in** - Works with any S3-compatible storage
|
||||
2. **Lightweight** - MinIO client is only ~500KB
|
||||
3. **Simple Configuration** - Just endpoint + credentials
|
||||
4. **No Data Migration** - Drop-in replacement between backends
|
||||
5. **Battle-Tested** - MinIO client works with all major S3 implementations
|
||||
|
||||
## Implementation Status
|
||||
|
||||
All code has been updated to use generic S3 parameter names:
|
||||
|
||||
- ✅ `blob_store.py` - Updated to accept `endpoint`, `access_key`, `secret_key`
|
||||
- ✅ `librarian.py` - Updated parameter names
|
||||
- ✅ `service.py` - Updated CLI arguments and configuration
|
||||
- ✅ Documentation updated
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **SSL/TLS Support** - Add `--s3-use-ssl` flag for HTTPS
|
||||
2. **Retry Logic** - Implement exponential backoff for transient failures
|
||||
3. **Presigned URLs** - Generate temporary upload/download URLs
|
||||
4. **Multi-region Support** - Replicate blobs across regions
|
||||
5. **CDN Integration** - Serve blobs via CDN
|
||||
6. **Storage Classes** - Use S3 storage classes for cost optimization
|
||||
7. **Lifecycle Policies** - Automatic archival/deletion
|
||||
8. **Versioning** - Store multiple versions of blobs
|
||||
|
||||
## References
|
||||
|
||||
- MinIO Python Client: https://min.io/docs/minio/linux/developers/python/API.html
|
||||
- Ceph RGW S3 API: https://docs.ceph.com/en/latest/radosgw/s3/
|
||||
- S3 API Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html
|
||||
772
docs/tech-specs/multi-tenant-support.md
Normal file
772
docs/tech-specs/multi-tenant-support.md
Normal file
|
|
@ -0,0 +1,772 @@
|
|||
# Technical Specification: Multi-Tenant Support
|
||||
|
||||
## Overview
|
||||
|
||||
Enable multi-tenant deployments by fixing parameter name mismatches that prevent queue customization and adding Cassandra keyspace parameterization.
|
||||
|
||||
## Architecture Context
|
||||
|
||||
### Flow-Based Queue Resolution
|
||||
|
||||
The TrustGraph system uses a **flow-based architecture** for dynamic queue resolution, which inherently supports multi-tenancy:
|
||||
|
||||
- **Flow Definitions** are stored in Cassandra and specify queue names via interface definitions
|
||||
- **Queue names use templates** with `{id}` variables that are replaced with flow instance IDs
|
||||
- **Services dynamically resolve queues** by looking up flow configurations at request time
|
||||
- **Each tenant can have unique flows** with different queue names, providing isolation
|
||||
|
||||
Example flow interface definition:
|
||||
```json
|
||||
{
|
||||
"interfaces": {
|
||||
"triples-store": "persistent://tg/flow/triples-store:{id}",
|
||||
"graph-embeddings-store": "persistent://tg/flow/graph-embeddings-store:{id}"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
When tenant A starts flow `tenant-a-prod` and tenant B starts flow `tenant-b-prod`, they automatically get isolated queues:
|
||||
- `persistent://tg/flow/triples-store:tenant-a-prod`
|
||||
- `persistent://tg/flow/triples-store:tenant-b-prod`
|
||||
|
||||
**Services correctly designed for multi-tenancy:**
|
||||
- ✅ **Knowledge Management (cores)** - Dynamically resolves queues from flow configuration passed in requests
|
||||
|
||||
**Services needing fixes:**
|
||||
- 🔴 **Config Service** - Parameter name mismatch prevents queue customization
|
||||
- 🔴 **Librarian Service** - Hardcoded storage management topics (discussed below)
|
||||
- 🔴 **All Services** - Cannot customize Cassandra keyspace
|
||||
|
||||
## Problem Statement
|
||||
|
||||
### Issue #1: Parameter Name Mismatch in AsyncProcessor
|
||||
- **CLI defines:** `--config-queue` (unclear naming)
|
||||
- **Argparse converts to:** `config_queue` (in params dict)
|
||||
- **Code looks for:** `config_push_queue`
|
||||
- **Result:** Parameter is ignored, defaults to `persistent://tg/config/config`
|
||||
- **Impact:** Affects all 32+ services inheriting from AsyncProcessor
|
||||
- **Blocks:** Multi-tenant deployments cannot use tenant-specific config queues
|
||||
- **Solution:** Rename CLI parameter to `--config-push-queue` for clarity (breaking change acceptable since feature is currently broken)
|
||||
|
||||
### Issue #2: Parameter Name Mismatch in Config Service
|
||||
- **CLI defines:** `--push-queue` (ambiguous naming)
|
||||
- **Argparse converts to:** `push_queue` (in params dict)
|
||||
- **Code looks for:** `config_push_queue`
|
||||
- **Result:** Parameter is ignored
|
||||
- **Impact:** Config service cannot use custom push queue
|
||||
- **Solution:** Rename CLI parameter to `--config-push-queue` for consistency and clarity (breaking change acceptable)
|
||||
|
||||
### Issue #3: Hardcoded Cassandra Keyspace
|
||||
- **Current:** Keyspace hardcoded as `"config"`, `"knowledge"`, `"librarian"` in various services
|
||||
- **Result:** Cannot customize keyspace for multi-tenant deployments
|
||||
- **Impact:** Config, cores, and librarian services
|
||||
- **Blocks:** Multiple tenants cannot use separate Cassandra keyspaces
|
||||
|
||||
### Issue #4: Collection Management Architecture ✅ COMPLETED
|
||||
- **Previous:** Collections stored in Cassandra librarian keyspace via separate collections table
|
||||
- **Previous:** Librarian used 4 hardcoded storage management topics to coordinate collection create/delete:
|
||||
- `vector_storage_management_topic`
|
||||
- `object_storage_management_topic`
|
||||
- `triples_storage_management_topic`
|
||||
- `storage_management_response_topic`
|
||||
- **Problems (Resolved):**
|
||||
- Hardcoded topics could not be customized for multi-tenant deployments
|
||||
- Complex async coordination between librarian and 4+ storage services
|
||||
- Separate Cassandra table and management infrastructure
|
||||
- Non-persistent request/response queues for critical operations
|
||||
- **Solution Implemented:** Migrated collections to config service storage, use config push for distribution
|
||||
- **Status:** All storage backends migrated to `CollectionConfigHandler` pattern
|
||||
|
||||
## Solution
|
||||
|
||||
This spec addresses Issues #1, #2, #3, and #4.
|
||||
|
||||
### Part 1: Fix Parameter Name Mismatches
|
||||
|
||||
#### Change 1: AsyncProcessor Base Class - Rename CLI Parameter
|
||||
**File:** `trustgraph-base/trustgraph/base/async_processor.py`
|
||||
**Line:** 260-264
|
||||
|
||||
**Current:**
|
||||
```python
|
||||
parser.add_argument(
|
||||
'--config-queue',
|
||||
default=default_config_queue,
|
||||
help=f'Config push queue {default_config_queue}',
|
||||
)
|
||||
```
|
||||
|
||||
**Fixed:**
|
||||
```python
|
||||
parser.add_argument(
|
||||
'--config-push-queue',
|
||||
default=default_config_queue,
|
||||
help=f'Config push queue (default: {default_config_queue})',
|
||||
)
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
- Clearer, more explicit naming
|
||||
- Matches the internal variable name `config_push_queue`
|
||||
- Breaking change acceptable since feature is currently non-functional
|
||||
- No code change needed in params.get() - it already looks for the correct name
|
||||
|
||||
#### Change 2: Config Service - Rename CLI Parameter
|
||||
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
|
||||
**Line:** 276-279
|
||||
|
||||
**Current:**
|
||||
```python
|
||||
parser.add_argument(
|
||||
'--push-queue',
|
||||
default=default_config_push_queue,
|
||||
help=f'Config push queue (default: {default_config_push_queue})'
|
||||
)
|
||||
```
|
||||
|
||||
**Fixed:**
|
||||
```python
|
||||
parser.add_argument(
|
||||
'--config-push-queue',
|
||||
default=default_config_push_queue,
|
||||
help=f'Config push queue (default: {default_config_push_queue})'
|
||||
)
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
- Clearer naming - "config-push-queue" is more explicit than just "push-queue"
|
||||
- Matches the internal variable name `config_push_queue`
|
||||
- Consistent with AsyncProcessor's `--config-push-queue` parameter
|
||||
- Breaking change acceptable since feature is currently non-functional
|
||||
- No code change needed in params.get() - it already looks for the correct name
|
||||
|
||||
### Part 2: Add Cassandra Keyspace Parameterization
|
||||
|
||||
#### Change 3: Add Keyspace Parameter to cassandra_config Module
|
||||
**File:** `trustgraph-base/trustgraph/base/cassandra_config.py`
|
||||
|
||||
**Add CLI argument** (in `add_cassandra_args()` function):
|
||||
```python
|
||||
parser.add_argument(
|
||||
'--cassandra-keyspace',
|
||||
default=None,
|
||||
help='Cassandra keyspace (default: service-specific)'
|
||||
)
|
||||
```
|
||||
|
||||
**Add environment variable support** (in `resolve_cassandra_config()` function):
|
||||
```python
|
||||
keyspace = params.get(
|
||||
"cassandra_keyspace",
|
||||
os.environ.get("CASSANDRA_KEYSPACE")
|
||||
)
|
||||
```
|
||||
|
||||
**Update return value** of `resolve_cassandra_config()`:
|
||||
- Currently returns: `(hosts, username, password)`
|
||||
- Change to return: `(hosts, username, password, keyspace)`
|
||||
|
||||
**Rationale:**
|
||||
- Consistent with existing Cassandra configuration pattern
|
||||
- Available to all services via `add_cassandra_args()`
|
||||
- Supports both CLI and environment variable configuration
|
||||
|
||||
#### Change 4: Config Service - Use Parameterized Keyspace
|
||||
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
|
||||
|
||||
**Line 30** - Remove hardcoded keyspace:
|
||||
```python
|
||||
# DELETE THIS LINE:
|
||||
keyspace = "config"
|
||||
```
|
||||
|
||||
**Lines 69-73** - Update cassandra config resolution:
|
||||
|
||||
**Current:**
|
||||
```python
|
||||
cassandra_host, cassandra_username, cassandra_password = \
|
||||
resolve_cassandra_config(params)
|
||||
```
|
||||
|
||||
**Fixed:**
|
||||
```python
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||
resolve_cassandra_config(params, default_keyspace="config")
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
- Maintains backward compatibility with "config" as default
|
||||
- Allows override via `--cassandra-keyspace` or `CASSANDRA_KEYSPACE`
|
||||
|
||||
#### Change 5: Cores/Knowledge Service - Use Parameterized Keyspace
|
||||
**File:** `trustgraph-flow/trustgraph/cores/service.py`
|
||||
|
||||
**Line 37** - Remove hardcoded keyspace:
|
||||
```python
|
||||
# DELETE THIS LINE:
|
||||
keyspace = "knowledge"
|
||||
```
|
||||
|
||||
**Update cassandra config resolution** (similar location as config service):
|
||||
```python
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||
resolve_cassandra_config(params, default_keyspace="knowledge")
|
||||
```
|
||||
|
||||
#### Change 6: Librarian Service - Use Parameterized Keyspace
|
||||
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
|
||||
|
||||
**Line 51** - Remove hardcoded keyspace:
|
||||
```python
|
||||
# DELETE THIS LINE:
|
||||
keyspace = "librarian"
|
||||
```
|
||||
|
||||
**Update cassandra config resolution** (similar location as config service):
|
||||
```python
|
||||
cassandra_host, cassandra_username, cassandra_password, keyspace = \
|
||||
resolve_cassandra_config(params, default_keyspace="librarian")
|
||||
```
|
||||
|
||||
### Part 3: Migrate Collection Management to Config Service
|
||||
|
||||
#### Overview
|
||||
Migrate collections from Cassandra librarian keyspace to config service storage. This eliminates hardcoded storage management topics and simplifies the architecture by using the existing config push mechanism for distribution.
|
||||
|
||||
#### Current Architecture
|
||||
```
|
||||
API Request → Gateway → Librarian Service
|
||||
↓
|
||||
CollectionManager
|
||||
↓
|
||||
Cassandra Collections Table (librarian keyspace)
|
||||
↓
|
||||
Broadcast to 4 Storage Management Topics (hardcoded)
|
||||
↓
|
||||
Wait for 4+ Storage Service Responses
|
||||
↓
|
||||
Response to Gateway
|
||||
```
|
||||
|
||||
#### New Architecture
|
||||
```
|
||||
API Request → Gateway → Librarian Service
|
||||
↓
|
||||
CollectionManager
|
||||
↓
|
||||
Config Service API (put/delete/getvalues)
|
||||
↓
|
||||
Cassandra Config Table (class='collections', key='user:collection')
|
||||
↓
|
||||
Config Push (to all subscribers on config-push-queue)
|
||||
↓
|
||||
All Storage Services receive config update independently
|
||||
```
|
||||
|
||||
#### Change 7: Collection Manager - Use Config Service API
|
||||
**File:** `trustgraph-flow/trustgraph/librarian/collection_manager.py`
|
||||
|
||||
**Remove:**
|
||||
- `LibraryTableStore` usage (Lines 33, 40-41)
|
||||
- Storage management producers initialization (Lines 86-140)
|
||||
- `on_storage_response` method (Lines 400-430)
|
||||
- `pending_deletions` tracking (Lines 57, 90-96, and usage throughout)
|
||||
|
||||
**Add:**
|
||||
- Config service client for API calls (request/response pattern)
|
||||
|
||||
**Config Client Setup:**
|
||||
```python
|
||||
# In __init__, add config request/response producers/consumers
|
||||
from trustgraph.schema.services.config import ConfigRequest, ConfigResponse
|
||||
|
||||
# Producer for config requests
|
||||
self.config_request_producer = Producer(
|
||||
client=pulsar_client,
|
||||
topic=config_request_queue,
|
||||
schema=ConfigRequest,
|
||||
)
|
||||
|
||||
# Consumer for config responses (with correlation ID)
|
||||
self.config_response_consumer = Consumer(
|
||||
taskgroup=taskgroup,
|
||||
client=pulsar_client,
|
||||
flow=None,
|
||||
topic=config_response_queue,
|
||||
subscriber=f"{id}-config",
|
||||
schema=ConfigResponse,
|
||||
handler=self.on_config_response,
|
||||
)
|
||||
|
||||
# Tracking for pending config requests
|
||||
self.pending_config_requests = {} # request_id -> asyncio.Event
|
||||
```
|
||||
|
||||
**Modify `list_collections` (Lines 145-180):**
|
||||
```python
|
||||
async def list_collections(self, user, tag_filter=None, limit=None):
|
||||
"""List collections from config service"""
|
||||
# Send getvalues request to config service
|
||||
request = ConfigRequest(
|
||||
id=str(uuid.uuid4()),
|
||||
operation='getvalues',
|
||||
type='collections',
|
||||
)
|
||||
|
||||
# Send request and wait for response
|
||||
response = await self.send_config_request(request)
|
||||
|
||||
# Parse collections from response
|
||||
collections = []
|
||||
for key, value_json in response.values.items():
|
||||
if ":" in key:
|
||||
coll_user, collection = key.split(":", 1)
|
||||
if coll_user == user:
|
||||
metadata = json.loads(value_json)
|
||||
collections.append(CollectionMetadata(**metadata))
|
||||
|
||||
# Apply tag filtering in-memory (as before)
|
||||
if tag_filter:
|
||||
collections = [c for c in collections if any(tag in c.tags for tag in tag_filter)]
|
||||
|
||||
# Apply limit
|
||||
if limit:
|
||||
collections = collections[:limit]
|
||||
|
||||
return collections
|
||||
|
||||
async def send_config_request(self, request):
|
||||
"""Send config request and wait for response"""
|
||||
event = asyncio.Event()
|
||||
self.pending_config_requests[request.id] = event
|
||||
|
||||
await self.config_request_producer.send(request)
|
||||
await event.wait()
|
||||
|
||||
return self.pending_config_requests.pop(request.id + "_response")
|
||||
|
||||
async def on_config_response(self, message, consumer, flow):
|
||||
"""Handle config response"""
|
||||
response = message.value()
|
||||
if response.id in self.pending_config_requests:
|
||||
self.pending_config_requests[response.id + "_response"] = response
|
||||
self.pending_config_requests[response.id].set()
|
||||
```
|
||||
|
||||
**Modify `update_collection` (Lines 182-312):**
|
||||
```python
|
||||
async def update_collection(self, user, collection, name, description, tags):
|
||||
"""Update collection via config service"""
|
||||
# Create metadata
|
||||
metadata = CollectionMetadata(
|
||||
user=user,
|
||||
collection=collection,
|
||||
name=name,
|
||||
description=description,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Send put request to config service
|
||||
request = ConfigRequest(
|
||||
id=str(uuid.uuid4()),
|
||||
operation='put',
|
||||
type='collections',
|
||||
key=f'{user}:{collection}',
|
||||
value=json.dumps(metadata.to_dict()),
|
||||
)
|
||||
|
||||
response = await self.send_config_request(request)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(f"Config update failed: {response.error.message}")
|
||||
|
||||
# Config service will trigger config push automatically
|
||||
# Storage services will receive update and create collections
|
||||
```
|
||||
|
||||
**Modify `delete_collection` (Lines 314-398):**
|
||||
```python
|
||||
async def delete_collection(self, user, collection):
|
||||
"""Delete collection via config service"""
|
||||
# Send delete request to config service
|
||||
request = ConfigRequest(
|
||||
id=str(uuid.uuid4()),
|
||||
operation='delete',
|
||||
type='collections',
|
||||
key=f'{user}:{collection}',
|
||||
)
|
||||
|
||||
response = await self.send_config_request(request)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(f"Config delete failed: {response.error.message}")
|
||||
|
||||
# Config service will trigger config push automatically
|
||||
# Storage services will receive update and delete collections
|
||||
```
|
||||
|
||||
**Collection Metadata Format:**
|
||||
- Stored in config table as: `class='collections', key='user:collection'`
|
||||
- Value is JSON-serialized CollectionMetadata (without timestamp fields)
|
||||
- Fields: `user`, `collection`, `name`, `description`, `tags`
|
||||
- Example: `class='collections', key='alice:my-docs', value='{"user":"alice","collection":"my-docs","name":"My Documents","description":"...","tags":["work"]}'`
|
||||
|
||||
#### Change 8: Librarian Service - Remove Storage Management Infrastructure
|
||||
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
|
||||
|
||||
**Remove:**
|
||||
- Storage management producers (Lines 173-190):
|
||||
- `vector_storage_management_producer`
|
||||
- `object_storage_management_producer`
|
||||
- `triples_storage_management_producer`
|
||||
- Storage response consumer (Lines 192-201)
|
||||
- `on_storage_response` handler (Lines 467-473)
|
||||
|
||||
**Modify:**
|
||||
- CollectionManager initialization (Lines 215-224) - remove storage producer parameters
|
||||
|
||||
**Note:** External collection API remains unchanged:
|
||||
- `list-collections`
|
||||
- `update-collection`
|
||||
- `delete-collection`
|
||||
|
||||
#### Change 9: Remove Collections Table from LibraryTableStore
|
||||
**File:** `trustgraph-flow/trustgraph/tables/library.py`
|
||||
|
||||
**Delete:**
|
||||
- Collections table CREATE statement (Lines 114-127)
|
||||
- Collections prepared statements (Lines 205-240)
|
||||
- All collection methods (Lines 578-717):
|
||||
- `ensure_collection_exists`
|
||||
- `list_collections`
|
||||
- `update_collection`
|
||||
- `delete_collection`
|
||||
- `get_collection`
|
||||
- `create_collection`
|
||||
|
||||
**Rationale:**
|
||||
- Collections now stored in config table
|
||||
- Breaking change acceptable - no data migration needed
|
||||
- Simplifies librarian service significantly
|
||||
|
||||
#### Change 10: Storage Services - Config-Based Collection Management ✅ COMPLETED
|
||||
|
||||
**Status:** All 11 storage backends have been migrated to use `CollectionConfigHandler`.
|
||||
|
||||
**Affected Services (11 total):**
|
||||
- Document embeddings: milvus, pinecone, qdrant
|
||||
- Graph embeddings: milvus, pinecone, qdrant
|
||||
- Object storage: cassandra
|
||||
- Triples storage: cassandra, falkordb, memgraph, neo4j
|
||||
|
||||
**Files:**
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
|
||||
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
|
||||
|
||||
**Implementation Pattern (all services):**
|
||||
|
||||
1. **Register config handler in `__init__`:**
|
||||
```python
|
||||
# Add after AsyncProcessor initialization
|
||||
self.register_config_handler(self.on_collection_config)
|
||||
self.known_collections = set() # Track (user, collection) tuples
|
||||
```
|
||||
|
||||
2. **Implement config handler:**
|
||||
```python
|
||||
async def on_collection_config(self, config, version):
|
||||
"""Handle collection configuration updates"""
|
||||
logger.info(f"Collection config version: {version}")
|
||||
|
||||
if "collections" not in config:
|
||||
return
|
||||
|
||||
# Parse collections from config
|
||||
# Key format: "user:collection" in config["collections"]
|
||||
config_collections = set()
|
||||
for key in config["collections"].keys():
|
||||
if ":" in key:
|
||||
user, collection = key.split(":", 1)
|
||||
config_collections.add((user, collection))
|
||||
|
||||
# Determine changes
|
||||
to_create = config_collections - self.known_collections
|
||||
to_delete = self.known_collections - config_collections
|
||||
|
||||
# Create new collections (idempotent)
|
||||
for user, collection in to_create:
|
||||
try:
|
||||
await self.create_collection_internal(user, collection)
|
||||
self.known_collections.add((user, collection))
|
||||
logger.info(f"Created collection: {user}/{collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create {user}/{collection}: {e}")
|
||||
|
||||
# Delete removed collections (idempotent)
|
||||
for user, collection in to_delete:
|
||||
try:
|
||||
await self.delete_collection_internal(user, collection)
|
||||
self.known_collections.discard((user, collection))
|
||||
logger.info(f"Deleted collection: {user}/{collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {user}/{collection}: {e}")
|
||||
```
|
||||
|
||||
3. **Initialize known collections on startup:**
|
||||
```python
|
||||
async def start(self):
|
||||
"""Start the processor"""
|
||||
await super().start()
|
||||
await self.sync_known_collections()
|
||||
|
||||
async def sync_known_collections(self):
|
||||
"""Query backend to populate known_collections set"""
|
||||
# Backend-specific implementation:
|
||||
# - Milvus/Pinecone/Qdrant: List collections/indexes matching naming pattern
|
||||
# - Cassandra: Query keyspaces or collection metadata
|
||||
# - Neo4j/Memgraph/FalkorDB: Query CollectionMetadata nodes
|
||||
pass
|
||||
```
|
||||
|
||||
4. **Refactor existing handler methods:**
|
||||
```python
|
||||
# Rename and remove response sending:
|
||||
# handle_create_collection → create_collection_internal
|
||||
# handle_delete_collection → delete_collection_internal
|
||||
|
||||
async def create_collection_internal(self, user, collection):
|
||||
"""Create collection (idempotent)"""
|
||||
# Same logic as current handle_create_collection
|
||||
# But remove response producer calls
|
||||
# Handle "already exists" gracefully
|
||||
pass
|
||||
|
||||
async def delete_collection_internal(self, user, collection):
|
||||
"""Delete collection (idempotent)"""
|
||||
# Same logic as current handle_delete_collection
|
||||
# But remove response producer calls
|
||||
# Handle "not found" gracefully
|
||||
pass
|
||||
```
|
||||
|
||||
5. **Remove storage management infrastructure:**
|
||||
- Remove `self.storage_request_consumer` setup and start
|
||||
- Remove `self.storage_response_producer` setup
|
||||
- Remove `on_storage_management` dispatcher method
|
||||
- Remove metrics for storage management
|
||||
- Remove imports: `StorageManagementRequest`, `StorageManagementResponse`
|
||||
|
||||
**Backend-Specific Considerations:**
|
||||
|
||||
- **Vector stores (Milvus, Pinecone, Qdrant):** Track logical `(user, collection)` in `known_collections`, but may create multiple backend collections per dimension. Continue lazy creation pattern. Delete operations must remove all dimension variants.
|
||||
|
||||
- **Cassandra Objects:** Collections are row properties, not structures. Track keyspace-level information.
|
||||
|
||||
- **Graph stores (Neo4j, Memgraph, FalkorDB):** Query `CollectionMetadata` nodes on startup. Create/delete metadata nodes on sync.
|
||||
|
||||
- **Cassandra Triples:** Use `KnowledgeGraph` API for collection operations.
|
||||
|
||||
**Key Design Points:**
|
||||
|
||||
- **Eventual consistency:** No request/response mechanism, config push is broadcast
|
||||
- **Idempotency:** All create/delete operations must be safe to retry
|
||||
- **Error handling:** Log errors but don't block config updates
|
||||
- **Self-healing:** Failed operations will retry on next config push
|
||||
- **Collection key format:** `"user:collection"` in `config["collections"]`
|
||||
|
||||
#### Change 11: Update Collection Schema - Remove Timestamps
|
||||
**File:** `trustgraph-base/trustgraph/schema/services/collection.py`
|
||||
|
||||
**Modify CollectionMetadata (Lines 13-21):**
|
||||
Remove `created_at` and `updated_at` fields:
|
||||
```python
|
||||
class CollectionMetadata(Record):
|
||||
user = String()
|
||||
collection = String()
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
# Remove: created_at = String()
|
||||
# Remove: updated_at = String()
|
||||
```
|
||||
|
||||
**Modify CollectionManagementRequest (Lines 25-47):**
|
||||
Remove timestamp fields:
|
||||
```python
|
||||
class CollectionManagementRequest(Record):
|
||||
operation = String()
|
||||
user = String()
|
||||
collection = String()
|
||||
timestamp = String()
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
# Remove: created_at = String()
|
||||
# Remove: updated_at = String()
|
||||
tag_filter = Array(String())
|
||||
limit = Integer()
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
- Timestamps don't add value for collections
|
||||
- Config service maintains its own version tracking
|
||||
- Simplifies schema and reduces storage
|
||||
|
||||
#### Benefits of Config Service Migration
|
||||
|
||||
1. ✅ **Eliminates hardcoded storage management topics** - Solves multi-tenant blocker
|
||||
2. ✅ **Simpler coordination** - No complex async waiting for 4+ storage responses
|
||||
3. ✅ **Eventual consistency** - Storage services update independently via config push
|
||||
4. ✅ **Better reliability** - Persistent config push vs non-persistent request/response
|
||||
5. ✅ **Unified configuration model** - Collections treated as configuration
|
||||
6. ✅ **Reduces complexity** - Removes ~300 lines of coordination code
|
||||
7. ✅ **Multi-tenant ready** - Config already supports tenant isolation via keyspace
|
||||
8. ✅ **Version tracking** - Config service version mechanism provides audit trail
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
**Parameter Changes:**
|
||||
- CLI parameter renames are breaking changes but acceptable (feature currently non-functional)
|
||||
- Services work without parameters (use defaults)
|
||||
- Default keyspaces preserved: "config", "knowledge", "librarian"
|
||||
- Default queue: `persistent://tg/config/config`
|
||||
|
||||
**Collection Management:**
|
||||
- **Breaking change:** Collections table removed from librarian keyspace
|
||||
- **No data migration provided** - acceptable for this phase
|
||||
- External collection API unchanged (list/update/delete operations)
|
||||
- Collection metadata format simplified (timestamps removed)
|
||||
|
||||
### Testing Requirements
|
||||
|
||||
**Parameter Testing:**
|
||||
1. Verify `--config-push-queue` parameter works on graph-embeddings service
|
||||
2. Verify `--config-push-queue` parameter works on text-completion service
|
||||
3. Verify `--config-push-queue` parameter works on config service
|
||||
4. Verify `--cassandra-keyspace` parameter works for config service
|
||||
5. Verify `--cassandra-keyspace` parameter works for cores service
|
||||
6. Verify `--cassandra-keyspace` parameter works for librarian service
|
||||
7. Verify services work without parameters (uses defaults)
|
||||
8. Verify multi-tenant deployment with custom queue names and keyspace
|
||||
|
||||
**Collection Management Testing:**
|
||||
9. Verify `list-collections` operation via config service
|
||||
10. Verify `update-collection` creates/updates in config table
|
||||
11. Verify `delete-collection` removes from config table
|
||||
12. Verify config push is triggered on collection updates
|
||||
13. Verify tag filtering works with config-based storage
|
||||
14. Verify collection operations work without timestamp fields
|
||||
|
||||
### Multi-Tenant Deployment Example
|
||||
```bash
|
||||
# Tenant: tg-dev
|
||||
graph-embeddings \
|
||||
-p pulsar+ssl://broker:6651 \
|
||||
--pulsar-api-key <KEY> \
|
||||
--config-push-queue persistent://tg-dev/config/config
|
||||
|
||||
config-service \
|
||||
-p pulsar+ssl://broker:6651 \
|
||||
--pulsar-api-key <KEY> \
|
||||
--config-push-queue persistent://tg-dev/config/config \
|
||||
--cassandra-keyspace tg_dev_config
|
||||
```
|
||||
|
||||
## Impact Analysis
|
||||
|
||||
### Services Affected by Change 1-2 (CLI Parameter Rename)
|
||||
All services inheriting from AsyncProcessor or FlowProcessor:
|
||||
- config-service
|
||||
- cores-service
|
||||
- librarian-service
|
||||
- graph-embeddings
|
||||
- document-embeddings
|
||||
- text-completion-* (all providers)
|
||||
- extract-* (all extractors)
|
||||
- query-* (all query services)
|
||||
- retrieval-* (all RAG services)
|
||||
- storage-* (all storage services)
|
||||
- And 20+ more services
|
||||
|
||||
### Services Affected by Changes 3-6 (Cassandra Keyspace)
|
||||
- config-service
|
||||
- cores-service
|
||||
- librarian-service
|
||||
|
||||
### Services Affected by Changes 7-11 (Collection Management)
|
||||
|
||||
**Immediate Changes:**
|
||||
- librarian-service (collection_manager.py, service.py)
|
||||
- tables/library.py (collections table removal)
|
||||
- schema/services/collection.py (timestamp removal)
|
||||
|
||||
**Completed Changes (Change 10):** ✅
|
||||
- All storage services (11 total) - migrated to config push for collection updates via `CollectionConfigHandler`
|
||||
- Storage management schema removed from `storage.py`
|
||||
|
||||
## Future Considerations
|
||||
|
||||
### Per-User Keyspace Model
|
||||
|
||||
Some services use **per-user keyspaces** dynamically, where each user gets their own Cassandra keyspace:
|
||||
|
||||
**Services with per-user keyspaces:**
|
||||
1. **Triples Query Service** (`trustgraph-flow/trustgraph/query/triples/cassandra/service.py:65`)
|
||||
- Uses `keyspace=query.user`
|
||||
2. **Objects Query Service** (`trustgraph-flow/trustgraph/query/objects/cassandra/service.py:479`)
|
||||
- Uses `keyspace=self.sanitize_name(user)`
|
||||
3. **KnowledgeGraph Direct Access** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py:18`)
|
||||
- Default parameter `keyspace="trustgraph"`
|
||||
|
||||
**Status:** These are **not modified** in this specification.
|
||||
|
||||
**Future Review Required:**
|
||||
- Evaluate whether per-user keyspace model creates tenant isolation issues
|
||||
- Consider if multi-tenant deployments need keyspace prefix patterns (e.g., `tenant_a_user1`)
|
||||
- Review for potential user ID collision across tenants
|
||||
- Assess if single shared keyspace per tenant with user-based row isolation is preferable
|
||||
|
||||
**Note:** This does not block the current multi-tenant implementation but should be reviewed before production multi-tenant deployments.
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Parameter Fixes (Changes 1-6)
|
||||
- Fix `--config-push-queue` parameter naming
|
||||
- Add `--cassandra-keyspace` parameter support
|
||||
- **Outcome:** Multi-tenant queue and keyspace configuration enabled
|
||||
|
||||
### Phase 2: Collection Management Migration (Changes 7-9, 11)
|
||||
- Migrate collection storage to config service
|
||||
- Remove collections table from librarian
|
||||
- Update collection schema (remove timestamps)
|
||||
- **Outcome:** Eliminates hardcoded storage management topics, simplifies librarian
|
||||
|
||||
### Phase 3: Storage Service Updates (Change 10) ✅ COMPLETED
|
||||
- Updated all storage services to use config push for collections via `CollectionConfigHandler`
|
||||
- Removed storage management request/response infrastructure
|
||||
- Removed legacy schema definitions
|
||||
- **Outcome:** Complete config-based collection management achieved
|
||||
|
||||
## References
|
||||
- GitHub Issue: https://github.com/trustgraph-ai/trustgraph/issues/582
|
||||
- Related Files:
|
||||
- `trustgraph-base/trustgraph/base/async_processor.py`
|
||||
- `trustgraph-base/trustgraph/base/cassandra_config.py`
|
||||
- `trustgraph-base/trustgraph/schema/core/topic.py`
|
||||
- `trustgraph-base/trustgraph/schema/services/collection.py`
|
||||
- `trustgraph-flow/trustgraph/config/service/service.py`
|
||||
- `trustgraph-flow/trustgraph/cores/service.py`
|
||||
- `trustgraph-flow/trustgraph/librarian/service.py`
|
||||
- `trustgraph-flow/trustgraph/librarian/collection_manager.py`
|
||||
- `trustgraph-flow/trustgraph/tables/library.py`
|
||||
761
docs/tech-specs/ontology-extract-phase-2.md
Normal file
761
docs/tech-specs/ontology-extract-phase-2.md
Normal file
|
|
@ -0,0 +1,761 @@
|
|||
# Ontology Knowledge Extraction - Phase 2 Refactor
|
||||
|
||||
**Status**: Draft
|
||||
**Author**: Analysis Session 2025-12-03
|
||||
**Related**: `ontology.md`, `ontorag.md`
|
||||
|
||||
## Overview
|
||||
|
||||
This document identifies inconsistencies in the current ontology-based knowledge extraction system and proposes a refactor to improve LLM performance and reduce information loss.
|
||||
|
||||
## Current Implementation
|
||||
|
||||
### How It Works Now
|
||||
|
||||
1. **Ontology Loading** (`ontology_loader.py`)
|
||||
- Loads ontology JSON with keys like `"fo/Recipe"`, `"fo/Food"`, `"fo/produces"`
|
||||
- Class IDs include namespace prefix in the key itself
|
||||
- Example from `food.ontology`:
|
||||
```json
|
||||
"classes": {
|
||||
"fo/Recipe": {
|
||||
"uri": "http://purl.org/ontology/fo/Recipe",
|
||||
"rdfs:comment": "A Recipe is a combination..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Prompt Construction** (`extract.py:299-307`, `ontology-prompt.md`)
|
||||
- Template receives `classes`, `object_properties`, `datatype_properties` dicts
|
||||
- Template iterates: `{% for class_id, class_def in classes.items() %}`
|
||||
- LLM sees: `**fo/Recipe**: A Recipe is a combination...`
|
||||
- Example output format shows:
|
||||
```json
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"}
|
||||
```
|
||||
|
||||
3. **Response Parsing** (`extract.py:382-428`)
|
||||
- Expects JSON array: `[{"subject": "...", "predicate": "...", "object": "..."}]`
|
||||
- Validates against ontology subset
|
||||
- Expands URIs via `expand_uri()` (extract.py:473-521)
|
||||
|
||||
4. **URI Expansion** (`extract.py:473-521`)
|
||||
- Checks if value is in `ontology_subset.classes` dict
|
||||
- If found, extracts URI from class definition
|
||||
- If not found, constructs URI: `f"https://trustgraph.ai/ontology/{ontology_id}#{value}"`
|
||||
|
||||
### Data Flow Example
|
||||
|
||||
**Ontology JSON → Loader → Prompt:**
|
||||
```
|
||||
"fo/Recipe" → classes["fo/Recipe"] → LLM sees "**fo/Recipe**"
|
||||
```
|
||||
|
||||
**LLM → Parser → Output:**
|
||||
```
|
||||
"Recipe" → not in classes["fo/Recipe"] → constructs URI → LOSES original URI
|
||||
"fo/Recipe" → found in classes → uses original URI → PRESERVES URI
|
||||
```
|
||||
|
||||
## Problems Identified
|
||||
|
||||
### 1. **Inconsistent Examples in Prompt**
|
||||
|
||||
**Issue**: The prompt template shows class IDs with prefixes (`fo/Recipe`) but the example output uses unprefixed class names (`Recipe`).
|
||||
|
||||
**Location**: `ontology-prompt.md:5-52`
|
||||
|
||||
```markdown
|
||||
## Ontology Classes:
|
||||
- **fo/Recipe**: A Recipe is...
|
||||
|
||||
## Example Output:
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
|
||||
```
|
||||
|
||||
**Impact**: LLM receives conflicting signals about what format to use.
|
||||
|
||||
### 2. **Information Loss in URI Expansion**
|
||||
|
||||
**Issue**: When LLM returns unprefixed class names following the example, `expand_uri()` can't find them in the ontology dict and constructs fallback URIs, losing the original proper URIs.
|
||||
|
||||
**Location**: `extract.py:494-500`
|
||||
|
||||
```python
|
||||
if value in ontology_subset.classes: # Looks for "Recipe"
|
||||
class_def = ontology_subset.classes[value] # But key is "fo/Recipe"
|
||||
if isinstance(class_def, dict) and 'uri' in class_def:
|
||||
return class_def['uri'] # Never reached!
|
||||
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" # Fallback
|
||||
```
|
||||
|
||||
**Impact**:
|
||||
- Original URI: `http://purl.org/ontology/fo/Recipe`
|
||||
- Constructed URI: `https://trustgraph.ai/ontology/food#Recipe`
|
||||
- Semantic meaning lost, breaks interoperability
|
||||
|
||||
### 3. **Ambiguous Entity Instance Format**
|
||||
|
||||
**Issue**: No clear guidance on entity instance URI format.
|
||||
|
||||
**Examples in prompt**:
|
||||
- `"recipe:cornish-pasty"` (namespace-like prefix)
|
||||
- `"ingredient:flour"` (different prefix)
|
||||
|
||||
**Actual behavior** (extract.py:517-520):
|
||||
```python
|
||||
# Treat as entity instance - construct unique URI
|
||||
normalized = value.replace(" ", "-").lower()
|
||||
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
|
||||
```
|
||||
|
||||
**Impact**: LLM must guess prefixing convention with no ontology context.
|
||||
|
||||
### 4. **No Namespace Prefix Guidance**
|
||||
|
||||
**Issue**: The ontology JSON contains namespace definitions (line 10-25 in food.ontology):
|
||||
```json
|
||||
"namespaces": {
|
||||
"fo": "http://purl.org/ontology/fo/",
|
||||
"rdf": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
But these are never surfaced to the LLM. The LLM doesn't know:
|
||||
- What "fo" means
|
||||
- What prefix to use for entities
|
||||
- Which namespace applies to which elements
|
||||
|
||||
### 5. **Labels Not Used in Prompt**
|
||||
|
||||
**Issue**: Every class has `rdfs:label` fields (e.g., `{"value": "Recipe", "lang": "en-gb"}`), but the prompt template doesn't use them.
|
||||
|
||||
**Current**: Shows only `class_id` and `comment`
|
||||
```jinja
|
||||
- **{{class_id}}**{% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||
```
|
||||
|
||||
**Available but unused**:
|
||||
```python
|
||||
"rdfs:label": [{"value": "Recipe", "lang": "en-gb"}]
|
||||
```
|
||||
|
||||
**Impact**: Could provide human-readable names alongside technical IDs.
|
||||
|
||||
## Proposed Solutions
|
||||
|
||||
### Option A: Normalize to Unprefixed IDs
|
||||
|
||||
**Approach**: Strip prefixes from class IDs before showing to LLM.
|
||||
|
||||
**Changes**:
|
||||
1. Modify `build_extraction_variables()` to transform keys:
|
||||
```python
|
||||
classes_for_prompt = {
|
||||
k.split('/')[-1]: v # "fo/Recipe" → "Recipe"
|
||||
for k, v in ontology_subset.classes.items()
|
||||
}
|
||||
```
|
||||
|
||||
2. Update prompt example to match (already uses unprefixed names)
|
||||
|
||||
3. Modify `expand_uri()` to handle both formats:
|
||||
```python
|
||||
# Try exact match first
|
||||
if value in ontology_subset.classes:
|
||||
return ontology_subset.classes[value]['uri']
|
||||
|
||||
# Try with prefix
|
||||
for prefix in ['fo/', 'rdf:', 'rdfs:']:
|
||||
prefixed = f"{prefix}{value}"
|
||||
if prefixed in ontology_subset.classes:
|
||||
return ontology_subset.classes[prefixed]['uri']
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Cleaner, more human-readable
|
||||
- Matches existing prompt examples
|
||||
- LLMs work better with simpler tokens
|
||||
|
||||
**Cons**:
|
||||
- Class name collisions if multiple ontologies have same class name
|
||||
- Loses namespace information
|
||||
- Requires fallback logic for lookups
|
||||
|
||||
### Option B: Use Full Prefixed IDs Consistently
|
||||
|
||||
**Approach**: Update examples to use prefixed IDs matching what's shown in the class list.
|
||||
|
||||
**Changes**:
|
||||
1. Update prompt example (ontology-prompt.md:46-52):
|
||||
```json
|
||||
[
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "fo/Recipe"},
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "fo/produces", "object": "food:cornish-pasty"},
|
||||
{"subject": "food:cornish-pasty", "predicate": "rdf:type", "object": "fo/Food"}
|
||||
]
|
||||
```
|
||||
|
||||
2. Add namespace explanation to prompt:
|
||||
```markdown
|
||||
## Namespace Prefixes:
|
||||
- **fo/**: Food Ontology (http://purl.org/ontology/fo/)
|
||||
- **rdf:**: RDF Schema
|
||||
- **rdfs:**: RDF Schema
|
||||
|
||||
Use these prefixes exactly as shown when referencing classes and properties.
|
||||
```
|
||||
|
||||
3. Keep `expand_uri()` as-is (works correctly when matches found)
|
||||
|
||||
**Pros**:
|
||||
- Input = Output consistency
|
||||
- No information loss
|
||||
- Preserves namespace semantics
|
||||
- Works with multiple ontologies
|
||||
|
||||
**Cons**:
|
||||
- More verbose tokens for LLM
|
||||
- Requires LLM to track prefixes
|
||||
|
||||
### Option C: Hybrid - Show Both Label and ID
|
||||
|
||||
**Approach**: Enhance prompt to show both human-readable labels and technical IDs.
|
||||
|
||||
**Changes**:
|
||||
1. Update prompt template:
|
||||
```jinja
|
||||
{% for class_id, class_def in classes.items() %}
|
||||
- **{{class_id}}** (label: "{{class_def.labels[0].value if class_def.labels else class_id}}"){% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
Example output:
|
||||
```markdown
|
||||
- **fo/Recipe** (label: "Recipe"): A Recipe is a combination...
|
||||
```
|
||||
|
||||
2. Update instructions:
|
||||
```markdown
|
||||
When referencing classes:
|
||||
- Use the full prefixed ID (e.g., "fo/Recipe") in JSON output
|
||||
- The label (e.g., "Recipe") is for human understanding only
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Clearest for LLM
|
||||
- Preserves all information
|
||||
- Explicit about what to use
|
||||
|
||||
**Cons**:
|
||||
- Longer prompt
|
||||
- More complex template
|
||||
|
||||
## Implemented Approach
|
||||
|
||||
**Simplified Entity-Relationship-Attribute Format** - completely replaces the old triple-based format.
|
||||
|
||||
The new approach was chosen because:
|
||||
|
||||
1. **No Information Loss**: Original URIs preserved correctly
|
||||
2. **Simpler Logic**: No transformation needed, direct dict lookups work
|
||||
3. **Namespace Safety**: Handles multiple ontologies without collisions
|
||||
4. **Semantic Correctness**: Maintains RDF/OWL semantics
|
||||
|
||||
## Implementation Complete
|
||||
|
||||
### What Was Built:
|
||||
|
||||
1. **New Prompt Template** (`prompts/ontology-extract-v2.txt`)
|
||||
- ✅ Clear sections: Entity Types, Relationships, Attributes
|
||||
- ✅ Example using full type identifiers (`fo/Recipe`, `fo/has_ingredient`)
|
||||
- ✅ Instructions to use exact identifiers from schema
|
||||
- ✅ New JSON format with entities/relationships/attributes arrays
|
||||
|
||||
2. **Entity Normalization** (`entity_normalizer.py`)
|
||||
- ✅ `normalize_entity_name()` - Converts names to URI-safe format
|
||||
- ✅ `normalize_type_identifier()` - Handles slashes in types (`fo/Recipe` → `fo-recipe`)
|
||||
- ✅ `build_entity_uri()` - Creates unique URIs using (name, type) tuple
|
||||
- ✅ `EntityRegistry` - Tracks entities for deduplication
|
||||
|
||||
3. **JSON Parser** (`simplified_parser.py`)
|
||||
- ✅ Parses new format: `{entities: [...], relationships: [...], attributes: [...]}`
|
||||
- ✅ Supports kebab-case and snake_case field names
|
||||
- ✅ Returns structured dataclasses
|
||||
- ✅ Graceful error handling with logging
|
||||
|
||||
4. **Triple Converter** (`triple_converter.py`)
|
||||
- ✅ `convert_entity()` - Generates type + label triples automatically
|
||||
- ✅ `convert_relationship()` - Connects entity URIs via properties
|
||||
- ✅ `convert_attribute()` - Adds literal values
|
||||
- ✅ Looks up full URIs from ontology definitions
|
||||
|
||||
5. **Updated Main Processor** (`extract.py`)
|
||||
- ✅ Removed old triple-based extraction code
|
||||
- ✅ Added `extract_with_simplified_format()` method
|
||||
- ✅ Now exclusively uses new simplified format
|
||||
- ✅ Calls prompt with `extract-with-ontologies-v2` ID
|
||||
|
||||
## Test Cases
|
||||
|
||||
### Test 1: URI Preservation
|
||||
```python
|
||||
# Given ontology class
|
||||
classes = {"fo/Recipe": {"uri": "http://purl.org/ontology/fo/Recipe", ...}}
|
||||
|
||||
# When LLM returns
|
||||
llm_output = {"subject": "x", "predicate": "rdf:type", "object": "fo/Recipe"}
|
||||
|
||||
# Then expanded URI should be
|
||||
assert expanded == "http://purl.org/ontology/fo/Recipe"
|
||||
# Not: "https://trustgraph.ai/ontology/food#Recipe"
|
||||
```
|
||||
|
||||
### Test 2: Multi-Ontology Collision
|
||||
```python
|
||||
# Given two ontologies
|
||||
ont1 = {"fo/Recipe": {...}}
|
||||
ont2 = {"cooking/Recipe": {...}}
|
||||
|
||||
# LLM should use full prefix to disambiguate
|
||||
llm_output = {"object": "fo/Recipe"} # Not just "Recipe"
|
||||
```
|
||||
|
||||
### Test 3: Entity Instance Format
|
||||
```python
|
||||
# Given prompt with food ontology
|
||||
# LLM should create instances like
|
||||
{"subject": "recipe:cornish-pasty"} # Namespace-style
|
||||
{"subject": "food:beef"} # Consistent prefix
|
||||
```
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **Should entity instances use namespace prefixes?**
|
||||
- Current: `"recipe:cornish-pasty"` (arbitrary)
|
||||
- Alternative: Use ontology prefix `"fo:cornish-pasty"`?
|
||||
- Alternative: No prefix, expand in URI `"cornish-pasty"` → full URI?
|
||||
|
||||
2. **How to handle domain/range in prompt?**
|
||||
- Currently shows: `(Recipe → Food)`
|
||||
- Should it be: `(fo/Recipe → fo/Food)`?
|
||||
|
||||
3. **Should we validate domain/range constraints?**
|
||||
- TODO comment at extract.py:470
|
||||
- Would catch more errors but more complex
|
||||
|
||||
4. **What about inverse properties and equivalences?**
|
||||
- Ontology has `owl:inverseOf`, `owl:equivalentClass`
|
||||
- Not currently used in extraction
|
||||
- Should they be?
|
||||
|
||||
## Success Metrics
|
||||
|
||||
- ✅ Zero URI information loss (100% preservation of original URIs)
|
||||
- ✅ LLM output format matches input format
|
||||
- ✅ No ambiguous examples in prompt
|
||||
- ✅ Tests pass with multiple ontologies
|
||||
- ✅ Improved extraction quality (measured by valid triple %)
|
||||
|
||||
## Alternative Approach: Simplified Extraction Format
|
||||
|
||||
### Philosophy
|
||||
|
||||
Instead of asking the LLM to understand RDF/OWL semantics, ask it to do what it's good at: **find entities and relationships in text**.
|
||||
|
||||
Let the code handle URI construction, RDF conversion, and semantic web formalities.
|
||||
|
||||
### Example: Entity Classification
|
||||
|
||||
**Input Text:**
|
||||
```
|
||||
Cornish pasty is a traditional British pastry filled with meat and vegetables.
|
||||
```
|
||||
|
||||
**Ontology Schema (shown to LLM):**
|
||||
```markdown
|
||||
## Entity Types:
|
||||
- Recipe: A recipe is a combination of ingredients and a method
|
||||
- Food: A food is something that can be eaten
|
||||
- Ingredient: An ingredient combines a quantity and a food
|
||||
```
|
||||
|
||||
**What LLM Returns (Simple JSON):**
|
||||
```json
|
||||
{
|
||||
"entities": [
|
||||
{
|
||||
"entity": "Cornish pasty",
|
||||
"type": "Recipe"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**What Code Produces (RDF Triples):**
|
||||
```python
|
||||
# 1. Normalize entity name + type to ID (type prevents collisions)
|
||||
entity_id = "recipe-cornish-pasty" # normalize("Cornish pasty", "Recipe")
|
||||
entity_uri = "https://trustgraph.ai/food/recipe-cornish-pasty"
|
||||
|
||||
# Note: Same name, different type = different URI
|
||||
# "Cornish pasty" (Recipe) → recipe-cornish-pasty
|
||||
# "Cornish pasty" (Food) → food-cornish-pasty
|
||||
|
||||
# 2. Generate triples
|
||||
triples = [
|
||||
# Type triple
|
||||
Triple(
|
||||
s=Value(value=entity_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
|
||||
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
|
||||
),
|
||||
# Label triple (automatic)
|
||||
Triple(
|
||||
s=Value(value=entity_uri, is_uri=True),
|
||||
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
|
||||
o=Value(value="Cornish pasty", is_uri=False)
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **LLM doesn't need to:**
|
||||
- Understand URI syntax
|
||||
- Invent identifier prefixes (`recipe:`, `ingredient:`)
|
||||
- Know about `rdf:type` or `rdfs:label`
|
||||
- Construct semantic web identifiers
|
||||
|
||||
2. **LLM just needs to:**
|
||||
- Find entities in text
|
||||
- Map them to ontology classes
|
||||
- Extract relationships and attributes
|
||||
|
||||
3. **Code handles:**
|
||||
- URI normalization and construction
|
||||
- RDF triple generation
|
||||
- Automatic label assignment
|
||||
- Namespace management
|
||||
|
||||
### Why This Works Better
|
||||
|
||||
- **Simpler prompt** = less confusion = fewer errors
|
||||
- **Consistent IDs** = code controls normalization rules
|
||||
- **Auto-generated labels** = no missing rdfs:label triples
|
||||
- **LLM focuses on extraction** = what it's actually good at
|
||||
|
||||
### Example: Entity Relationships
|
||||
|
||||
**Input Text:**
|
||||
```
|
||||
Cornish pasty is a traditional British pastry filled with beef and potatoes.
|
||||
```
|
||||
|
||||
**Ontology Schema (shown to LLM):**
|
||||
```markdown
|
||||
## Entity Types:
|
||||
- Recipe: A recipe is a combination of ingredients and a method
|
||||
- Food: A food is something that can be eaten
|
||||
- Ingredient: An ingredient combines a quantity and a food
|
||||
|
||||
## Relationships:
|
||||
- has_ingredient: Relates a recipe to an ingredient it uses (Recipe → Ingredient)
|
||||
- food: Relates an ingredient to the food that is required (Ingredient → Food)
|
||||
```
|
||||
|
||||
**What LLM Returns (Simple JSON):**
|
||||
```json
|
||||
{
|
||||
"entities": [
|
||||
{
|
||||
"entity": "Cornish pasty",
|
||||
"type": "Recipe"
|
||||
},
|
||||
{
|
||||
"entity": "beef",
|
||||
"type": "Food"
|
||||
},
|
||||
{
|
||||
"entity": "potatoes",
|
||||
"type": "Food"
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Cornish pasty",
|
||||
"subject-type": "Recipe",
|
||||
"relation": "has_ingredient",
|
||||
"object": "beef",
|
||||
"object-type": "Food"
|
||||
},
|
||||
{
|
||||
"subject": "Cornish pasty",
|
||||
"subject-type": "Recipe",
|
||||
"relation": "has_ingredient",
|
||||
"object": "potatoes",
|
||||
"object-type": "Food"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**What Code Produces (RDF Triples):**
|
||||
```python
|
||||
# Normalize entity names to URIs
|
||||
cornish_pasty_uri = "https://trustgraph.ai/food/cornish-pasty"
|
||||
beef_uri = "https://trustgraph.ai/food/beef"
|
||||
potatoes_uri = "https://trustgraph.ai/food/potatoes"
|
||||
|
||||
# Look up relation URI from ontology
|
||||
has_ingredient_uri = "http://purl.org/ontology/fo/ingredients" # from fo/has_ingredient
|
||||
|
||||
triples = [
|
||||
# Entity type triples (as before)
|
||||
Triple(s=cornish_pasty_uri, p=rdf_type, o="http://purl.org/ontology/fo/Recipe"),
|
||||
Triple(s=cornish_pasty_uri, p=rdfs_label, o="Cornish pasty"),
|
||||
|
||||
Triple(s=beef_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
|
||||
Triple(s=beef_uri, p=rdfs_label, o="beef"),
|
||||
|
||||
Triple(s=potatoes_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
|
||||
Triple(s=potatoes_uri, p=rdfs_label, o="potatoes"),
|
||||
|
||||
# Relationship triples
|
||||
Triple(
|
||||
s=Value(value=cornish_pasty_uri, is_uri=True),
|
||||
p=Value(value=has_ingredient_uri, is_uri=True),
|
||||
o=Value(value=beef_uri, is_uri=True)
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=cornish_pasty_uri, is_uri=True),
|
||||
p=Value(value=has_ingredient_uri, is_uri=True),
|
||||
o=Value(value=potatoes_uri, is_uri=True)
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- LLM returns natural language entity names: `"Cornish pasty"`, `"beef"`, `"potatoes"`
|
||||
- LLM includes types to disambiguate: `subject-type`, `object-type`
|
||||
- LLM uses relation name from schema: `"has_ingredient"`
|
||||
- Code derives consistent IDs using (name, type): `("Cornish pasty", "Recipe")` → `recipe-cornish-pasty`
|
||||
- Code looks up relation URI from ontology: `fo/has_ingredient` → full URI
|
||||
- Same (name, type) tuple always gets same URI (deduplication)
|
||||
|
||||
### Example: Entity Name Disambiguation
|
||||
|
||||
**Problem:** Same name can refer to different entity types.
|
||||
|
||||
**Real-world case:**
|
||||
```
|
||||
"Cornish pasty" can be:
|
||||
- A Recipe (instructions for making it)
|
||||
- A Food (the dish itself)
|
||||
```
|
||||
|
||||
**How It's Handled:**
|
||||
|
||||
LLM returns both as separate entities:
|
||||
```json
|
||||
{
|
||||
"entities": [
|
||||
{"entity": "Cornish pasty", "type": "Recipe"},
|
||||
{"entity": "Cornish pasty", "type": "Food"}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Cornish pasty",
|
||||
"subject-type": "Recipe",
|
||||
"relation": "produces",
|
||||
"object": "Cornish pasty",
|
||||
"object-type": "Food"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Code Resolution:**
|
||||
```python
|
||||
# Different types → different URIs
|
||||
recipe_uri = normalize("Cornish pasty", "Recipe")
|
||||
# → "https://trustgraph.ai/food/recipe-cornish-pasty"
|
||||
|
||||
food_uri = normalize("Cornish pasty", "Food")
|
||||
# → "https://trustgraph.ai/food/food-cornish-pasty"
|
||||
|
||||
# Relationship connects them correctly
|
||||
triple = Triple(
|
||||
s=recipe_uri, # The Recipe
|
||||
p="http://purl.org/ontology/fo/produces",
|
||||
o=food_uri # The Food
|
||||
)
|
||||
```
|
||||
|
||||
**Why This Works:**
|
||||
- Type is included in ALL references (entities, relationships, attributes)
|
||||
- Code uses `(name, type)` tuple as lookup key
|
||||
- No ambiguity, no collisions
|
||||
|
||||
### Example: Entity Attributes
|
||||
|
||||
**Input Text:**
|
||||
```
|
||||
This Cornish pasty recipe serves 4-6 people and takes 45 minutes to prepare.
|
||||
```
|
||||
|
||||
**Ontology Schema (shown to LLM):**
|
||||
```markdown
|
||||
## Entity Types:
|
||||
- Recipe: A recipe is a combination of ingredients and a method
|
||||
|
||||
## Attributes:
|
||||
- serves: Indicates what the recipe is intended to serve (Recipe → text)
|
||||
- preparation_time: Time needed to prepare the recipe (Recipe → text)
|
||||
```
|
||||
|
||||
**What LLM Returns (Simple JSON):**
|
||||
```json
|
||||
{
|
||||
"entities": [
|
||||
{
|
||||
"entity": "Cornish pasty recipe",
|
||||
"type": "Recipe"
|
||||
}
|
||||
],
|
||||
"attributes": [
|
||||
{
|
||||
"entity": "Cornish pasty recipe",
|
||||
"entity-type": "Recipe",
|
||||
"attribute": "serves",
|
||||
"value": "4-6 people"
|
||||
},
|
||||
{
|
||||
"entity": "Cornish pasty recipe",
|
||||
"entity-type": "Recipe",
|
||||
"attribute": "preparation_time",
|
||||
"value": "45 minutes"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**What Code Produces (RDF Triples):**
|
||||
```python
|
||||
# Normalize entity name to URI
|
||||
recipe_uri = "https://trustgraph.ai/food/cornish-pasty-recipe"
|
||||
|
||||
# Look up attribute URIs from ontology
|
||||
serves_uri = "http://purl.org/ontology/fo/serves" # from fo/serves
|
||||
prep_time_uri = "http://purl.org/ontology/fo/preparation_time" # from fo/preparation_time
|
||||
|
||||
triples = [
|
||||
# Entity type triple
|
||||
Triple(
|
||||
s=Value(value=recipe_uri, is_uri=True),
|
||||
p=Value(value=rdf_type, is_uri=True),
|
||||
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
|
||||
),
|
||||
|
||||
# Label triple (automatic)
|
||||
Triple(
|
||||
s=Value(value=recipe_uri, is_uri=True),
|
||||
p=Value(value=rdfs_label, is_uri=True),
|
||||
o=Value(value="Cornish pasty recipe", is_uri=False)
|
||||
),
|
||||
|
||||
# Attribute triples (objects are literals, not URIs)
|
||||
Triple(
|
||||
s=Value(value=recipe_uri, is_uri=True),
|
||||
p=Value(value=serves_uri, is_uri=True),
|
||||
o=Value(value="4-6 people", is_uri=False) # Literal value!
|
||||
),
|
||||
Triple(
|
||||
s=Value(value=recipe_uri, is_uri=True),
|
||||
p=Value(value=prep_time_uri, is_uri=True),
|
||||
o=Value(value="45 minutes", is_uri=False) # Literal value!
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- LLM extracts literal values: `"4-6 people"`, `"45 minutes"`
|
||||
- LLM includes entity type for disambiguation: `entity-type`
|
||||
- LLM uses attribute name from schema: `"serves"`, `"preparation_time"`
|
||||
- Code looks up attribute URI from ontology datatype properties
|
||||
- **Object is literal** (`is_uri=False`), not a URI reference
|
||||
- Values stay as natural text, no normalization needed
|
||||
|
||||
**Difference from Relationships:**
|
||||
- Relationships: both subject and object are entities (URIs)
|
||||
- Attributes: subject is entity (URI), object is literal value (string/number)
|
||||
|
||||
### Complete Example: Entities + Relationships + Attributes
|
||||
|
||||
**Input Text:**
|
||||
```
|
||||
Cornish pasty is a savory pastry filled with beef and potatoes.
|
||||
This recipe serves 4 people.
|
||||
```
|
||||
|
||||
**What LLM Returns:**
|
||||
```json
|
||||
{
|
||||
"entities": [
|
||||
{
|
||||
"entity": "Cornish pasty",
|
||||
"type": "Recipe"
|
||||
},
|
||||
{
|
||||
"entity": "beef",
|
||||
"type": "Food"
|
||||
},
|
||||
{
|
||||
"entity": "potatoes",
|
||||
"type": "Food"
|
||||
}
|
||||
],
|
||||
"relationships": [
|
||||
{
|
||||
"subject": "Cornish pasty",
|
||||
"subject-type": "Recipe",
|
||||
"relation": "has_ingredient",
|
||||
"object": "beef",
|
||||
"object-type": "Food"
|
||||
},
|
||||
{
|
||||
"subject": "Cornish pasty",
|
||||
"subject-type": "Recipe",
|
||||
"relation": "has_ingredient",
|
||||
"object": "potatoes",
|
||||
"object-type": "Food"
|
||||
}
|
||||
],
|
||||
"attributes": [
|
||||
{
|
||||
"entity": "Cornish pasty",
|
||||
"entity-type": "Recipe",
|
||||
"attribute": "serves",
|
||||
"value": "4 people"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** 11 RDF triples generated:
|
||||
- 3 entity type triples (rdf:type)
|
||||
- 3 entity label triples (rdfs:label) - automatic
|
||||
- 2 relationship triples (has_ingredient)
|
||||
- 1 attribute triple (serves)
|
||||
|
||||
All from simple, natural language extractions by the LLM!
|
||||
|
||||
## References
|
||||
|
||||
- Current implementation: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py`
|
||||
- Prompt template: `ontology-prompt.md`
|
||||
- Test cases: `tests/unit/test_extract/test_ontology/`
|
||||
- Example ontology: `e2e/test-data/food.ontology`
|
||||
958
docs/tech-specs/pubsub.md
Normal file
958
docs/tech-specs/pubsub.md
Normal file
|
|
@ -0,0 +1,958 @@
|
|||
# Pub/Sub Infrastructure
|
||||
|
||||
## Overview
|
||||
|
||||
This document catalogs all connections between the TrustGraph codebase and the pub/sub infrastructure. Currently, the system is hardcoded to use Apache Pulsar. This analysis identifies all integration points to inform future refactoring toward a configurable pub/sub abstraction.
|
||||
|
||||
## Current State: Pulsar Integration Points
|
||||
|
||||
### 1. Direct Pulsar Client Usage
|
||||
|
||||
**Location:** `trustgraph-flow/trustgraph/gateway/service.py`
|
||||
|
||||
The API gateway directly imports and instantiates the Pulsar client:
|
||||
|
||||
- **Line 20:** `import pulsar`
|
||||
- **Lines 54-61:** Direct instantiation of `pulsar.Client()` with optional `pulsar.AuthenticationToken()`
|
||||
- **Lines 33-35:** Default Pulsar host configuration from environment variables
|
||||
- **Lines 178-192:** CLI arguments for `--pulsar-host`, `--pulsar-api-key`, and `--pulsar-listener`
|
||||
- **Lines 78, 124:** Passes `pulsar_client` to `ConfigReceiver` and `DispatcherManager`
|
||||
|
||||
This is the only location that directly instantiates a Pulsar client outside of the abstraction layer.
|
||||
|
||||
### 2. Base Processor Framework
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/base/async_processor.py`
|
||||
|
||||
The base class for all processors provides Pulsar connectivity:
|
||||
|
||||
- **Line 9:** `import _pulsar` (for exception handling)
|
||||
- **Line 18:** `from . pubsub import PulsarClient`
|
||||
- **Line 38:** Creates `pulsar_client_object = PulsarClient(**params)`
|
||||
- **Lines 104-108:** Properties exposing `pulsar_host` and `pulsar_client`
|
||||
- **Line 250:** Static method `add_args()` calls `PulsarClient.add_args(parser)` for CLI arguments
|
||||
- **Lines 223-225:** Exception handling for `_pulsar.Interrupted`
|
||||
|
||||
All processors inherit from `AsyncProcessor`, making this the central integration point.
|
||||
|
||||
### 3. Consumer Abstraction
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/base/consumer.py`
|
||||
|
||||
Consumes messages from queues and invokes handler functions:
|
||||
|
||||
**Pulsar imports:**
|
||||
- **Line 12:** `from pulsar.schema import JsonSchema`
|
||||
- **Line 13:** `import pulsar`
|
||||
- **Line 14:** `import _pulsar`
|
||||
|
||||
**Pulsar-specific usage:**
|
||||
- **Lines 100, 102:** `pulsar.InitialPosition.Earliest` / `pulsar.InitialPosition.Latest`
|
||||
- **Line 108:** `JsonSchema(self.schema)` wrapper
|
||||
- **Line 110:** `pulsar.ConsumerType.Shared`
|
||||
- **Lines 104-111:** `self.client.subscribe()` with Pulsar-specific parameters
|
||||
- **Lines 143, 150, 65:** `consumer.unsubscribe()` and `consumer.close()` methods
|
||||
- **Line 162:** `_pulsar.Timeout` exception
|
||||
- **Lines 182, 205, 232:** `consumer.acknowledge()` / `consumer.negative_acknowledge()`
|
||||
|
||||
**Spec file:** `trustgraph-base/trustgraph/base/consumer_spec.py`
|
||||
- **Line 22:** References `processor.pulsar_client`
|
||||
|
||||
### 4. Producer Abstraction
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/base/producer.py`
|
||||
|
||||
Sends messages to queues:
|
||||
|
||||
**Pulsar imports:**
|
||||
- **Line 2:** `from pulsar.schema import JsonSchema`
|
||||
|
||||
**Pulsar-specific usage:**
|
||||
- **Line 49:** `JsonSchema(self.schema)` wrapper
|
||||
- **Lines 47-51:** `self.client.create_producer()` with Pulsar-specific parameters (topic, schema, chunking_enabled)
|
||||
- **Lines 31, 76:** `producer.close()` method
|
||||
- **Lines 64-65:** `producer.send()` with message and properties
|
||||
|
||||
**Spec file:** `trustgraph-base/trustgraph/base/producer_spec.py`
|
||||
- **Line 18:** References `processor.pulsar_client`
|
||||
|
||||
### 5. Publisher Abstraction
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/base/publisher.py`
|
||||
|
||||
Asynchronous message publishing with queue buffering:
|
||||
|
||||
**Pulsar imports:**
|
||||
- **Line 2:** `from pulsar.schema import JsonSchema`
|
||||
- **Line 6:** `import pulsar`
|
||||
|
||||
**Pulsar-specific usage:**
|
||||
- **Line 52:** `JsonSchema(self.schema)` wrapper
|
||||
- **Lines 50-54:** `self.client.create_producer()` with Pulsar-specific parameters
|
||||
- **Lines 101, 103:** `producer.send()` with message and optional properties
|
||||
- **Lines 106-107:** `producer.flush()` and `producer.close()` methods
|
||||
|
||||
### 6. Subscriber Abstraction
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/base/subscriber.py`
|
||||
|
||||
Provides multi-recipient message distribution from queues:
|
||||
|
||||
**Pulsar imports:**
|
||||
- **Line 6:** `from pulsar.schema import JsonSchema`
|
||||
- **Line 8:** `import _pulsar`
|
||||
|
||||
**Pulsar-specific usage:**
|
||||
- **Line 55:** `JsonSchema(self.schema)` wrapper
|
||||
- **Line 57:** `self.client.subscribe(**subscribe_args)`
|
||||
- **Lines 101, 136, 160, 167-172:** Pulsar exceptions: `_pulsar.Timeout`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
|
||||
- **Lines 159, 166, 170:** Consumer methods: `negative_acknowledge()`, `unsubscribe()`, `close()`
|
||||
- **Lines 247, 251:** Message acknowledgment: `acknowledge()`, `negative_acknowledge()`
|
||||
|
||||
**Spec file:** `trustgraph-base/trustgraph/base/subscriber_spec.py`
|
||||
- **Line 19:** References `processor.pulsar_client`
|
||||
|
||||
### 7. Schema System (Heart of Darkness)
|
||||
|
||||
**Location:** `trustgraph-base/trustgraph/schema/`
|
||||
|
||||
Every message schema in the system is defined using Pulsar's schema framework.
|
||||
|
||||
**Core primitives:** `schema/core/primitives.py`
|
||||
- **Line 2:** `from pulsar.schema import Record, String, Boolean, Array, Integer`
|
||||
- All schemas inherit from Pulsar's `Record` base class
|
||||
- All field types are Pulsar types: `String()`, `Integer()`, `Boolean()`, `Array()`, `Map()`, `Double()`
|
||||
|
||||
**Example schemas:**
|
||||
- `schema/services/llm.py` (Line 2): `from pulsar.schema import Record, String, Array, Double, Integer, Boolean`
|
||||
- `schema/services/config.py` (Line 2): `from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer`
|
||||
|
||||
**Topic naming:** `schema/core/topic.py`
|
||||
- **Lines 2-3:** Topic format: `{kind}://{tenant}/{namespace}/{topic}`
|
||||
- This URI structure is Pulsar-specific (e.g., `persistent://tg/flow/config`)
|
||||
|
||||
**Impact:**
|
||||
- All request/response message definitions throughout the codebase use Pulsar schemas
|
||||
- This includes services for: config, flow, llm, prompt, query, storage, agent, collection, diagnosis, library, lookup, nlp_query, objects_query, retrieval, structured_query
|
||||
- Schema definitions are imported and used extensively across all processors and services
|
||||
|
||||
## Summary
|
||||
|
||||
### Pulsar Dependencies by Category
|
||||
|
||||
1. **Client instantiation:**
|
||||
- Direct: `gateway/service.py`
|
||||
- Abstracted: `async_processor.py` → `pubsub.py` (PulsarClient)
|
||||
|
||||
2. **Message transport:**
|
||||
- Consumer: `consumer.py`, `consumer_spec.py`
|
||||
- Producer: `producer.py`, `producer_spec.py`
|
||||
- Publisher: `publisher.py`
|
||||
- Subscriber: `subscriber.py`, `subscriber_spec.py`
|
||||
|
||||
3. **Schema system:**
|
||||
- Base types: `schema/core/primitives.py`
|
||||
- All service schemas: `schema/services/*.py`
|
||||
- Topic naming: `schema/core/topic.py`
|
||||
|
||||
4. **Pulsar-specific concepts required:**
|
||||
- Topic-based messaging
|
||||
- Schema system (Record, field types)
|
||||
- Shared subscriptions
|
||||
- Message acknowledgment (positive/negative)
|
||||
- Consumer positioning (earliest/latest)
|
||||
- Message properties
|
||||
- Initial positions and consumer types
|
||||
- Chunking support
|
||||
- Persistent vs non-persistent topics
|
||||
|
||||
### Refactoring Challenges
|
||||
|
||||
The good news: The abstraction layer (Consumer, Producer, Publisher, Subscriber) provides a clean encapsulation of most Pulsar interactions.
|
||||
|
||||
The challenges:
|
||||
1. **Schema system pervasiveness:** Every message definition uses `pulsar.schema.Record` and Pulsar field types
|
||||
2. **Pulsar-specific enums:** `InitialPosition`, `ConsumerType`
|
||||
3. **Pulsar exceptions:** `_pulsar.Timeout`, `_pulsar.Interrupted`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
|
||||
4. **Method signatures:** `acknowledge()`, `negative_acknowledge()`, `subscribe()`, `create_producer()`, etc.
|
||||
5. **Topic URI format:** Pulsar's `kind://tenant/namespace/topic` structure
|
||||
|
||||
### Next Steps
|
||||
|
||||
To make the pub/sub infrastructure configurable, we need to:
|
||||
|
||||
1. Create an abstraction interface for the client/schema system
|
||||
2. Abstract Pulsar-specific enums and exceptions
|
||||
3. Create schema wrappers or alternative schema definitions
|
||||
4. Implement the interface for both Pulsar and alternative systems (Kafka, RabbitMQ, Redis Streams, etc.)
|
||||
5. Update `pubsub.py` to be configurable and support multiple backends
|
||||
6. Provide migration path for existing deployments
|
||||
|
||||
## Approach Draft 1: Adapter Pattern with Schema Translation Layer
|
||||
|
||||
### Key Insight
|
||||
The **schema system** is the deepest integration point - everything else flows from it. We need to solve this first, or we'll be rewriting the entire codebase.
|
||||
|
||||
### Strategy: Minimal Disruption with Adapters
|
||||
|
||||
**1. Keep Pulsar schemas as the internal representation**
|
||||
- Don't rewrite all the schema definitions
|
||||
- Schemas remain `pulsar.schema.Record` internally
|
||||
- Use adapters to translate at the boundary between our code and the pub/sub backend
|
||||
|
||||
**2. Create a pub/sub abstraction layer:**
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Existing Code (unchanged) │
|
||||
│ - Uses Pulsar schemas internally │
|
||||
│ - Consumer/Producer/Publisher │
|
||||
└──────────────┬──────────────────────┘
|
||||
│
|
||||
┌──────────────┴──────────────────────┐
|
||||
│ PubSubFactory (configurable) │
|
||||
│ - Creates backend-specific client │
|
||||
└──────────────┬──────────────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│ │
|
||||
┌───────▼─────┐ ┌────▼─────────┐
|
||||
│ PulsarAdapter│ │ KafkaAdapter │ etc...
|
||||
│ (passthrough)│ │ (translates) │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
**3. Define abstract interfaces:**
|
||||
- `PubSubClient` - client connection
|
||||
- `PubSubProducer` - sending messages
|
||||
- `PubSubConsumer` - receiving messages
|
||||
- `SchemaAdapter` - translating Pulsar schemas to/from JSON or backend-specific formats
|
||||
|
||||
**4. Implementation details:**
|
||||
|
||||
For **Pulsar adapter**: Nearly passthrough, minimal translation
|
||||
|
||||
For **other backends** (Kafka, RabbitMQ, etc.):
|
||||
- Serialize Pulsar Record objects to JSON/bytes
|
||||
- Map concepts like:
|
||||
- `InitialPosition.Earliest/Latest` → Kafka's auto.offset.reset
|
||||
- `acknowledge()` → Kafka's commit
|
||||
- `negative_acknowledge()` → Re-queue or DLQ pattern
|
||||
- Topic URIs → Backend-specific topic names
|
||||
|
||||
### Analysis
|
||||
|
||||
**Pros:**
|
||||
- ✅ Minimal code changes to existing services
|
||||
- ✅ Schemas stay as-is (no massive rewrite)
|
||||
- ✅ Gradual migration path
|
||||
- ✅ Pulsar users see no difference
|
||||
- ✅ New backends added via adapters
|
||||
|
||||
**Cons:**
|
||||
- ⚠️ Still carries Pulsar dependency (for schema definitions)
|
||||
- ⚠️ Some impedance mismatch translating concepts
|
||||
|
||||
### Alternative Consideration
|
||||
|
||||
Create a **TrustGraph schema system** that's pub/sub agnostic (using dataclasses or Pydantic), then generate Pulsar/Kafka/etc schemas from it. This requires rewriting every schema file and potentially breaking changes.
|
||||
|
||||
### Recommendation for Draft 1
|
||||
|
||||
Start with the **adapter approach** because:
|
||||
1. It's pragmatic - works with existing code
|
||||
2. Proves the concept with minimal risk
|
||||
3. Can evolve to a native schema system later if needed
|
||||
4. Configuration-driven: one env var switches backends
|
||||
|
||||
## Approach Draft 2: Backend-Agnostic Schema System with Dataclasses
|
||||
|
||||
### Core Concept
|
||||
|
||||
Use Python **dataclasses** as the neutral schema definition format. Each pub/sub backend provides its own serialization/deserialization for dataclasses, eliminating the need for Pulsar schemas to remain in the codebase.
|
||||
|
||||
### Schema Polymorphism at the Factory Level
|
||||
|
||||
Instead of translating Pulsar schemas, **each backend provides its own schema handling** that works with standard Python dataclasses.
|
||||
|
||||
### Publisher Flow
|
||||
|
||||
```python
|
||||
# 1. Get the configured backend from factory
|
||||
pubsub = get_pubsub() # Returns PulsarBackend, MQTTBackend, etc.
|
||||
|
||||
# 2. Get schema class from the backend
|
||||
# (Can be imported directly - backend-agnostic)
|
||||
from trustgraph.schema.services.llm import TextCompletionRequest
|
||||
|
||||
# 3. Create a producer/publisher for a specific topic
|
||||
producer = pubsub.create_producer(
|
||||
topic="text-completion-requests",
|
||||
schema=TextCompletionRequest # Tells backend what schema to use
|
||||
)
|
||||
|
||||
# 4. Create message instances (same API regardless of backend)
|
||||
request = TextCompletionRequest(
|
||||
system="You are helpful",
|
||||
prompt="Hello world",
|
||||
streaming=False
|
||||
)
|
||||
|
||||
# 5. Send the message
|
||||
producer.send(request) # Backend serializes appropriately
|
||||
```
|
||||
|
||||
### Consumer Flow
|
||||
|
||||
```python
|
||||
# 1. Get the configured backend
|
||||
pubsub = get_pubsub()
|
||||
|
||||
# 2. Create a consumer
|
||||
consumer = pubsub.subscribe(
|
||||
topic="text-completion-requests",
|
||||
schema=TextCompletionRequest # Tells backend how to deserialize
|
||||
)
|
||||
|
||||
# 3. Receive and deserialize
|
||||
msg = consumer.receive()
|
||||
request = msg.value() # Returns TextCompletionRequest dataclass instance
|
||||
|
||||
# 4. Use the data (type-safe access)
|
||||
print(request.system) # "You are helpful"
|
||||
print(request.prompt) # "Hello world"
|
||||
print(request.streaming) # False
|
||||
```
|
||||
|
||||
### What Happens Behind the Scenes
|
||||
|
||||
**For Pulsar backend:**
|
||||
- `create_producer()` → creates Pulsar producer with JSON schema or dynamically generated Record
|
||||
- `send(request)` → serializes dataclass to JSON/Pulsar format, sends to Pulsar
|
||||
- `receive()` → gets Pulsar message, deserializes back to dataclass
|
||||
|
||||
**For MQTT backend:**
|
||||
- `create_producer()` → connects to MQTT broker, no schema registration needed
|
||||
- `send(request)` → converts dataclass to JSON, publishes to MQTT topic
|
||||
- `receive()` → subscribes to MQTT topic, deserializes JSON to dataclass
|
||||
|
||||
**For Kafka backend:**
|
||||
- `create_producer()` → creates Kafka producer, registers Avro schema if needed
|
||||
- `send(request)` → serializes dataclass to Avro format, sends to Kafka
|
||||
- `receive()` → gets Kafka message, deserializes Avro back to dataclass
|
||||
|
||||
### Key Design Points
|
||||
|
||||
1. **Schema object creation**: The dataclass instance (`TextCompletionRequest(...)`) is identical regardless of backend
|
||||
2. **Backend handles encoding**: Each backend knows how to serialize its dataclass to the wire format
|
||||
3. **Schema definition at creation**: When creating producer/consumer, you specify the schema type
|
||||
4. **Type safety preserved**: You get back a proper `TextCompletionRequest` object, not a dict
|
||||
5. **No backend leakage**: Application code never imports backend-specific libraries
|
||||
|
||||
### Example Transformation
|
||||
|
||||
**Current (Pulsar-specific):**
|
||||
```python
|
||||
# schema/services/llm.py
|
||||
from pulsar.schema import Record, String, Boolean, Integer
|
||||
|
||||
class TextCompletionRequest(Record):
|
||||
system = String()
|
||||
prompt = String()
|
||||
streaming = Boolean()
|
||||
```
|
||||
|
||||
**New (Backend-agnostic):**
|
||||
```python
|
||||
# schema/services/llm.py
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class TextCompletionRequest:
|
||||
system: str
|
||||
prompt: str
|
||||
streaming: bool = False
|
||||
```
|
||||
|
||||
### Backend Integration
|
||||
|
||||
Each backend handles serialization/deserialization of dataclasses:
|
||||
|
||||
**Pulsar backend:**
|
||||
- Dynamically generate `pulsar.schema.Record` classes from dataclasses
|
||||
- Or serialize dataclasses to JSON and use Pulsar's JSON schema
|
||||
- Maintains compatibility with existing Pulsar deployments
|
||||
|
||||
**MQTT/Redis backend:**
|
||||
- Direct JSON serialization of dataclass instances
|
||||
- Use `dataclasses.asdict()` / `from_dict()`
|
||||
- Lightweight, no schema registry needed
|
||||
|
||||
**Kafka backend:**
|
||||
- Generate Avro schemas from dataclass definitions
|
||||
- Use Confluent's schema registry
|
||||
- Type-safe serialization with schema evolution support
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Application Code │
|
||||
│ - Uses dataclass schemas │
|
||||
│ - Backend-agnostic │
|
||||
└──────────────┬──────────────────────┘
|
||||
│
|
||||
┌──────────────┴──────────────────────┐
|
||||
│ PubSubFactory (configurable) │
|
||||
│ - get_pubsub() returns backend │
|
||||
└──────────────┬──────────────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│ │
|
||||
┌───────▼─────────┐ ┌────▼──────────────┐
|
||||
│ PulsarBackend │ │ MQTTBackend │
|
||||
│ - JSON schema │ │ - JSON serialize │
|
||||
│ - or dynamic │ │ - Simple queues │
|
||||
│ Record gen │ │ │
|
||||
└─────────────────┘ └───────────────────┘
|
||||
```
|
||||
|
||||
### Implementation Details
|
||||
|
||||
**1. Schema definitions:** Plain dataclasses with type hints
|
||||
- `str`, `int`, `bool`, `float` for primitives
|
||||
- `list[T]` for arrays
|
||||
- `dict[str, T]` for maps
|
||||
- Nested dataclasses for complex types
|
||||
|
||||
**2. Each backend provides:**
|
||||
- Serializer: `dataclass → bytes/wire format`
|
||||
- Deserializer: `bytes/wire format → dataclass`
|
||||
- Schema registration (if needed, like Pulsar/Kafka)
|
||||
|
||||
**3. Consumer/Producer abstraction:**
|
||||
- Already exists (consumer.py, producer.py)
|
||||
- Update to use backend's serialization
|
||||
- Remove direct Pulsar imports
|
||||
|
||||
**4. Type mappings:**
|
||||
- Pulsar `String()` → Python `str`
|
||||
- Pulsar `Integer()` → Python `int`
|
||||
- Pulsar `Boolean()` → Python `bool`
|
||||
- Pulsar `Array(T)` → Python `list[T]`
|
||||
- Pulsar `Map(K, V)` → Python `dict[K, V]`
|
||||
- Pulsar `Double()` → Python `float`
|
||||
- Pulsar `Bytes()` → Python `bytes`
|
||||
|
||||
### Migration Path
|
||||
|
||||
1. **Create dataclass versions** of all schemas in `trustgraph/schema/`
|
||||
2. **Update backend classes** (Consumer, Producer, Publisher, Subscriber) to use backend-provided serialization
|
||||
3. **Implement PulsarBackend** with JSON schema or dynamic Record generation
|
||||
4. **Test with Pulsar** to ensure backward compatibility with existing deployments
|
||||
5. **Add new backends** (MQTT, Kafka, Redis, etc.) as needed
|
||||
6. **Remove Pulsar imports** from schema files
|
||||
|
||||
### Benefits
|
||||
|
||||
✅ **No pub/sub dependency** in schema definitions
|
||||
✅ **Standard Python** - easy to understand, type-check, document
|
||||
✅ **Modern tooling** - works with mypy, IDE autocomplete, linters
|
||||
✅ **Backend-optimized** - each backend uses native serialization
|
||||
✅ **No translation overhead** - direct serialization, no adapters
|
||||
✅ **Type safety** - real objects with proper types
|
||||
✅ **Easy validation** - can use Pydantic if needed
|
||||
|
||||
### Challenges & Solutions
|
||||
|
||||
**Challenge:** Pulsar's `Record` has runtime field validation
|
||||
**Solution:** Use Pydantic dataclasses for validation if needed, or Python 3.10+ dataclass features with `__post_init__`
|
||||
|
||||
**Challenge:** Some Pulsar-specific features (like `Bytes` type)
|
||||
**Solution:** Map to `bytes` type in dataclass, backend handles encoding appropriately
|
||||
|
||||
**Challenge:** Topic naming (`persistent://tenant/namespace/topic`)
|
||||
**Solution:** Abstract topic names in schema definitions, backend converts to proper format
|
||||
|
||||
**Challenge:** Schema evolution and versioning
|
||||
**Solution:** Each backend handles this according to its capabilities (Pulsar schema versions, Kafka schema registry, etc.)
|
||||
|
||||
**Challenge:** Nested complex types
|
||||
**Solution:** Use nested dataclasses, backends recursively serialize/deserialize
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Plain dataclasses or Pydantic?**
|
||||
- ✅ **Decision: Use plain Python dataclasses**
|
||||
- Simpler, no additional dependencies
|
||||
- Validation not required in practice
|
||||
- Easier to understand and maintain
|
||||
|
||||
2. **Schema evolution:**
|
||||
- ✅ **Decision: No versioning mechanism needed**
|
||||
- Schemas are stable and long-lasting
|
||||
- Updates typically add new fields (backward compatible)
|
||||
- Backends handle schema evolution according to their capabilities
|
||||
|
||||
3. **Backward compatibility:**
|
||||
- ✅ **Decision: Major version change, no backward compatibility required**
|
||||
- Will be a breaking change with migration instructions
|
||||
- Clean break allows for better design
|
||||
- Migration guide will be provided for existing deployments
|
||||
|
||||
4. **Nested types and complex structures:**
|
||||
- ✅ **Decision: Use nested dataclasses naturally**
|
||||
- Python dataclasses handle nesting perfectly
|
||||
- `list[T]` for arrays, `dict[K, V]` for maps
|
||||
- Backends recursively serialize/deserialize
|
||||
- Example:
|
||||
```python
|
||||
@dataclass
|
||||
class Value:
|
||||
value: str
|
||||
is_uri: bool
|
||||
|
||||
@dataclass
|
||||
class Triple:
|
||||
s: Value # Nested dataclass
|
||||
p: Value
|
||||
o: Value
|
||||
|
||||
@dataclass
|
||||
class GraphQuery:
|
||||
triples: list[Triple] # Array of nested dataclasses
|
||||
metadata: dict[str, str]
|
||||
```
|
||||
|
||||
5. **Default values and optional fields:**
|
||||
- ✅ **Decision: Mix of required, defaults, and optional fields**
|
||||
- Required fields: No default value
|
||||
- Fields with defaults: Always present, have sensible default
|
||||
- Truly optional fields: `T | None = None`, omitted from serialization when `None`
|
||||
- Example:
|
||||
```python
|
||||
@dataclass
|
||||
class TextCompletionRequest:
|
||||
system: str # Required, no default
|
||||
prompt: str # Required, no default
|
||||
streaming: bool = False # Optional with default value
|
||||
metadata: dict | None = None # Truly optional, can be absent
|
||||
```
|
||||
|
||||
**Important serialization semantics:**
|
||||
|
||||
When `metadata = None`:
|
||||
```json
|
||||
{
|
||||
"system": "...",
|
||||
"prompt": "...",
|
||||
"streaming": false
|
||||
// metadata field NOT PRESENT
|
||||
}
|
||||
```
|
||||
|
||||
When `metadata = {}` (explicitly empty):
|
||||
```json
|
||||
{
|
||||
"system": "...",
|
||||
"prompt": "...",
|
||||
"streaming": false,
|
||||
"metadata": {} // Field PRESENT but empty
|
||||
}
|
||||
```
|
||||
|
||||
**Key distinction:**
|
||||
- `None` → field absent from JSON (not serialized)
|
||||
- Empty value (`{}`, `[]`, `""`) → field present with empty value
|
||||
- This matters semantically: "not provided" vs "explicitly empty"
|
||||
- Serialization backends must skip `None` fields, not encode as `null`
|
||||
|
||||
## Approach Draft 3: Implementation Details
|
||||
|
||||
### Generic Queue Naming Format
|
||||
|
||||
Replace backend-specific queue names with a generic format that backends can map appropriately.
|
||||
|
||||
**Format:** `{qos}/{tenant}/{namespace}/{queue-name}`
|
||||
|
||||
Where:
|
||||
- `qos`: Quality of Service level
|
||||
- `q0` = best-effort (fire and forget, no acknowledgment)
|
||||
- `q1` = at-least-once (requires acknowledgment)
|
||||
- `q2` = exactly-once (two-phase acknowledgment)
|
||||
- `tenant`: Logical grouping for multi-tenancy
|
||||
- `namespace`: Sub-grouping within tenant
|
||||
- `queue-name`: Actual queue/topic name
|
||||
|
||||
**Examples:**
|
||||
```
|
||||
q1/tg/flow/text-completion-requests
|
||||
q2/tg/config/config-push
|
||||
q0/tg/metrics/stats
|
||||
```
|
||||
|
||||
### Backend Topic Mapping
|
||||
|
||||
Each backend maps the generic format to its native format:
|
||||
|
||||
**Pulsar Backend:**
|
||||
```python
|
||||
def map_topic(self, generic_topic: str) -> str:
|
||||
# Parse: q1/tg/flow/text-completion-requests
|
||||
qos, tenant, namespace, queue = generic_topic.split('/', 3)
|
||||
|
||||
# Map QoS to persistence
|
||||
persistence = 'persistent' if qos in ['q1', 'q2'] else 'non-persistent'
|
||||
|
||||
# Return Pulsar URI: persistent://tg/flow/text-completion-requests
|
||||
return f"{persistence}://{tenant}/{namespace}/{queue}"
|
||||
```
|
||||
|
||||
**MQTT Backend:**
|
||||
```python
|
||||
def map_topic(self, generic_topic: str) -> tuple[str, int]:
|
||||
# Parse: q1/tg/flow/text-completion-requests
|
||||
qos, tenant, namespace, queue = generic_topic.split('/', 3)
|
||||
|
||||
# Map QoS level
|
||||
qos_level = {'q0': 0, 'q1': 1, 'q2': 2}[qos]
|
||||
|
||||
# Build MQTT topic including tenant/namespace for proper namespacing
|
||||
mqtt_topic = f"{tenant}/{namespace}/{queue}"
|
||||
|
||||
return mqtt_topic, qos_level
|
||||
```
|
||||
|
||||
### Updated Topic Helper Function
|
||||
|
||||
```python
|
||||
# schema/core/topic.py
|
||||
def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
|
||||
"""
|
||||
Create a generic topic identifier that can be mapped by backends.
|
||||
|
||||
Args:
|
||||
queue_name: The queue/topic name
|
||||
qos: Quality of service
|
||||
- 'q0' = best-effort (no ack)
|
||||
- 'q1' = at-least-once (ack required)
|
||||
- 'q2' = exactly-once (two-phase ack)
|
||||
tenant: Tenant identifier for multi-tenancy
|
||||
namespace: Namespace within tenant
|
||||
|
||||
Returns:
|
||||
Generic topic string: qos/tenant/namespace/queue_name
|
||||
|
||||
Examples:
|
||||
topic('my-queue') # q1/tg/flow/my-queue
|
||||
topic('config', qos='q2', namespace='config') # q2/tg/config/config
|
||||
"""
|
||||
return f"{qos}/{tenant}/{namespace}/{queue_name}"
|
||||
```
|
||||
|
||||
### Configuration and Initialization
|
||||
|
||||
**Command-Line Arguments + Environment Variables:**
|
||||
|
||||
```python
|
||||
# In base/async_processor.py - add_args() method
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# Pub/sub backend selection
|
||||
parser.add_argument(
|
||||
'--pubsub-backend',
|
||||
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||
choices=['pulsar', 'mqtt'],
|
||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)'
|
||||
)
|
||||
|
||||
# Pulsar-specific configuration
|
||||
parser.add_argument(
|
||||
'--pulsar-host',
|
||||
default=os.getenv('PULSAR_HOST', 'pulsar://localhost:6650'),
|
||||
help='Pulsar host (default: pulsar://localhost:6650, env: PULSAR_HOST)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=os.getenv('PULSAR_API_KEY', None),
|
||||
help='Pulsar API key (env: PULSAR_API_KEY)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=os.getenv('PULSAR_LISTENER', None),
|
||||
help='Pulsar listener name (env: PULSAR_LISTENER)'
|
||||
)
|
||||
|
||||
# MQTT-specific configuration
|
||||
parser.add_argument(
|
||||
'--mqtt-host',
|
||||
default=os.getenv('MQTT_HOST', 'localhost'),
|
||||
help='MQTT broker host (default: localhost, env: MQTT_HOST)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mqtt-port',
|
||||
type=int,
|
||||
default=int(os.getenv('MQTT_PORT', '1883')),
|
||||
help='MQTT broker port (default: 1883, env: MQTT_PORT)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mqtt-username',
|
||||
default=os.getenv('MQTT_USERNAME', None),
|
||||
help='MQTT username (env: MQTT_USERNAME)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mqtt-password',
|
||||
default=os.getenv('MQTT_PASSWORD', None),
|
||||
help='MQTT password (env: MQTT_PASSWORD)'
|
||||
)
|
||||
```
|
||||
|
||||
**Factory Function:**
|
||||
|
||||
```python
|
||||
# In base/pubsub.py or base/pubsub_factory.py
|
||||
def get_pubsub(**config) -> PubSubBackend:
|
||||
"""
|
||||
Create and return a pub/sub backend based on configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dict from command-line args
|
||||
Must include 'pubsub_backend' key
|
||||
|
||||
Returns:
|
||||
Backend instance (PulsarBackend, MQTTBackend, etc.)
|
||||
"""
|
||||
backend_type = config.get('pubsub_backend', 'pulsar')
|
||||
|
||||
if backend_type == 'pulsar':
|
||||
return PulsarBackend(
|
||||
host=config.get('pulsar_host'),
|
||||
api_key=config.get('pulsar_api_key'),
|
||||
listener=config.get('pulsar_listener'),
|
||||
)
|
||||
elif backend_type == 'mqtt':
|
||||
return MQTTBackend(
|
||||
host=config.get('mqtt_host'),
|
||||
port=config.get('mqtt_port'),
|
||||
username=config.get('mqtt_username'),
|
||||
password=config.get('mqtt_password'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||
```
|
||||
|
||||
**Usage in AsyncProcessor:**
|
||||
|
||||
```python
|
||||
# In async_processor.py
|
||||
class AsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.id = params.get("id")
|
||||
|
||||
# Create backend from config (replaces PulsarClient)
|
||||
self.pubsub = get_pubsub(**params)
|
||||
|
||||
# Rest of initialization...
|
||||
```
|
||||
|
||||
### Backend Interface
|
||||
|
||||
```python
|
||||
class PubSubBackend(Protocol):
|
||||
"""Protocol defining the interface all pub/sub backends must implement."""
|
||||
|
||||
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||
"""
|
||||
Create a producer for a topic.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
schema: Dataclass type for messages
|
||||
options: Backend-specific options (e.g., chunking_enabled)
|
||||
|
||||
Returns:
|
||||
Backend-specific producer instance
|
||||
"""
|
||||
...
|
||||
|
||||
def create_consumer(
|
||||
self,
|
||||
topic: str,
|
||||
subscription: str,
|
||||
schema: type,
|
||||
initial_position: str = 'latest',
|
||||
consumer_type: str = 'shared',
|
||||
**options
|
||||
) -> BackendConsumer:
|
||||
"""
|
||||
Create a consumer for a topic.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
subscription: Subscription/consumer group name
|
||||
schema: Dataclass type for messages
|
||||
initial_position: 'earliest' or 'latest' (MQTT may ignore)
|
||||
consumer_type: 'shared', 'exclusive', 'failover' (MQTT may ignore)
|
||||
options: Backend-specific options
|
||||
|
||||
Returns:
|
||||
Backend-specific consumer instance
|
||||
"""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the backend connection."""
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
class BackendProducer(Protocol):
|
||||
"""Protocol for backend-specific producer."""
|
||||
|
||||
def send(self, message: Any, properties: dict = {}) -> None:
|
||||
"""Send a message (dataclass instance) with optional properties."""
|
||||
...
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush any buffered messages."""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the producer."""
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
class BackendConsumer(Protocol):
|
||||
"""Protocol for backend-specific consumer."""
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""
|
||||
Receive a message from the topic.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no message received within timeout
|
||||
"""
|
||||
...
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
"""Acknowledge successful processing of a message."""
|
||||
...
|
||||
|
||||
def negative_acknowledge(self, message: Message) -> None:
|
||||
"""Negative acknowledge - triggers redelivery."""
|
||||
...
|
||||
|
||||
def unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the topic."""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the consumer."""
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
class Message(Protocol):
|
||||
"""Protocol for a received message."""
|
||||
|
||||
def value(self) -> Any:
|
||||
"""Get the deserialized message (dataclass instance)."""
|
||||
...
|
||||
|
||||
def properties(self) -> dict:
|
||||
"""Get message properties/metadata."""
|
||||
...
|
||||
```
|
||||
|
||||
### Existing Classes Refactoring
|
||||
|
||||
The existing `Consumer`, `Producer`, `Publisher`, `Subscriber` classes remain largely intact:
|
||||
|
||||
**Current responsibilities (keep):**
|
||||
- Async threading model and taskgroups
|
||||
- Reconnection logic and retry handling
|
||||
- Metrics collection
|
||||
- Rate limiting
|
||||
- Concurrency management
|
||||
|
||||
**Changes needed:**
|
||||
- Remove direct Pulsar imports (`pulsar.schema`, `pulsar.InitialPosition`, etc.)
|
||||
- Accept `BackendProducer`/`BackendConsumer` instead of Pulsar client
|
||||
- Delegate actual pub/sub operations to backend instances
|
||||
- Map generic concepts to backend calls
|
||||
|
||||
**Example refactoring:**
|
||||
|
||||
```python
|
||||
# OLD - consumer.py
|
||||
class Consumer:
|
||||
def __init__(self, client, topic, subscriber, schema, ...):
|
||||
self.client = client # Direct Pulsar client
|
||||
# ...
|
||||
|
||||
async def consumer_run(self):
|
||||
# Uses pulsar.InitialPosition, pulsar.ConsumerType
|
||||
self.consumer = self.client.subscribe(
|
||||
topic=self.topic,
|
||||
schema=JsonSchema(self.schema),
|
||||
initial_position=pulsar.InitialPosition.Earliest,
|
||||
consumer_type=pulsar.ConsumerType.Shared,
|
||||
)
|
||||
|
||||
# NEW - consumer.py
|
||||
class Consumer:
|
||||
def __init__(self, backend_consumer, schema, ...):
|
||||
self.backend_consumer = backend_consumer # Backend-specific consumer
|
||||
self.schema = schema
|
||||
# ...
|
||||
|
||||
async def consumer_run(self):
|
||||
# Backend consumer already created with right settings
|
||||
# Just use it directly
|
||||
while self.running:
|
||||
msg = await asyncio.to_thread(
|
||||
self.backend_consumer.receive,
|
||||
timeout_millis=2000
|
||||
)
|
||||
await self.handle_message(msg)
|
||||
```
|
||||
|
||||
### Backend-Specific Behaviors
|
||||
|
||||
**Pulsar Backend:**
|
||||
- Maps `q0` → `non-persistent://`, `q1`/`q2` → `persistent://`
|
||||
- Supports all consumer types (shared, exclusive, failover)
|
||||
- Supports initial position (earliest/latest)
|
||||
- Native message acknowledgment
|
||||
- Schema registry support
|
||||
|
||||
**MQTT Backend:**
|
||||
- Maps `q0`/`q1`/`q2` → MQTT QoS levels 0/1/2
|
||||
- Includes tenant/namespace in topic path for namespacing
|
||||
- Auto-generates client IDs from subscription names
|
||||
- Ignores initial position (no message history in basic MQTT)
|
||||
- Ignores consumer type (MQTT uses client IDs, not consumer groups)
|
||||
- Simple publish/subscribe model
|
||||
|
||||
### Design Decisions Summary
|
||||
|
||||
1. ✅ **Generic queue naming**: `qos/tenant/namespace/queue-name` format
|
||||
2. ✅ **QoS in queue ID**: Determined by queue definition, not configuration
|
||||
3. ✅ **Reconnection**: Handled by Consumer/Producer classes, not backends
|
||||
4. ✅ **MQTT topics**: Include tenant/namespace for proper namespacing
|
||||
5. ✅ **Message history**: MQTT ignores `initial_position` parameter (future enhancement)
|
||||
6. ✅ **Client IDs**: MQTT backend auto-generates from subscription name
|
||||
|
||||
### Future Enhancements
|
||||
|
||||
**MQTT message history:**
|
||||
- Could add optional persistence layer (e.g., retained messages, external store)
|
||||
- Would allow supporting `initial_position='earliest'`
|
||||
- Not required for initial implementation
|
||||
|
||||
1508
docs/tech-specs/python-api-refactor.md
Normal file
1508
docs/tech-specs/python-api-refactor.md
Normal file
File diff suppressed because it is too large
Load diff
54
ontology-prompt.md
Normal file
54
ontology-prompt.md
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
You are a knowledge extraction expert. Extract structured triples from text using ONLY the provided ontology elements.
|
||||
|
||||
## Ontology Classes:
|
||||
|
||||
{% for class_id, class_def in classes.items() %}
|
||||
- **{{class_id}}**{% if class_def.subclass_of %} (subclass of {{class_def.subclass_of}}){% endif %}{% if class_def.comment %}: {{class_def.comment}}{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
## Object Properties (connect entities):
|
||||
|
||||
{% for prop_id, prop_def in object_properties.items() %}
|
||||
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
## Datatype Properties (entity attributes):
|
||||
|
||||
{% for prop_id, prop_def in datatype_properties.items() %}
|
||||
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
## Text to Analyze:
|
||||
|
||||
{{text}}
|
||||
|
||||
## Extraction Rules:
|
||||
|
||||
1. Only use classes defined above for entity types
|
||||
2. Only use properties defined above for relationships and attributes
|
||||
3. Respect domain and range constraints where specified
|
||||
4. For class instances, use `rdf:type` as the predicate
|
||||
5. Include `rdfs:label` for new entities to provide human-readable names
|
||||
6. Extract all relevant triples that can be inferred from the text
|
||||
7. Use entity URIs or meaningful identifiers as subjects/objects
|
||||
|
||||
## Output Format:
|
||||
|
||||
Return ONLY a valid JSON array (no markdown, no code blocks) containing objects with these fields:
|
||||
- "subject": the subject entity (URI or identifier)
|
||||
- "predicate": the property (from ontology or rdf:type/rdfs:label)
|
||||
- "object": the object entity or literal value
|
||||
|
||||
Important: Return raw JSON only, with no markdown formatting, no code blocks, and no backticks.
|
||||
|
||||
## Example Output:
|
||||
|
||||
[
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
|
||||
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"},
|
||||
{"subject": "ingredient:flour", "predicate": "rdf:type", "object": "Ingredient"},
|
||||
{"subject": "ingredient:flour", "predicate": "rdfs:label", "object": "plain flour"}
|
||||
]
|
||||
|
||||
Now extract triples from the text above.
|
||||
|
|
@ -21,3 +21,4 @@ prometheus-client
|
|||
pyarrow
|
||||
boto3
|
||||
ollama
|
||||
python-logging-loki
|
||||
|
|
|
|||
60
tests/conftest.py
Normal file
60
tests/conftest.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Global pytest configuration for all tests.
|
||||
|
||||
This conftest.py applies to all test directories.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
# import asyncio
|
||||
# import tracemalloc
|
||||
# import warnings
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Uncomment the lines below to enable asyncio debug mode and tracemalloc
|
||||
# for tracing unawaited coroutines and their creation points
|
||||
# tracemalloc.start()
|
||||
# asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||
# warnings.simplefilter("always", ResourceWarning)
|
||||
# warnings.simplefilter("always", RuntimeWarning)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_loki_handler(session_mocker=None):
|
||||
"""
|
||||
Mock LokiHandler to prevent connection attempts during tests.
|
||||
|
||||
This fixture runs once per test session and prevents the logging
|
||||
module from trying to connect to a Loki server that doesn't exist
|
||||
in the test environment.
|
||||
"""
|
||||
# Try to import logging_loki and mock it if available
|
||||
try:
|
||||
import logging_loki
|
||||
# Create a mock LokiHandler that does nothing
|
||||
original_loki_handler = logging_loki.LokiHandler
|
||||
|
||||
class MockLokiHandler:
|
||||
"""Mock LokiHandler that doesn't make network calls."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def emit(self, record):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# Replace the real LokiHandler with our mock
|
||||
logging_loki.LokiHandler = MockLokiHandler
|
||||
|
||||
yield
|
||||
|
||||
# Restore original after tests
|
||||
logging_loki.LokiHandler = original_loki_handler
|
||||
|
||||
except ImportError:
|
||||
# If logging_loki isn't installed, no need to mock
|
||||
yield
|
||||
|
|
@ -257,7 +257,6 @@ class TestAgentMessageContracts:
|
|||
# Act
|
||||
request = AgentRequest(
|
||||
question="What comes next?",
|
||||
plan="Multi-step plan",
|
||||
state="processing",
|
||||
history=history_steps
|
||||
)
|
||||
|
|
@ -588,7 +587,6 @@ class TestSerializationContracts:
|
|||
|
||||
request = AgentRequest(
|
||||
question="Test with array",
|
||||
plan="Test plan",
|
||||
state="Test state",
|
||||
history=steps
|
||||
)
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ class TestObjectsCassandraContracts:
|
|||
assert result == expected_val
|
||||
assert isinstance(result, expected_type) or result is None
|
||||
|
||||
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||
def test_extracted_object_serialization_contract(self):
|
||||
"""Test that ExtractedObject can be serialized/deserialized correctly"""
|
||||
# Create test object
|
||||
|
|
@ -408,6 +409,7 @@ class TestObjectsCassandraContractsBatch:
|
|||
assert isinstance(single_batch_object.values[0], dict)
|
||||
assert single_batch_object.values[0]["customer_id"] == "CUST999"
|
||||
|
||||
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
|
||||
def test_extracted_object_batch_serialization_contract(self):
|
||||
"""Test that batched ExtractedObject can be serialized/deserialized correctly"""
|
||||
# Create batch object
|
||||
|
|
|
|||
|
|
@ -480,11 +480,15 @@ def streaming_chunk_collector():
|
|||
class ChunkCollector:
|
||||
def __init__(self):
|
||||
self.chunks = []
|
||||
self.end_of_stream_flags = []
|
||||
self.complete = False
|
||||
|
||||
async def collect(self, chunk):
|
||||
"""Async callback to collect chunks"""
|
||||
async def collect(self, chunk, end_of_stream=False):
|
||||
"""Async callback to collect chunks with end_of_stream flag"""
|
||||
self.chunks.append(chunk)
|
||||
self.end_of_stream_flags.append(end_of_stream)
|
||||
if end_of_stream:
|
||||
self.complete = True
|
||||
|
||||
def get_full_text(self):
|
||||
"""Concatenate all chunk content"""
|
||||
|
|
@ -496,6 +500,14 @@ def streaming_chunk_collector():
|
|||
return [c.get("chunk_type") for c in self.chunks]
|
||||
return []
|
||||
|
||||
def verify_streaming_protocol(self):
|
||||
"""Verify that streaming protocol is correct"""
|
||||
assert len(self.chunks) > 0, "Should have received at least one chunk"
|
||||
assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag"
|
||||
assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True"
|
||||
assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True"
|
||||
assert self.complete is True, "Should be marked complete after final chunk"
|
||||
|
||||
return ChunkCollector
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,8 +47,9 @@ Args: {
|
|||
"}"
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
await chunk_callback(chunk)
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
|
|
@ -312,8 +313,10 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
call_count += 1
|
||||
|
||||
if streaming and chunk_callback:
|
||||
for chunk in response.split():
|
||||
await chunk_callback(chunk + " ")
|
||||
chunks = response.split()
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -373,13 +373,13 @@ class TestMultipleHostsHandling:
|
|||
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
||||
|
||||
# Test various whitespace scenarios
|
||||
hosts1, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
||||
hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
||||
assert hosts1 == ['host1', 'host2', 'host3']
|
||||
|
||||
hosts2, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
||||
hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
||||
assert hosts2 == ['host1', 'host2', 'host3']
|
||||
|
||||
hosts3, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
||||
hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
||||
assert hosts3 == ['host1', 'host2']
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -46,9 +46,16 @@ class TestDocumentRagStreaming:
|
|||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
await chunk_callback(chunk)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
|
|
@ -89,6 +96,9 @@ class TestDocumentRagStreaming:
|
|||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
# Verify streaming protocol compliance
|
||||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
|
|
@ -117,7 +127,7 @@ class TestDocumentRagStreaming:
|
|||
# Act - Streaming
|
||||
streaming_chunks = []
|
||||
|
||||
async def collect(chunk):
|
||||
async def collect(chunk, end_of_stream):
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
streaming_result = await document_rag_streaming.query(
|
||||
|
|
|
|||
|
|
@ -59,9 +59,16 @@ class TestGraphRagStreaming:
|
|||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
chunks = []
|
||||
async for chunk in mock_streaming_llm_response():
|
||||
await chunk_callback(chunk)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Send all chunks with end_of_stream=False except the last
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
|
|
@ -102,6 +109,9 @@ class TestGraphRagStreaming:
|
|||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
# Verify streaming protocol compliance
|
||||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
|
|
@ -128,7 +138,7 @@ class TestGraphRagStreaming:
|
|||
# Act - Streaming
|
||||
streaming_chunks = []
|
||||
|
||||
async def collect(chunk):
|
||||
async def collect(chunk, end_of_stream):
|
||||
streaming_chunks.append(chunk)
|
||||
|
||||
streaming_result = await graph_rag_streaming.query(
|
||||
|
|
|
|||
|
|
@ -59,17 +59,17 @@ class MockWebSocket:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for integration testing."""
|
||||
client = MagicMock()
|
||||
|
||||
def mock_backend():
|
||||
"""Mock backend for integration testing."""
|
||||
backend = MagicMock()
|
||||
|
||||
# Mock producer
|
||||
producer = MagicMock()
|
||||
producer.send = MagicMock()
|
||||
producer.flush = MagicMock()
|
||||
producer.close = MagicMock()
|
||||
client.create_producer.return_value = producer
|
||||
|
||||
backend.create_producer.return_value = producer
|
||||
|
||||
# Mock consumer
|
||||
consumer = MagicMock()
|
||||
consumer.receive = AsyncMock()
|
||||
|
|
@ -78,33 +78,31 @@ def mock_pulsar_client():
|
|||
consumer.pause_message_listener = MagicMock()
|
||||
consumer.unsubscribe = MagicMock()
|
||||
consumer.close = MagicMock()
|
||||
client.subscribe.return_value = consumer
|
||||
|
||||
return client
|
||||
backend.create_consumer.return_value = consumer
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_graceful_shutdown_integration():
|
||||
async def test_import_graceful_shutdown_integration(mock_backend):
|
||||
"""Test import path handles shutdown gracefully with real message flow."""
|
||||
mock_client = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
|
||||
mock_producer = mock_backend.create_producer.return_value
|
||||
|
||||
# Track sent messages
|
||||
sent_messages = []
|
||||
def track_send(message, properties=None):
|
||||
sent_messages.append((message, properties))
|
||||
|
||||
|
||||
mock_producer.send.side_effect = track_send
|
||||
|
||||
|
||||
ws = MockWebSocket()
|
||||
running = Running()
|
||||
|
||||
|
||||
# Create import handler
|
||||
import_handler = TriplesImport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
backend=mock_backend,
|
||||
queue="test-triples-import"
|
||||
)
|
||||
|
||||
|
|
@ -151,11 +149,9 @@ async def test_import_graceful_shutdown_integration():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_no_message_loss_integration():
|
||||
async def test_export_no_message_loss_integration(mock_backend):
|
||||
"""Test export path doesn't lose acknowledged messages."""
|
||||
mock_client = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
mock_consumer = mock_backend.create_consumer.return_value
|
||||
|
||||
# Create test messages
|
||||
test_messages = []
|
||||
|
|
@ -202,7 +198,7 @@ async def test_export_no_message_loss_integration():
|
|||
export_handler = TriplesExport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
backend=mock_backend,
|
||||
queue="test-triples-export",
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
|
|
@ -245,14 +241,14 @@ async def test_export_no_message_loss_integration():
|
|||
async def test_concurrent_import_export_shutdown():
|
||||
"""Test concurrent import and export shutdown scenarios."""
|
||||
# Setup mock clients
|
||||
import_client = MagicMock()
|
||||
export_client = MagicMock()
|
||||
import_backend = MagicMock()
|
||||
export_backend = MagicMock()
|
||||
|
||||
import_producer = MagicMock()
|
||||
export_consumer = MagicMock()
|
||||
|
||||
import_client.create_producer.return_value = import_producer
|
||||
export_client.subscribe.return_value = export_consumer
|
||||
import_backend.create_producer.return_value = import_producer
|
||||
export_backend.subscribe.return_value = export_consumer
|
||||
|
||||
# Track operations
|
||||
import_operations = []
|
||||
|
|
@ -280,14 +276,14 @@ async def test_concurrent_import_export_shutdown():
|
|||
import_handler = TriplesImport(
|
||||
ws=import_ws,
|
||||
running=import_running,
|
||||
pulsar_client=import_client,
|
||||
backend=import_backend,
|
||||
queue="concurrent-import"
|
||||
)
|
||||
|
||||
export_handler = TriplesExport(
|
||||
ws=export_ws,
|
||||
running=export_running,
|
||||
pulsar_client=export_client,
|
||||
backend=export_backend,
|
||||
queue="concurrent-export",
|
||||
consumer="concurrent-consumer",
|
||||
subscriber="concurrent-subscriber"
|
||||
|
|
@ -328,9 +324,9 @@ async def test_concurrent_import_export_shutdown():
|
|||
@pytest.mark.asyncio
|
||||
async def test_websocket_close_during_message_processing():
|
||||
"""Test graceful handling when websocket closes during active message processing."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend_local = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend_local.create_producer.return_value = mock_producer
|
||||
|
||||
# Simulate slow message processing
|
||||
processed_messages = []
|
||||
|
|
@ -346,7 +342,7 @@ async def test_websocket_close_during_message_processing():
|
|||
import_handler = TriplesImport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
backend=mock_backend_local,
|
||||
queue="slow-processing-import"
|
||||
)
|
||||
|
||||
|
|
@ -395,9 +391,9 @@ async def test_websocket_close_during_message_processing():
|
|||
@pytest.mark.asyncio
|
||||
async def test_backpressure_during_shutdown():
|
||||
"""Test graceful shutdown under backpressure conditions."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend_local = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
mock_backend_local.subscribe.return_value = mock_consumer
|
||||
|
||||
# Mock slow websocket
|
||||
class SlowWebSocket(MockWebSocket):
|
||||
|
|
@ -410,8 +406,8 @@ async def test_backpressure_during_shutdown():
|
|||
|
||||
export_handler = TriplesExport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
running=running,
|
||||
backend=mock_backend_local,
|
||||
queue="backpressure-export",
|
||||
consumer="backpressure-consumer",
|
||||
subscriber="backpressure-subscriber"
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestObjectsCassandraIntegration:
|
|||
assert "customer_records" in processor.schemas
|
||||
|
||||
# Step 1.5: Create the collection first (simulate tg-set-collection)
|
||||
await processor.create_collection("test_user", "import_2024")
|
||||
await processor.create_collection("test_user", "import_2024", {})
|
||||
|
||||
# Step 2: Process an ExtractedObject
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -213,8 +213,8 @@ class TestObjectsCassandraIntegration:
|
|||
assert len(processor.schemas) == 2
|
||||
|
||||
# Create collections first
|
||||
await processor.create_collection("shop", "catalog")
|
||||
await processor.create_collection("shop", "sales")
|
||||
await processor.create_collection("shop", "catalog", {})
|
||||
await processor.create_collection("shop", "sales", {})
|
||||
|
||||
# Process objects for different schemas
|
||||
product_obj = ExtractedObject(
|
||||
|
|
@ -263,7 +263,7 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "test")
|
||||
await processor.create_collection("test", "test", {})
|
||||
|
||||
# Create object missing required field
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -302,7 +302,7 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("logger", "app_events")
|
||||
await processor.create_collection("logger", "app_events", {})
|
||||
|
||||
# Process object
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -407,7 +407,7 @@ class TestObjectsCassandraIntegration:
|
|||
|
||||
# Create all collections first
|
||||
for coll in collections:
|
||||
await processor.create_collection("analytics", coll)
|
||||
await processor.create_collection("analytics", coll, {})
|
||||
|
||||
for coll in collections:
|
||||
obj = ExtractedObject(
|
||||
|
|
@ -486,7 +486,7 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test_user", "batch_import")
|
||||
await processor.create_collection("test_user", "batch_import", {})
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
|
@ -532,7 +532,7 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "empty")
|
||||
await processor.create_collection("test", "empty", {})
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
|
|
@ -573,7 +573,7 @@ class TestObjectsCassandraIntegration:
|
|||
)
|
||||
|
||||
# Create collection first
|
||||
await processor.create_collection("test", "mixed")
|
||||
await processor.create_collection("test", "mixed", {})
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
|
|
|
|||
351
tests/integration/test_rag_streaming_protocol.py
Normal file
351
tests/integration/test_rag_streaming_protocol.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""
|
||||
Integration tests for RAG service streaming protocol compliance.
|
||||
|
||||
These tests verify that RAG services correctly forward end_of_stream flags
|
||||
and don't duplicate final chunks, ensuring proper streaming semantics.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
"""Integration tests for GraphRAG streaming protocol"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_client(self):
|
||||
"""Mock graph embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["entity1", "entity2"]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_triples_client(self):
|
||||
"""Mock triples client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self):
|
||||
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
await chunk_callback("The", False)
|
||||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
else:
|
||||
return "The answer is here."
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_receives_end_of_stream_parameter(self, graph_rag):
|
||||
"""Test that callback receives end_of_stream parameter"""
|
||||
# Arrange
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert - callback should receive (chunk, end_of_stream) signature
|
||||
assert callback.call_count == 4
|
||||
# All calls should have 2 arguments
|
||||
for call_args in callback.call_args_list:
|
||||
assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag):
|
||||
"""Test that end_of_stream flags are forwarded correctly"""
|
||||
# Arrange
|
||||
chunks_with_flags = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
chunks_with_flags.append((chunk, end_of_stream))
|
||||
|
||||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(chunks_with_flags) == 4
|
||||
|
||||
# First three chunks should have end_of_stream=False
|
||||
assert chunks_with_flags[0] == ("The", False)
|
||||
assert chunks_with_flags[1] == (" answer", False)
|
||||
assert chunks_with_flags[2] == (" is here.", False)
|
||||
|
||||
# Final chunk should have end_of_stream=True
|
||||
assert chunks_with_flags[3] == ("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_final_chunk(self, graph_rag):
|
||||
"""Test that final chunk is not duplicated"""
|
||||
# Arrange
|
||||
chunks = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert - should have exactly 4 chunks, no duplicates
|
||||
assert len(chunks) == 4
|
||||
assert chunks == ["The", " answer", " is here.", ""]
|
||||
|
||||
# The last chunk appears exactly once
|
||||
assert chunks.count("") == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exactly_one_end_of_stream_true(self, graph_rag):
|
||||
"""Test that exactly one message has end_of_stream=True"""
|
||||
# Arrange
|
||||
end_of_stream_flags = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
end_of_stream_flags.append(end_of_stream)
|
||||
|
||||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert - exactly one True
|
||||
assert end_of_stream_flags.count(True) == 1
|
||||
assert end_of_stream_flags.count(False) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_final_chunk_preserved(self, graph_rag):
|
||||
"""Test that empty final chunks are preserved and forwarded"""
|
||||
# Arrange
|
||||
final_chunk = None
|
||||
final_flag = None
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
nonlocal final_chunk, final_flag
|
||||
if end_of_stream:
|
||||
final_chunk = chunk
|
||||
final_flag = end_of_stream
|
||||
|
||||
# Act
|
||||
await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert final_flag is True
|
||||
assert final_chunk == "", "Empty final chunk should be preserved"
|
||||
|
||||
|
||||
class TestDocumentRagStreamingProtocol:
|
||||
"""Integration tests for DocumentRAG streaming protocol"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = ["doc1", "doc2"]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_streaming_prompt_client(self):
|
||||
"""Mock prompt client with streaming support"""
|
||||
client = AsyncMock()
|
||||
|
||||
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming with non-empty final chunk (some LLMs do this)
|
||||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
else:
|
||||
return "Document summary."
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_receives_end_of_stream_parameter(self, document_rag):
|
||||
"""Test that callback receives end_of_stream parameter"""
|
||||
# Arrange
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 3
|
||||
for call_args in callback.call_args_list:
|
||||
assert len(call_args.args) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_empty_final_chunk_preserved(self, document_rag):
|
||||
"""Test that non-empty final chunks are preserved with correct flag"""
|
||||
# Arrange
|
||||
chunks_with_flags = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
chunks_with_flags.append((chunk, end_of_stream))
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(chunks_with_flags) == 3
|
||||
assert chunks_with_flags[0] == ("Document", False)
|
||||
assert chunks_with_flags[1] == (" summary", False)
|
||||
assert chunks_with_flags[2] == (".", True) # Non-empty final chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_final_chunk(self, document_rag):
|
||||
"""Test that final chunk is not duplicated"""
|
||||
# Arrange
|
||||
chunks = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert - final "." appears exactly once
|
||||
assert chunks.count(".") == 1
|
||||
assert chunks == ["Document", " summary", "."]
|
||||
|
||||
|
||||
class TestStreamingProtocolEdgeCases:
|
||||
"""Test edge cases in streaming protocol"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_empty_chunks_before_final(self):
|
||||
"""Test handling of multiple empty chunks (edge case)"""
|
||||
# Arrange
|
||||
client = AsyncMock()
|
||||
|
||||
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
else:
|
||||
return "textmore"
|
||||
|
||||
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
prompt_client=client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
chunks_with_flags = []
|
||||
|
||||
async def collect(chunk, end_of_stream):
|
||||
chunks_with_flags.append((chunk, end_of_stream))
|
||||
|
||||
# Act
|
||||
await rag.query(
|
||||
query="test",
|
||||
streaming=True,
|
||||
chunk_callback=collect
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(chunks_with_flags) == 4
|
||||
assert chunks_with_flags[-1] == ("", True) # Final empty chunk
|
||||
end_of_stream_flags = [f for c, f in chunks_with_flags]
|
||||
assert end_of_stream_flags.count(True) == 1
|
||||
|
|
@ -14,19 +14,20 @@ from trustgraph.base.async_processor import AsyncProcessor
|
|||
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test AsyncProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.PulsarClient')
|
||||
@patch('trustgraph.base.async_processor.get_pubsub')
|
||||
@patch('trustgraph.base.async_processor.Consumer')
|
||||
@patch('trustgraph.base.async_processor.ProcessorMetrics')
|
||||
@patch('trustgraph.base.async_processor.ConsumerMetrics')
|
||||
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
||||
mock_consumer, mock_pulsar_client):
|
||||
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
||||
mock_consumer, mock_get_pubsub):
|
||||
"""Test basic AsyncProcessor initialization"""
|
||||
# Arrange
|
||||
mock_pulsar_client.return_value = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
mock_consumer.return_value = MagicMock()
|
||||
mock_processor_metrics.return_value = MagicMock()
|
||||
mock_consumer_metrics.return_value = MagicMock()
|
||||
|
||||
|
||||
config = {
|
||||
'id': 'test-async-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
|
|
@ -42,14 +43,14 @@ class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.running == True
|
||||
assert hasattr(processor, 'config_handlers')
|
||||
assert processor.config_handlers == []
|
||||
|
||||
# Verify PulsarClient was created
|
||||
mock_pulsar_client.assert_called_once_with(**config)
|
||||
|
||||
|
||||
# Verify get_pubsub was called to create backend
|
||||
mock_get_pubsub.assert_called_once_with(**config)
|
||||
|
||||
# Verify metrics were initialized
|
||||
mock_processor_metrics.assert_called_once()
|
||||
mock_consumer_metrics.assert_called_once()
|
||||
|
||||
|
||||
# Verify Consumer was created for config subscription
|
||||
mock_consumer.assert_called_once()
|
||||
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class TestResolveCassandraConfig:
|
|||
def test_default_configuration(self):
|
||||
"""Test resolution with no parameters or environment variables."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
hosts, username, password = resolve_cassandra_config()
|
||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||
|
||||
assert hosts == ['cassandra']
|
||||
assert username is None
|
||||
|
|
@ -160,7 +160,7 @@ class TestResolveCassandraConfig:
|
|||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
hosts, username, password = resolve_cassandra_config()
|
||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||
|
||||
assert hosts == ['env1', 'env2', 'env3']
|
||||
assert username == 'env-user'
|
||||
|
|
@ -175,7 +175,7 @@ class TestResolveCassandraConfig:
|
|||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host='explicit-host',
|
||||
username='explicit-user',
|
||||
password='explicit-pass'
|
||||
|
|
@ -188,19 +188,19 @@ class TestResolveCassandraConfig:
|
|||
def test_host_list_parsing(self):
|
||||
"""Test different host list formats."""
|
||||
# Single host
|
||||
hosts, _, _ = resolve_cassandra_config(host='single-host')
|
||||
hosts, _, _, _ = resolve_cassandra_config(host='single-host')
|
||||
assert hosts == ['single-host']
|
||||
|
||||
# Multiple hosts with spaces
|
||||
hosts, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
||||
hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
||||
assert hosts == ['host1', 'host2', 'host3']
|
||||
|
||||
# Empty elements filtered out
|
||||
hosts, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
||||
hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
||||
assert hosts == ['host1', 'host2']
|
||||
|
||||
# Already a list
|
||||
hosts, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
||||
hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
||||
assert hosts == ['list-host1', 'list-host2']
|
||||
|
||||
def test_args_object_resolution(self):
|
||||
|
|
@ -212,7 +212,7 @@ class TestResolveCassandraConfig:
|
|||
cassandra_password = 'args-pass'
|
||||
|
||||
args = MockArgs()
|
||||
hosts, username, password = resolve_cassandra_config(args)
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
||||
|
||||
assert hosts == ['args-host1', 'args-host2']
|
||||
assert username == 'args-user'
|
||||
|
|
@ -233,7 +233,7 @@ class TestResolveCassandraConfig:
|
|||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
args = PartialArgs()
|
||||
hosts, username, password = resolve_cassandra_config(args)
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
||||
|
||||
assert hosts == ['args-host'] # From args
|
||||
assert username == 'env-user' # From env
|
||||
|
|
@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams:
|
|||
'cassandra_password': 'new-pass'
|
||||
}
|
||||
|
||||
hosts, username, password = get_cassandra_config_from_params(params)
|
||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||
|
||||
assert hosts == ['new-host1', 'new-host2']
|
||||
assert username == 'new-user'
|
||||
|
|
@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams:
|
|||
'graph_password': 'old-pass'
|
||||
}
|
||||
|
||||
hosts, username, password = get_cassandra_config_from_params(params)
|
||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||
|
||||
# Should use defaults since graph_* params are not recognized
|
||||
assert hosts == ['cassandra'] # Default
|
||||
|
|
@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams:
|
|||
'cassandra_password': 'compat-pass'
|
||||
}
|
||||
|
||||
hosts, username, password = get_cassandra_config_from_params(params)
|
||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||
|
||||
assert hosts == ['compat-host']
|
||||
assert username is None # cassandra_user is not recognized
|
||||
|
|
@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams:
|
|||
'graph_password': 'old-pass'
|
||||
}
|
||||
|
||||
hosts, username, password = get_cassandra_config_from_params(params)
|
||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||
|
||||
assert hosts == ['new-host'] # Only cassandra_* params work
|
||||
assert username == 'new-user' # Only cassandra_* params work
|
||||
|
|
@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams:
|
|||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
params = {}
|
||||
hosts, username, password = get_cassandra_config_from_params(params)
|
||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
||||
|
||||
assert hosts == ['fallback-host1', 'fallback-host2']
|
||||
assert username == 'fallback-user'
|
||||
|
|
@ -334,7 +334,7 @@ class TestConfigurationPriority:
|
|||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
# CLI args should override everything
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host='cli-host',
|
||||
username='cli-user',
|
||||
password='cli-pass'
|
||||
|
|
@ -354,7 +354,7 @@ class TestConfigurationPriority:
|
|||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
# Only provide host via CLI
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host='cli-host'
|
||||
# username and password not provided
|
||||
)
|
||||
|
|
@ -366,7 +366,7 @@ class TestConfigurationPriority:
|
|||
def test_no_config_defaults(self):
|
||||
"""Test that defaults are used when no configuration is provided."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
hosts, username, password = resolve_cassandra_config()
|
||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
||||
|
||||
assert hosts == ['cassandra'] # Default
|
||||
assert username is None # Default
|
||||
|
|
@ -378,17 +378,17 @@ class TestEdgeCases:
|
|||
|
||||
def test_empty_host_string(self):
|
||||
"""Test handling of empty host string falls back to default."""
|
||||
hosts, _, _ = resolve_cassandra_config(host='')
|
||||
hosts, _, _, _ = resolve_cassandra_config(host='')
|
||||
assert hosts == ['cassandra'] # Falls back to default
|
||||
|
||||
def test_whitespace_only_host(self):
|
||||
"""Test handling of whitespace-only host string."""
|
||||
hosts, _, _ = resolve_cassandra_config(host=' ')
|
||||
hosts, _, _, _ = resolve_cassandra_config(host=' ')
|
||||
assert hosts == [] # Empty after stripping whitespace
|
||||
|
||||
def test_none_values_preserved(self):
|
||||
"""Test that None values are preserved correctly."""
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host=None,
|
||||
username=None,
|
||||
password=None
|
||||
|
|
@ -401,7 +401,7 @@ class TestEdgeCases:
|
|||
|
||||
def test_mixed_none_and_values(self):
|
||||
"""Test mixing None and actual values."""
|
||||
hosts, username, password = resolve_cassandra_config(
|
||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
||||
host='mixed-host',
|
||||
username=None,
|
||||
password='mixed-pass'
|
||||
|
|
|
|||
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
"""
|
||||
Unit tests for PromptClient streaming callback behavior.
|
||||
|
||||
These tests verify that the prompt client correctly passes the end_of_stream
|
||||
flag to chunk callbacks, ensuring proper streaming protocol compliance.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
from trustgraph.base.prompt_client import PromptClient
|
||||
from trustgraph.schema import PromptResponse
|
||||
|
||||
|
||||
class TestPromptClientStreamingCallback:
|
||||
"""Test PromptClient streaming callback behavior"""
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_client(self):
|
||||
"""Create a PromptClient with mocked dependencies"""
|
||||
# Mock all the required initialization parameters
|
||||
with patch.object(PromptClient, '__init__', lambda self: None):
|
||||
client = PromptClient()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_response(self):
|
||||
"""Create a mock request/response handler"""
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Simulate streaming responses
|
||||
responses = [
|
||||
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
else:
|
||||
# Non-streaming response
|
||||
return PromptResponse(text="Hello world!", object=None, error=None)
|
||||
|
||||
return mock_request
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
|
||||
"""Test that callback receives both chunk text and end_of_stream flag"""
|
||||
# Arrange
|
||||
prompt_client.request = mock_request_response
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.prompt(
|
||||
id="test-prompt",
|
||||
variables={"query": "test"},
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert - callback should be called with (chunk, end_of_stream) signature
|
||||
assert callback.call_count == 4
|
||||
|
||||
# Verify first chunk: text + end_of_stream=False
|
||||
assert callback.call_args_list[0] == call("Hello", False)
|
||||
|
||||
# Verify second chunk
|
||||
assert callback.call_args_list[1] == call(" world", False)
|
||||
|
||||
# Verify third chunk
|
||||
assert callback.call_args_list[2] == call("!", False)
|
||||
|
||||
# Verify final chunk: empty text + end_of_stream=True
|
||||
assert callback.call_args_list[3] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
|
||||
"""Test that empty final chunks are passed to callback"""
|
||||
# Arrange
|
||||
prompt_client.request = mock_request_response
|
||||
|
||||
chunks_received = []
|
||||
|
||||
async def collect_chunks(chunk, end_of_stream):
|
||||
chunks_received.append((chunk, end_of_stream))
|
||||
|
||||
# Act
|
||||
await prompt_client.prompt(
|
||||
id="test-prompt",
|
||||
variables={"query": "test"},
|
||||
streaming=True,
|
||||
chunk_callback=collect_chunks
|
||||
)
|
||||
|
||||
# Assert - should receive the empty final chunk
|
||||
final_chunk = chunks_received[-1]
|
||||
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
|
||||
"""Test callback signature when LLM sends non-empty final chunk"""
|
||||
# Arrange
|
||||
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Some LLMs send content in the final chunk
|
||||
responses = [
|
||||
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request_non_empty_final
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.prompt(
|
||||
id="test-prompt",
|
||||
variables={"query": "test"},
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Hello", False)
|
||||
assert callback.call_args_list[1] == call(" world!", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_not_called_without_text(self, prompt_client):
|
||||
"""Test that callback is not called for responses without text"""
|
||||
# Arrange
|
||||
async def mock_request_no_text(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Response with only end_of_stream, no text
|
||||
responses = [
|
||||
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request_no_text
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.prompt(
|
||||
id="test-prompt",
|
||||
variables={"query": "test"},
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert - callback should only be called once (for "Content")
|
||||
assert callback.call_count == 1
|
||||
assert callback.call_args_list[0] == call("Content", False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
|
||||
"""Test that synchronous callbacks also receive end_of_stream parameter"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = MagicMock() # Synchronous mock
|
||||
|
||||
# Act
|
||||
await prompt_client.prompt(
|
||||
id="test-prompt",
|
||||
variables={"query": "test"},
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert - synchronous callback should also get both parameters
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("test", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.kg_prompt(
|
||||
query="What is machine learning?",
|
||||
kg=[("subject", "predicate", "object")],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Answer", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that document_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.document_prompt(
|
||||
query="Summarize this",
|
||||
documents=["doc1", "doc2"],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Summary", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
|
@ -8,22 +8,22 @@ from trustgraph.base.publisher import Publisher
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for testing."""
|
||||
client = MagicMock()
|
||||
def mock_pulsar_backend():
|
||||
"""Mock Pulsar backend for testing."""
|
||||
backend = MagicMock()
|
||||
producer = AsyncMock()
|
||||
producer.send = MagicMock()
|
||||
producer.flush = MagicMock()
|
||||
producer.close = MagicMock()
|
||||
client.create_producer.return_value = producer
|
||||
return client
|
||||
backend.create_producer.return_value = producer
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(mock_pulsar_client):
|
||||
def publisher(mock_pulsar_backend):
|
||||
"""Create Publisher instance for testing."""
|
||||
return Publisher(
|
||||
client=mock_pulsar_client,
|
||||
backend=mock_pulsar_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -34,12 +34,12 @@ def publisher(mock_pulsar_client):
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_queue_drain():
|
||||
"""Verify Publisher drains queue on shutdown."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -85,12 +85,12 @@ async def test_publisher_queue_drain():
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_rejects_messages_during_drain():
|
||||
"""Verify Publisher rejects new messages during shutdown."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -113,12 +113,12 @@ async def test_publisher_rejects_messages_during_drain():
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_drain_timeout():
|
||||
"""Verify Publisher respects drain timeout."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -169,12 +169,12 @@ async def test_publisher_drain_timeout():
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_successful_drain():
|
||||
"""Verify Publisher drains successfully under normal conditions."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -224,12 +224,12 @@ async def test_publisher_successful_drain():
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_state_transitions():
|
||||
"""Test Publisher state transitions during graceful shutdown."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -276,9 +276,9 @@ async def test_publisher_state_transitions():
|
|||
@pytest.mark.asyncio
|
||||
async def test_publisher_exception_handling():
|
||||
"""Test Publisher handles exceptions during drain gracefully."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
mock_backend.create_producer.return_value = mock_producer
|
||||
|
||||
# Mock producer.send to raise exception on second call
|
||||
call_count = 0
|
||||
|
|
@ -291,7 +291,7 @@ async def test_publisher_exception_handling():
|
|||
mock_producer.send.side_effect = failing_send
|
||||
|
||||
publisher = Publisher(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
|
|||
|
|
@ -6,23 +6,11 @@ import uuid
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from trustgraph.base.subscriber import Subscriber
|
||||
|
||||
# Mock JsonSchema globally to avoid schema issues in tests
|
||||
# Patch at the module level where it's imported in subscriber
|
||||
@patch('trustgraph.base.subscriber.JsonSchema')
|
||||
def mock_json_schema_global(mock_schema):
|
||||
mock_schema.return_value = MagicMock()
|
||||
return mock_schema
|
||||
|
||||
# Apply the global patch
|
||||
_json_schema_patch = patch('trustgraph.base.subscriber.JsonSchema')
|
||||
_mock_json_schema = _json_schema_patch.start()
|
||||
_mock_json_schema.return_value = MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for testing."""
|
||||
client = MagicMock()
|
||||
def mock_pulsar_backend():
|
||||
"""Mock Pulsar backend for testing."""
|
||||
backend = MagicMock()
|
||||
consumer = MagicMock()
|
||||
consumer.receive = MagicMock()
|
||||
consumer.acknowledge = MagicMock()
|
||||
|
|
@ -30,15 +18,15 @@ def mock_pulsar_client():
|
|||
consumer.pause_message_listener = MagicMock()
|
||||
consumer.unsubscribe = MagicMock()
|
||||
consumer.close = MagicMock()
|
||||
client.subscribe.return_value = consumer
|
||||
return client
|
||||
backend.create_consumer.return_value = consumer
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def subscriber(mock_pulsar_client):
|
||||
def subscriber(mock_pulsar_backend):
|
||||
"""Create Subscriber instance for testing."""
|
||||
return Subscriber(
|
||||
client=mock_pulsar_client,
|
||||
backend=mock_pulsar_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
|
|
@ -60,14 +48,14 @@ def create_mock_message(message_id="test-id", data=None):
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_deferred_acknowledgment_success():
|
||||
"""Verify Subscriber only acks on successful delivery."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -102,15 +90,15 @@ async def test_subscriber_deferred_acknowledgment_success():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_deferred_acknowledgment_failure():
|
||||
"""Verify Subscriber negative acks on delivery failure."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=1, # Very small queue
|
||||
backpressure_strategy="drop_new"
|
||||
|
|
@ -140,14 +128,14 @@ async def test_subscriber_deferred_acknowledgment_failure():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_backpressure_strategies():
|
||||
"""Test different backpressure strategies."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
# Test drop_oldest strategy
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
topic="test-topic",
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
|
|
@ -187,12 +175,12 @@ async def test_subscriber_backpressure_strategies():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_graceful_shutdown():
|
||||
"""Test Subscriber graceful shutdown with queue draining."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
|
|
@ -253,14 +241,14 @@ async def test_subscriber_graceful_shutdown():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_drain_timeout():
|
||||
"""Test Subscriber respects drain timeout."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
|
|
@ -288,12 +276,12 @@ async def test_subscriber_drain_timeout():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_pending_acks_cleanup():
|
||||
"""Test Subscriber cleans up pending acknowledgments on shutdown."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
|
|
@ -342,12 +330,12 @@ async def test_subscriber_pending_acks_cleanup():
|
|||
@pytest.mark.asyncio
|
||||
async def test_subscriber_multiple_subscribers():
|
||||
"""Test Subscriber with multiple concurrent subscribers."""
|
||||
mock_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
mock_backend.create_consumer.return_value = mock_consumer
|
||||
|
||||
subscriber = Subscriber(
|
||||
client=mock_client,
|
||||
backend=mock_backend,
|
||||
topic="test-topic",
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,8 @@ class TestListConfigItems:
|
|||
mock_list.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json'
|
||||
format_type='json',
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
|
|
@ -126,7 +127,8 @@ class TestListConfigItems:
|
|||
mock_list.assert_called_once_with(
|
||||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text'
|
||||
format_type='text',
|
||||
token=None
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -193,7 +195,8 @@ class TestGetConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json'
|
||||
format_type='json',
|
||||
token=None
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -249,7 +252,8 @@ class TestPutConfigItem:
|
|||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}'
|
||||
value='Custom prompt: {input}',
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
|
|
@ -273,7 +277,8 @@ class TestPutConfigItem:
|
|||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content
|
||||
value=stdin_content,
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
|
|
@ -328,7 +333,8 @@ class TestDeleteConfigItem:
|
|||
mock_delete.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template'
|
||||
key='old-template',
|
||||
token=None
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,17 +2,16 @@
|
|||
Unit tests for the load_knowledge CLI module.
|
||||
|
||||
Tests the business logic of loading triples and entity contexts from Turtle files
|
||||
while mocking WebSocket connections and external dependencies.
|
||||
using the BulkClient API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import tempfile
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch, mock_open, MagicMock
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
from pathlib import Path
|
||||
|
||||
from trustgraph.cli.load_knowledge import KnowledgeLoader, main
|
||||
from trustgraph.api import Triple
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -38,159 +37,80 @@ def temp_turtle_file(sample_turtle_content):
|
|||
f.write(sample_turtle_content)
|
||||
f.flush()
|
||||
yield f.name
|
||||
|
||||
|
||||
# Cleanup
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Mock WebSocket connection."""
|
||||
mock_ws = MagicMock()
|
||||
|
||||
async def async_send(data):
|
||||
return None
|
||||
|
||||
async def async_recv():
|
||||
return ""
|
||||
|
||||
async def async_close():
|
||||
return None
|
||||
|
||||
mock_ws.send = Mock(side_effect=async_send)
|
||||
mock_ws.recv = Mock(side_effect=async_recv)
|
||||
mock_ws.close = Mock(side_effect=async_close)
|
||||
return mock_ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_loader():
|
||||
"""Create a KnowledgeLoader instance with test parameters."""
|
||||
return KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc-123",
|
||||
url="ws://test.example.com/"
|
||||
url="http://test.example.com/",
|
||||
token=None
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeLoader:
|
||||
"""Test the KnowledgeLoader class business logic."""
|
||||
|
||||
def test_init_constructs_urls_correctly(self):
|
||||
"""Test that URLs are constructed properly."""
|
||||
def test_init_stores_parameters_correctly(self):
|
||||
"""Test that initialization stores parameters correctly."""
|
||||
loader = KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
files=["file1.ttl", "file2.ttl"],
|
||||
flow="my-flow",
|
||||
user="user1",
|
||||
collection="col1",
|
||||
document_id="doc1",
|
||||
url="ws://example.com/"
|
||||
)
|
||||
|
||||
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples"
|
||||
assert loader.entity_contexts_url == "ws://example.com/api/v1/flow/my-flow/import/entity-contexts"
|
||||
assert loader.user == "user1"
|
||||
assert loader.collection == "col1"
|
||||
assert loader.document_id == "doc1"
|
||||
|
||||
def test_init_adds_trailing_slash(self):
|
||||
"""Test that trailing slash is added to URL if missing."""
|
||||
loader = KnowledgeLoader(
|
||||
files=["test.ttl"],
|
||||
flow="my-flow",
|
||||
user="user1",
|
||||
collection="col1",
|
||||
document_id="doc1",
|
||||
url="ws://example.com" # No trailing slash
|
||||
url="http://example.com/",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_triples_sends_correct_messages(self, temp_turtle_file, mock_websocket):
|
||||
"""Test that triple loading sends correctly formatted messages."""
|
||||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
await loader.load_triples(temp_turtle_file, mock_websocket)
|
||||
|
||||
# Verify WebSocket send was called
|
||||
assert mock_websocket.send.call_count > 0
|
||||
|
||||
# Check message format for one of the calls
|
||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
||||
|
||||
# Verify message structure
|
||||
sample_message = sent_messages[0]
|
||||
assert "metadata" in sample_message
|
||||
assert "triples" in sample_message
|
||||
|
||||
metadata = sample_message["metadata"]
|
||||
assert metadata["id"] == "test-doc"
|
||||
assert metadata["user"] == "test-user"
|
||||
assert metadata["collection"] == "test-collection"
|
||||
assert isinstance(metadata["metadata"], list)
|
||||
|
||||
triple = sample_message["triples"][0]
|
||||
assert "s" in triple
|
||||
assert "p" in triple
|
||||
assert "o" in triple
|
||||
|
||||
# Check Value structure
|
||||
assert "v" in triple["s"]
|
||||
assert "e" in triple["s"]
|
||||
assert triple["s"]["e"] is True # Subject should be URI
|
||||
assert loader.files == ["file1.ttl", "file2.ttl"]
|
||||
assert loader.flow == "my-flow"
|
||||
assert loader.user == "user1"
|
||||
assert loader.collection == "col1"
|
||||
assert loader.document_id == "doc1"
|
||||
assert loader.url == "http://example.com/"
|
||||
assert loader.token == "test-token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_entity_contexts_processes_literals_only(self, temp_turtle_file, mock_websocket):
|
||||
def test_load_triples_from_file_yields_triples(self, temp_turtle_file, knowledge_loader):
|
||||
"""Test that load_triples_from_file yields Triple objects."""
|
||||
triples = list(knowledge_loader.load_triples_from_file(temp_turtle_file))
|
||||
|
||||
# Should have triples for all statements in the file
|
||||
assert len(triples) > 0
|
||||
|
||||
# Verify they are Triple objects
|
||||
for triple in triples:
|
||||
assert isinstance(triple, Triple)
|
||||
assert hasattr(triple, 's')
|
||||
assert hasattr(triple, 'p')
|
||||
assert hasattr(triple, 'o')
|
||||
assert isinstance(triple.s, str)
|
||||
assert isinstance(triple.p, str)
|
||||
assert isinstance(triple.o, str)
|
||||
|
||||
def test_load_entity_contexts_from_file_yields_literals_only(self, temp_turtle_file, knowledge_loader):
|
||||
"""Test that entity contexts are created only for literals."""
|
||||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
await loader.load_entity_contexts(temp_turtle_file, mock_websocket)
|
||||
|
||||
# Get all sent messages
|
||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
||||
|
||||
# Verify we got entity context messages
|
||||
assert len(sent_messages) > 0
|
||||
|
||||
for message in sent_messages:
|
||||
assert "metadata" in message
|
||||
assert "entities" in message
|
||||
|
||||
metadata = message["metadata"]
|
||||
assert metadata["id"] == "test-doc"
|
||||
assert metadata["user"] == "test-user"
|
||||
assert metadata["collection"] == "test-collection"
|
||||
|
||||
entity_context = message["entities"][0]
|
||||
assert "entity" in entity_context
|
||||
assert "context" in entity_context
|
||||
|
||||
entity = entity_context["entity"]
|
||||
assert "v" in entity
|
||||
assert "e" in entity
|
||||
assert entity["e"] is True # Entity should be URI (subject)
|
||||
|
||||
# Context should be a string (the literal value)
|
||||
assert isinstance(entity_context["context"], str)
|
||||
contexts = list(knowledge_loader.load_entity_contexts_from_file(temp_turtle_file))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_entity_contexts_skips_uri_objects(self, mock_websocket):
|
||||
# Should have contexts for literal objects (foaf:name, foaf:age, foaf:email)
|
||||
assert len(contexts) > 0
|
||||
|
||||
# Verify format: (entity, context) tuples
|
||||
for entity, context in contexts:
|
||||
assert isinstance(entity, str)
|
||||
assert isinstance(context, str)
|
||||
# Entity should be a URI (subject)
|
||||
assert entity.startswith("http://")
|
||||
|
||||
def test_load_entity_contexts_skips_uri_objects(self):
|
||||
"""Test that URI objects don't generate entity contexts."""
|
||||
# Create turtle with only URI objects (no literals)
|
||||
turtle_content = """
|
||||
|
|
@ -198,242 +118,229 @@ class TestKnowledgeLoader:
|
|||
ex:john ex:knows ex:mary .
|
||||
ex:mary ex:knows ex:bob .
|
||||
"""
|
||||
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||
f.write(turtle_content)
|
||||
f.flush()
|
||||
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
)
|
||||
|
||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
# Should not send any messages since there are no literals
|
||||
mock_websocket.send.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.cli.load_knowledge.connect')
|
||||
async def test_run_calls_both_loaders(self, mock_connect, knowledge_loader, temp_turtle_file):
|
||||
"""Test that run() calls both triple and entity context loaders."""
|
||||
knowledge_loader.files = [temp_turtle_file]
|
||||
|
||||
# Create a simple mock websocket
|
||||
mock_ws = MagicMock()
|
||||
async def mock_send(data):
|
||||
pass
|
||||
mock_ws.send = mock_send
|
||||
|
||||
# Create async context manager mock
|
||||
async def mock_aenter(self):
|
||||
return mock_ws
|
||||
|
||||
async def mock_aexit(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.__aenter__ = mock_aenter
|
||||
mock_connection.__aexit__ = mock_aexit
|
||||
mock_connect.return_value = mock_connection
|
||||
|
||||
# Create AsyncMock objects that can track calls properly
|
||||
mock_load_triples = AsyncMock(return_value=None)
|
||||
mock_load_contexts = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(knowledge_loader, 'load_triples', mock_load_triples), \
|
||||
patch.object(knowledge_loader, 'load_entity_contexts', mock_load_contexts):
|
||||
|
||||
await knowledge_loader.run()
|
||||
|
||||
# Verify both methods were called
|
||||
mock_load_triples.assert_called_once_with(temp_turtle_file, mock_ws)
|
||||
mock_load_contexts.assert_called_once_with(temp_turtle_file, mock_ws)
|
||||
|
||||
# Verify WebSocket connections were made to both URLs
|
||||
assert mock_connect.call_count == 2
|
||||
contexts = list(loader.load_entity_contexts_from_file(f.name))
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
# Should have no contexts since there are no literals
|
||||
assert len(contexts) == 0
|
||||
|
||||
@patch('trustgraph.cli.load_knowledge.Api')
|
||||
def test_run_calls_bulk_api(self, mock_api_class, temp_turtle_file):
|
||||
"""Test that run() uses BulkClient API."""
|
||||
# Setup mocks
|
||||
mock_api = MagicMock()
|
||||
mock_bulk = MagicMock()
|
||||
mock_api_class.return_value = mock_api
|
||||
mock_api.bulk.return_value = mock_bulk
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
loader.run()
|
||||
|
||||
# Verify Api was created with correct parameters
|
||||
mock_api_class.assert_called_once_with(
|
||||
url="http://test.example.com/",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
# Verify bulk client was obtained
|
||||
mock_api.bulk.assert_called_once()
|
||||
|
||||
# Verify import_triples was called
|
||||
assert mock_bulk.import_triples.call_count == 1
|
||||
call_args = mock_bulk.import_triples.call_args
|
||||
assert call_args[1]['flow'] == "test-flow"
|
||||
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||
assert call_args[1]['metadata']['user'] == "test-user"
|
||||
assert call_args[1]['metadata']['collection'] == "test-collection"
|
||||
|
||||
# Verify import_entity_contexts was called
|
||||
assert mock_bulk.import_entity_contexts.call_count == 1
|
||||
call_args = mock_bulk.import_entity_contexts.call_args
|
||||
assert call_args[1]['flow'] == "test-flow"
|
||||
assert call_args[1]['metadata']['id'] == "test-doc"
|
||||
|
||||
|
||||
class TestCLIArgumentParsing:
|
||||
"""Test CLI argument parsing and main function."""
|
||||
|
||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
||||
def test_main_parses_args_correctly(self, mock_asyncio_run, mock_loader_class):
|
||||
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||
def test_main_parses_args_correctly(self, mock_sleep, mock_loader_class):
|
||||
"""Test that main() parses arguments correctly."""
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_class.return_value = mock_loader_instance
|
||||
|
||||
|
||||
test_args = [
|
||||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'-f', 'my-flow',
|
||||
'-f', 'my-flow',
|
||||
'-U', 'my-user',
|
||||
'-C', 'my-collection',
|
||||
'-u', 'ws://custom.example.com/',
|
||||
'-u', 'http://custom.example.com/',
|
||||
'-t', 'my-token',
|
||||
'file1.ttl',
|
||||
'file2.ttl'
|
||||
]
|
||||
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
main()
|
||||
|
||||
|
||||
# Verify KnowledgeLoader was instantiated with correct args
|
||||
mock_loader_class.assert_called_once_with(
|
||||
document_id='doc-123',
|
||||
url='ws://custom.example.com/',
|
||||
url='http://custom.example.com/',
|
||||
token='my-token',
|
||||
flow='my-flow',
|
||||
files=['file1.ttl', 'file2.ttl'],
|
||||
user='my-user',
|
||||
collection='my-collection'
|
||||
)
|
||||
|
||||
# Verify asyncio.run was called once
|
||||
mock_asyncio_run.assert_called_once()
|
||||
|
||||
# Verify run was called
|
||||
mock_loader_instance.run.assert_called_once()
|
||||
|
||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
||||
def test_main_uses_defaults(self, mock_asyncio_run, mock_loader_class):
|
||||
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||
def test_main_uses_defaults(self, mock_sleep, mock_loader_class):
|
||||
"""Test that main() uses default values when not specified."""
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_class.return_value = mock_loader_instance
|
||||
|
||||
|
||||
test_args = [
|
||||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'file1.ttl'
|
||||
]
|
||||
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
main()
|
||||
|
||||
|
||||
# Verify defaults were used
|
||||
call_args = mock_loader_class.call_args[1]
|
||||
assert call_args['flow'] == 'default'
|
||||
assert call_args['user'] == 'trustgraph'
|
||||
assert call_args['collection'] == 'default'
|
||||
assert call_args['url'] == 'ws://localhost:8088/'
|
||||
assert call_args['url'] == 'http://localhost:8088/'
|
||||
assert call_args['token'] is None
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_triples_handles_invalid_turtle(self, mock_websocket):
|
||||
def test_load_triples_handles_invalid_turtle(self, knowledge_loader):
|
||||
"""Test handling of invalid Turtle content."""
|
||||
# Create file with invalid Turtle content
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||
f.write("Invalid Turtle Content {{{")
|
||||
f.flush()
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
|
||||
# Should raise an exception for invalid Turtle
|
||||
with pytest.raises(Exception):
|
||||
await loader.load_triples(f.name, mock_websocket)
|
||||
|
||||
list(knowledge_loader.load_triples_from_file(f.name))
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_entity_contexts_handles_invalid_turtle(self, mock_websocket):
|
||||
def test_load_entity_contexts_handles_invalid_turtle(self, knowledge_loader):
|
||||
"""Test handling of invalid Turtle content in entity contexts."""
|
||||
# Create file with invalid Turtle content
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||
f.write("Invalid Turtle Content {{{")
|
||||
f.flush()
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
|
||||
# Should raise an exception for invalid Turtle
|
||||
with pytest.raises(Exception):
|
||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
||||
|
||||
list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.cli.load_knowledge.connect')
|
||||
@patch('trustgraph.cli.load_knowledge.Api')
|
||||
@patch('builtins.print') # Mock print to avoid output during tests
|
||||
async def test_run_handles_connection_errors(self, mock_print, mock_connect, knowledge_loader, temp_turtle_file):
|
||||
"""Test handling of WebSocket connection errors."""
|
||||
knowledge_loader.files = [temp_turtle_file]
|
||||
|
||||
# Mock connection failure
|
||||
mock_connect.side_effect = ConnectionError("Failed to connect")
|
||||
|
||||
# Should not raise exception, just print error
|
||||
await knowledge_loader.run()
|
||||
def test_run_handles_api_errors(self, mock_print, mock_api_class, temp_turtle_file):
|
||||
"""Test handling of API errors."""
|
||||
# Mock API to raise an error
|
||||
mock_api_class.side_effect = Exception("API connection failed")
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[temp_turtle_file],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc",
|
||||
url="http://test.example.com/"
|
||||
)
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="API connection failed"):
|
||||
loader.run()
|
||||
|
||||
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
|
||||
@patch('trustgraph.cli.load_knowledge.asyncio.run')
|
||||
@patch('trustgraph.cli.load_knowledge.time.sleep')
|
||||
@patch('builtins.print') # Mock print to avoid output during tests
|
||||
def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_asyncio_run, mock_loader_class):
|
||||
def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_loader_class):
|
||||
"""Test that main() retries on exceptions."""
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_class.return_value = mock_loader_instance
|
||||
|
||||
|
||||
# First call raises exception, second succeeds
|
||||
mock_asyncio_run.side_effect = [Exception("Test error"), None]
|
||||
|
||||
mock_loader_instance.run.side_effect = [Exception("Test error"), None]
|
||||
|
||||
test_args = [
|
||||
'tg-load-knowledge',
|
||||
'-i', 'doc-123',
|
||||
'-i', 'doc-123',
|
||||
'file1.ttl'
|
||||
]
|
||||
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
main()
|
||||
|
||||
|
||||
# Should have been called twice (first failed, second succeeded)
|
||||
assert mock_asyncio_run.call_count == 2
|
||||
assert mock_loader_instance.run.call_count == 2
|
||||
mock_sleep.assert_called_once_with(10)
|
||||
|
||||
|
||||
class TestDataValidation:
|
||||
"""Test data validation and edge cases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_turtle_file(self, mock_websocket):
|
||||
def test_empty_turtle_file(self, knowledge_loader):
|
||||
"""Test handling of empty Turtle files."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||
f.write("") # Empty file
|
||||
f.flush()
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
await loader.load_triples(f.name, mock_websocket)
|
||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
||||
|
||||
# Should not send any messages for empty file
|
||||
mock_websocket.send.assert_not_called()
|
||||
|
||||
|
||||
triples = list(knowledge_loader.load_triples_from_file(f.name))
|
||||
contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||
|
||||
# Should return empty lists for empty file
|
||||
assert len(triples) == 0
|
||||
assert len(contexts) == 0
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turtle_with_mixed_literals_and_uris(self, mock_websocket):
|
||||
def test_turtle_with_mixed_literals_and_uris(self, knowledge_loader):
|
||||
"""Test handling of Turtle with mixed literal and URI objects."""
|
||||
turtle_content = """
|
||||
@prefix ex: <http://example.org/> .
|
||||
|
|
@ -443,37 +350,23 @@ ex:john ex:name "John Smith" ;
|
|||
ex:city "New York" .
|
||||
ex:mary ex:name "Mary Johnson" .
|
||||
"""
|
||||
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
|
||||
f.write(turtle_content)
|
||||
f.flush()
|
||||
|
||||
loader = KnowledgeLoader(
|
||||
files=[f.name],
|
||||
flow="test-flow",
|
||||
user="test-user",
|
||||
collection="test-collection",
|
||||
document_id="test-doc"
|
||||
)
|
||||
|
||||
await loader.load_entity_contexts(f.name, mock_websocket)
|
||||
|
||||
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
|
||||
|
||||
|
||||
contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
|
||||
|
||||
# Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson")
|
||||
# URI ex:mary should be skipped
|
||||
assert len(sent_messages) == 4
|
||||
|
||||
assert len(contexts) == 4
|
||||
|
||||
# Verify all contexts are for literals (subjects should be URIs)
|
||||
contexts = []
|
||||
for message in sent_messages:
|
||||
entity_context = message["entities"][0]
|
||||
assert entity_context["entity"]["e"] is True # Subject is URI
|
||||
contexts.append(entity_context["context"])
|
||||
|
||||
assert "John Smith" in contexts
|
||||
assert "25" in contexts
|
||||
assert "New York" in contexts
|
||||
assert "Mary Johnson" in contexts
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
context_values = [context for entity, context in contexts]
|
||||
|
||||
assert "John Smith" in context_values
|
||||
assert "25" in context_values
|
||||
assert "New York" in context_values
|
||||
assert "Mary Johnson" in context_values
|
||||
|
||||
Path(f.name).unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -135,7 +135,8 @@ class TestSetToolStructuredQuery:
|
|||
arguments=[],
|
||||
group=None,
|
||||
state=None,
|
||||
applicable_states=None
|
||||
applicable_states=None,
|
||||
token=None
|
||||
)
|
||||
|
||||
def test_set_main_structured_query_no_arguments_needed(self):
|
||||
|
|
@ -313,7 +314,7 @@ class TestShowToolsStructuredQuery:
|
|||
|
||||
show_main()
|
||||
|
||||
mock_show.assert_called_once_with(url='http://custom.com')
|
||||
mock_show.assert_called_once_with(url='http://custom.com', token=None)
|
||||
|
||||
|
||||
class TestStructuredQueryToolValidation:
|
||||
|
|
|
|||
|
|
@ -22,18 +22,18 @@ class TestConfigReceiver:
|
|||
|
||||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
assert config_receiver.pulsar_client == mock_pulsar_client
|
||||
assert config_receiver.backend == mock_backend
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
|
@ -48,8 +48,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
|
@ -87,8 +87,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
|
|
@ -128,8 +128,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
|
|
@ -158,8 +158,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
mock_msg = Mock()
|
||||
|
|
@ -174,8 +174,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
|
|
@ -197,8 +197,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
|
|
@ -217,8 +217,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
|
|
@ -240,8 +240,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
|
|
@ -260,9 +260,9 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
|
|
@ -291,8 +291,8 @@ class TestConfigReceiver:
|
|||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['client'] == mock_pulsar_client
|
||||
|
||||
assert call_args[1]['backend'] == mock_backend
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
|
@ -301,8 +301,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
mock_task = Mock()
|
||||
|
|
@ -320,8 +320,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
|
|
@ -380,8 +380,8 @@ class TestConfigReceiver:
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
mock_backend = Mock()
|
||||
config_receiver = ConfigReceiver(mock_backend)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
|
|
|
|||
|
|
@ -24,10 +24,10 @@ class TestConfigRequestor:
|
|||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Mock dependencies
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber",
|
||||
timeout=60
|
||||
|
|
@ -55,7 +55,7 @@ class TestConfigRequestor:
|
|||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
patch.object(ServiceRequestor, 'process', return_value=None):
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
backend=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
|
@ -79,7 +79,7 @@ class TestConfigRequestor:
|
|||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
backend=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatcher_manager_initialization(self):
|
||||
"""Test DispatcherManager initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
assert manager.pulsar_client == mock_pulsar_client
|
||||
assert manager.backend == mock_backend
|
||||
assert manager.config_receiver == mock_config_receiver
|
||||
assert manager.prefix == "api-gateway" # default prefix
|
||||
assert manager.flows == {}
|
||||
|
|
@ -55,19 +55,19 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatcher_manager_initialization_with_custom_prefix(self):
|
||||
"""Test DispatcherManager initialization with custom prefix"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix")
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
|
||||
|
||||
assert manager.prefix == "custom-prefix"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow(self):
|
||||
"""Test start_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
|
|
@ -79,9 +79,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
"""Test stop_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
|
@ -93,9 +93,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_global_service()
|
||||
|
||||
|
|
@ -104,9 +104,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_core_export_returns_wrapper(self):
|
||||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_export()
|
||||
|
||||
|
|
@ -115,9 +115,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_core_import_returns_wrapper(self):
|
||||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_import()
|
||||
|
||||
|
|
@ -127,9 +127,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_core_import(self):
|
||||
"""Test process_core_import method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||
mock_importer = Mock()
|
||||
|
|
@ -138,16 +138,16 @@ class TestDispatcherManager:
|
|||
|
||||
result = await manager.process_core_import("data", "error", "ok", "request")
|
||||
|
||||
mock_core_import.assert_called_once_with(mock_pulsar_client)
|
||||
mock_core_import.assert_called_once_with(mock_backend)
|
||||
mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "import_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_core_export(self):
|
||||
"""Test process_core_export method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||
mock_exporter = Mock()
|
||||
|
|
@ -156,16 +156,16 @@ class TestDispatcherManager:
|
|||
|
||||
result = await manager.process_core_export("data", "error", "ok", "request")
|
||||
|
||||
mock_core_export.assert_called_once_with(mock_pulsar_client)
|
||||
mock_core_export.assert_called_once_with(mock_backend)
|
||||
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "export_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_global_service(self):
|
||||
"""Test process_global_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||
|
||||
|
|
@ -178,9 +178,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_global_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
|
|
@ -195,9 +195,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_creates_new_dispatcher(self):
|
||||
"""Test invoke_global_service creates new dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
|
|
@ -211,10 +211,12 @@ class TestDispatcherManager:
|
|||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
timeout=120,
|
||||
consumer="api-gateway-config-request",
|
||||
subscriber="api-gateway-config-request"
|
||||
subscriber="api-gateway-config-request",
|
||||
request_queue=None,
|
||||
response_queue=None
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
|
@ -225,9 +227,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_flow_import_returns_method(self):
|
||||
"""Test dispatch_flow_import returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_import()
|
||||
|
||||
|
|
@ -235,9 +237,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_flow_export_returns_method(self):
|
||||
"""Test dispatch_flow_export returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_export()
|
||||
|
||||
|
|
@ -245,9 +247,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_socket_returns_method(self):
|
||||
"""Test dispatch_socket returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_socket()
|
||||
|
||||
|
|
@ -255,9 +257,9 @@ class TestDispatcherManager:
|
|||
|
||||
def test_dispatch_flow_service_returns_wrapper(self):
|
||||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_flow_service()
|
||||
|
||||
|
|
@ -267,9 +269,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_import with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -292,7 +294,7 @@ class TestDispatcherManager:
|
|||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"}
|
||||
|
|
@ -303,9 +305,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_invalid_flow(self):
|
||||
"""Test process_flow_import with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||
|
||||
|
|
@ -318,9 +320,9 @@ class TestDispatcherManager:
|
|||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", RuntimeWarning)
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -340,9 +342,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_flow_export_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_export with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -364,7 +366,7 @@ class TestDispatcherManager:
|
|||
result = await manager.process_flow_export("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"},
|
||||
|
|
@ -376,9 +378,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_socket(self):
|
||||
"""Test process_socket method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||
mock_mux_instance = Mock()
|
||||
|
|
@ -392,9 +394,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_process_flow_service(self):
|
||||
"""Test process_flow_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||
|
||||
|
|
@ -407,9 +409,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_flow_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
|
@ -427,9 +429,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_request_response_dispatcher(self):
|
||||
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -454,7 +456,7 @@ class TestDispatcherManager:
|
|||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
|
|
@ -471,9 +473,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_sender_dispatcher(self):
|
||||
"""Test invoke_flow_service creates sender dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -498,7 +500,7 @@ class TestDispatcherManager:
|
|||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue={"queue": "text_load_queue"}
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
|
|
@ -511,9 +513,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_flow(self):
|
||||
"""Test invoke_flow_service with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
|
|
@ -521,9 +523,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
"""Test invoke_flow_service with kind not supported by flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
@ -538,9 +540,9 @@ class TestDispatcherManager:
|
|||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
"""Test invoke_flow_service with invalid kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_backend = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
manager = DispatcherManager(mock_backend, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ class TestServiceRequestor:
|
|||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
request_queue="test-request-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="test-response-queue",
|
||||
|
|
@ -32,12 +32,12 @@ class TestServiceRequestor:
|
|||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-request-queue", schema=mock_request_schema
|
||||
mock_backend, "test-request-queue", schema=mock_request_schema
|
||||
)
|
||||
|
||||
# Verify Subscriber was created correctly
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "test-response-queue",
|
||||
mock_backend, "test-response-queue",
|
||||
"test-subscription", "test-consumer", mock_response_schema
|
||||
)
|
||||
|
||||
|
|
@ -48,12 +48,12 @@ class TestServiceRequestor:
|
|||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization with default parameters"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
request_queue="test-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="response-queue",
|
||||
|
|
@ -62,7 +62,7 @@ class TestServiceRequestor:
|
|||
|
||||
# Verify default values
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "response-queue",
|
||||
mock_backend, "response-queue",
|
||||
"api-gateway", "api-gateway", mock_response_schema
|
||||
)
|
||||
assert requestor.timeout == 600 # Default timeout
|
||||
|
|
@ -72,14 +72,14 @@ class TestServiceRequestor:
|
|||
@pytest.mark.asyncio
|
||||
async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor start method"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
|
|
@ -98,14 +98,14 @@ class TestServiceRequestor:
|
|||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor has correct attributes"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
|
|
|
|||
|
|
@ -14,18 +14,18 @@ class TestServiceSender:
|
|||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_initialization(self, mock_publisher):
|
||||
"""Test ServiceSender initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_schema = MagicMock()
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue",
|
||||
schema=mock_schema
|
||||
)
|
||||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-queue", schema=mock_schema
|
||||
mock_backend, "test-queue", schema=mock_schema
|
||||
)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
|
|
@ -36,7 +36,7 @@ class TestServiceSender:
|
|||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
backend=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
|
@ -55,7 +55,7 @@ class TestServiceSender:
|
|||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
backend=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
|
@ -70,7 +70,7 @@ class TestServiceSender:
|
|||
def test_service_sender_to_request_not_implemented(self, mock_publisher):
|
||||
"""Test ServiceSender to_request method raises RuntimeError"""
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
backend=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
|
@ -91,7 +91,7 @@ class TestServiceSender:
|
|||
return {"processed": request}
|
||||
|
||||
sender = ConcreteSender(
|
||||
pulsar_client=MagicMock(),
|
||||
backend=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
|
@ -111,7 +111,7 @@ class TestServiceSender:
|
|||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
backend=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from trustgraph.schema import Metadata, ExtractedObject
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
def mock_backend():
|
||||
"""Mock Pulsar client."""
|
||||
client = Mock()
|
||||
return client
|
||||
|
|
@ -96,7 +96,7 @@ class TestObjectsImportInitialization:
|
|||
"""Test ObjectsImport initialization."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport creates Publisher with correct parameters."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_class.return_value = mock_publisher_instance
|
||||
|
|
@ -104,13 +104,13 @@ class TestObjectsImportInitialization:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-objects-queue"
|
||||
)
|
||||
|
||||
# Verify Publisher was created with correct parameters
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_pulsar_client,
|
||||
mock_backend,
|
||||
topic="test-objects-queue",
|
||||
schema=ExtractedObject
|
||||
)
|
||||
|
|
@ -121,12 +121,12 @@ class TestObjectsImportInitialization:
|
|||
assert objects_import.publisher == mock_publisher_instance
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that ObjectsImport stores all required references."""
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="objects-queue"
|
||||
)
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class TestObjectsImportLifecycle:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that start() calls publisher.start()."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.start = AsyncMock()
|
||||
|
|
@ -148,7 +148,7 @@ class TestObjectsImportLifecycle:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ class TestObjectsImportLifecycle:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that destroy() properly stops publisher and closes websocket."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.stop = AsyncMock()
|
||||
|
|
@ -167,7 +167,7 @@ class TestObjectsImportLifecycle:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -180,7 +180,7 @@ class TestObjectsImportLifecycle:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running):
|
||||
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that destroy() handles None websocket gracefully."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.stop = AsyncMock()
|
||||
|
|
@ -189,7 +189,7 @@ class TestObjectsImportLifecycle:
|
|||
objects_import = ObjectsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -205,7 +205,7 @@ class TestObjectsImportMessageProcessing:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
||||
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() processes complete message correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
|
|
@ -214,7 +214,7 @@ class TestObjectsImportMessageProcessing:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -248,7 +248,7 @@ class TestObjectsImportMessageProcessing:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message):
|
||||
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
|
||||
"""Test that receive() handles message with minimal required fields."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
|
|
@ -257,7 +257,7 @@ class TestObjectsImportMessageProcessing:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ class TestObjectsImportMessageProcessing:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() uses appropriate default values for optional fields."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
|
|
@ -290,7 +290,7 @@ class TestObjectsImportMessageProcessing:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -323,7 +323,7 @@ class TestObjectsImportRunMethod:
|
|||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that run() loops while running.get() returns True."""
|
||||
mock_sleep.return_value = None
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
|
@ -334,7 +334,7 @@ class TestObjectsImportRunMethod:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ class TestObjectsImportRunMethod:
|
|||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running):
|
||||
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
|
||||
"""Test that run() handles None websocket gracefully."""
|
||||
mock_sleep.return_value = None
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
|
@ -363,7 +363,7 @@ class TestObjectsImportRunMethod:
|
|||
objects_import = ObjectsImport(
|
||||
ws=None, # None websocket
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -417,7 +417,7 @@ class TestObjectsImportBatchProcessing:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message):
|
||||
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
|
||||
"""Test that receive() processes batch message correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
|
|
@ -426,7 +426,7 @@ class TestObjectsImportBatchProcessing:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -467,7 +467,7 @@ class TestObjectsImportBatchProcessing:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles empty batch correctly."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock()
|
||||
|
|
@ -476,7 +476,7 @@ class TestObjectsImportBatchProcessing:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -507,7 +507,7 @@ class TestObjectsImportErrorHandling:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message):
|
||||
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
|
||||
"""Test that receive() propagates publisher send errors."""
|
||||
mock_publisher_instance = Mock()
|
||||
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
|
||||
|
|
@ -516,7 +516,7 @@ class TestObjectsImportErrorHandling:
|
|||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
@ -528,14 +528,14 @@ class TestObjectsImportErrorHandling:
|
|||
|
||||
@patch('trustgraph.gateway.dispatch.objects_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running):
|
||||
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
|
||||
"""Test that receive() handles malformed JSON appropriately."""
|
||||
mock_publisher_class.return_value = Mock()
|
||||
|
||||
objects_import = ObjectsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
pulsar_client=mock_pulsar_client,
|
||||
backend=mock_backend,
|
||||
queue="test-queue"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,23 +19,21 @@ class TestApi:
|
|||
|
||||
def test_api_initialization_with_defaults(self):
|
||||
"""Test Api initialization with default values"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
api = Api()
|
||||
|
||||
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
# Verify Pulsar client was created without API key
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None
|
||||
)
|
||||
|
||||
# Verify get_pubsub was called
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_with_custom_config(self):
|
||||
"""Test Api initialization with custom configuration"""
|
||||
|
|
@ -48,14 +46,13 @@ class TestApi:
|
|||
"prometheus_url": "http://custom-prometheus:9090",
|
||||
"api_token": "secret-token"
|
||||
}
|
||||
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_backend = Mock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
api = Api(**config)
|
||||
|
||||
|
||||
assert api.port == 9000
|
||||
assert api.timeout == 300
|
||||
assert api.pulsar_host == "pulsar://custom-host:6650"
|
||||
|
|
@ -63,35 +60,25 @@ class TestApi:
|
|||
assert api.prometheus_url == "http://custom-prometheus:9090/"
|
||||
assert api.auth.token == "secret-token"
|
||||
assert api.auth.allow_all is False
|
||||
|
||||
# Verify Pulsar client was created with API key
|
||||
mock_auth.assert_called_once_with("test-api-key")
|
||||
mock_client.assert_called_once_with(
|
||||
"pulsar://custom-host:6650",
|
||||
listener_name="custom-listener",
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
# Verify get_pubsub was called with config
|
||||
mock_get_pubsub.assert_called_once_with(**config)
|
||||
|
||||
def test_api_initialization_with_pulsar_api_key(self):
|
||||
"""Test Api initialization with Pulsar API key authentication"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(pulsar_api_key="test-key")
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None,
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
# Verify api key was stored
|
||||
assert api.pulsar_api_key == "test-key"
|
||||
mock_get_pubsub.assert_called_once()
|
||||
|
||||
def test_api_initialization_prometheus_url_normalization(self):
|
||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Test URL without trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090")
|
||||
|
|
@ -103,16 +90,16 @@ class TestApi:
|
|||
|
||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||
"""Test that empty API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token="")
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||
"""Test that None API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(api_token=None)
|
||||
assert api.auth.allow_all is True
|
||||
|
|
@ -120,8 +107,8 @@ class TestApi:
|
|||
@pytest.mark.asyncio
|
||||
async def test_app_factory_creates_application(self):
|
||||
"""Test that app_factory creates aiohttp application"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
|
|
@ -147,8 +134,8 @@ class TestApi:
|
|||
@pytest.mark.asyncio
|
||||
async def test_app_factory_with_custom_endpoints(self):
|
||||
"""Test app_factory with custom endpoints"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
|
|
@ -180,13 +167,13 @@ class TestApi:
|
|||
|
||||
def test_run_method_calls_web_run_app(self):
|
||||
"""Test that run method calls web.run_app"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
|
||||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
|
||||
# Verify run_app was called once with the correct port
|
||||
mock_run_app.assert_called_once()
|
||||
args, kwargs = mock_run_app.call_args
|
||||
|
|
@ -195,19 +182,19 @@ class TestApi:
|
|||
|
||||
def test_api_components_initialization(self):
|
||||
"""Test that all API components are properly initialized"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
|
||||
# Verify all components are initialized
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
assert api.endpoints == []
|
||||
|
||||
|
||||
# Verify component relationships
|
||||
assert api.dispatcher_manager.pulsar_client == api.pulsar_client
|
||||
assert api.dispatcher_manager.backend == api.pubsub_backend
|
||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ async def test_handle_normal_flow():
|
|||
"""Test normal websocket handling flow."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
dispatcher_created = False
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
nonlocal dispatcher_created
|
||||
|
|
@ -110,33 +110,43 @@ async def test_handle_normal_flow():
|
|||
dispatcher = AsyncMock()
|
||||
dispatcher.destroy = AsyncMock()
|
||||
return dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
# Mock task group context manager
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
# Consume the coroutine to avoid "was never awaited" warning
|
||||
coro.close()
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have created dispatcher
|
||||
assert dispatcher_created is True
|
||||
|
||||
|
||||
# Should return websocket
|
||||
assert result == mock_ws
|
||||
|
||||
|
|
@ -146,50 +156,64 @@ async def test_handle_exception_group_cleanup():
|
|||
"""Test exception group triggers dispatcher cleanup."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
# Mock TaskGroup to raise ExceptionGroup
|
||||
class TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
# Consume the coroutine to avoid "was never awaited" warning
|
||||
coro.close()
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.return_value = None
|
||||
|
||||
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
||||
# Make wait_for consume the coroutine passed to it
|
||||
async def wait_for_side_effect(coro, timeout=None):
|
||||
coro.close() # Consume the coroutine
|
||||
return None
|
||||
mock_wait_for.side_effect = wait_for_side_effect
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have attempted graceful cleanup
|
||||
mock_wait_for.assert_called_once()
|
||||
|
||||
|
||||
# Should have called destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
|
||||
# Should have closed websocket
|
||||
mock_ws.close.assert_called()
|
||||
|
||||
|
|
@ -199,48 +223,62 @@ async def test_handle_dispatcher_cleanup_timeout():
|
|||
"""Test dispatcher cleanup with timeout."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
# Mock dispatcher that takes long to destroy
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
# Mock TaskGroup to raise exception
|
||||
exception_group = ExceptionGroup("Test", [Exception("test")])
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
# Consume the coroutine to avoid "was never awaited" warning
|
||||
coro.close()
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
# Mock asyncio.wait_for to raise TimeoutError
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
||||
|
||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
|
||||
# Make wait_for consume the coroutine before raising
|
||||
async def wait_for_timeout(coro, timeout=None):
|
||||
coro.close() # Consume the coroutine
|
||||
raise asyncio.TimeoutError("Cleanup timeout")
|
||||
mock_wait_for.side_effect = wait_for_timeout
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should have attempted cleanup with timeout
|
||||
mock_wait_for.assert_called_once()
|
||||
# Check that timeout was passed correctly
|
||||
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
||||
|
||||
|
||||
# Should still call destroy in finally block
|
||||
assert mock_dispatcher.destroy.call_count >= 1
|
||||
|
||||
|
|
@ -290,37 +328,47 @@ async def test_handle_websocket_already_closed():
|
|||
"""Test handling when websocket is already closed."""
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.permitted.return_value = True
|
||||
|
||||
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.destroy = AsyncMock()
|
||||
|
||||
|
||||
async def mock_dispatcher_factory(ws, running, match_info):
|
||||
return mock_dispatcher
|
||||
|
||||
|
||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"token": "valid-token"}
|
||||
request.match_info = {}
|
||||
|
||||
|
||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.prepare = AsyncMock()
|
||||
mock_ws.close = AsyncMock()
|
||||
mock_ws.closed = True # Already closed
|
||||
mock_ws_class.return_value = mock_ws
|
||||
|
||||
|
||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||
mock_tg = AsyncMock()
|
||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
||||
|
||||
# Create proper mock tasks that look like asyncio.Task objects
|
||||
def create_task_mock(coro):
|
||||
# Consume the coroutine to avoid "was never awaited" warning
|
||||
coro.close()
|
||||
task = AsyncMock()
|
||||
task.done = MagicMock(return_value=True)
|
||||
task.cancelled = MagicMock(return_value=False)
|
||||
return task
|
||||
|
||||
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||
mock_task_group.return_value = mock_tg
|
||||
|
||||
|
||||
result = await socket_endpoint.handle(request)
|
||||
|
||||
|
||||
# Should still have called destroy
|
||||
mock_dispatcher.destroy.assert_called()
|
||||
|
||||
|
||||
# Should not attempt to close already closed websocket
|
||||
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
||||
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
"""
|
||||
Unit tests for streaming behavior in message translators.
|
||||
|
||||
These tests verify that translators correctly handle empty strings and
|
||||
end_of_stream flags in streaming responses, preventing bugs where empty
|
||||
final chunks could be dropped due to falsy value checks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.messaging.translators.retrieval import (
|
||||
GraphRagResponseTranslator,
|
||||
DocumentRagResponseTranslator,
|
||||
)
|
||||
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
|
||||
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse,
|
||||
DocumentRagResponse,
|
||||
PromptResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphRagResponseTranslator:
|
||||
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - Empty string should be included in result
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response="Some text",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Some text"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_response(self):
|
||||
"""Test that None response is handled correctly"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
response = GraphRagResponse(
|
||||
response=None,
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - None should not be included
|
||||
assert "response" not in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_response_with_completion_returns_correct_flag(self):
|
||||
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||
# Arrange
|
||||
translator = GraphRagResponseTranslator()
|
||||
|
||||
# Test non-final chunk
|
||||
response_chunk = GraphRagResponse(
|
||||
response="chunk",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||
|
||||
# Assert
|
||||
assert is_final is False
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
# Test final chunk with empty content
|
||||
final_response = GraphRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(final_response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
|
||||
class TestDocumentRagResponseTranslator:
|
||||
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_response(self):
|
||||
"""Test that empty response strings are preserved"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
response = DocumentRagResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_response(self):
|
||||
"""Test that non-empty responses work correctly"""
|
||||
# Arrange
|
||||
translator = DocumentRagResponseTranslator()
|
||||
response = DocumentRagResponse(
|
||||
response="Document content",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["response"] == "Document content"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
|
||||
class TestPromptResponseTranslator:
|
||||
"""Test PromptResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_with_empty_text(self):
|
||||
"""Test that empty text strings are preserved"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text="",
|
||||
object=None,
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "text" in result
|
||||
assert result["text"] == ""
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_with_non_empty_text(self):
|
||||
"""Test that non-empty text works correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text="Some prompt response",
|
||||
object=None,
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert result["text"] == "Some prompt response"
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
def test_from_pulsar_with_none_text(self):
|
||||
"""Test that None text is handled correctly"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
response = PromptResponse(
|
||||
text=None,
|
||||
object='{"result": "data"}',
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "text" not in result
|
||||
assert "object" in result
|
||||
assert result["end_of_stream"] is True
|
||||
|
||||
def test_from_pulsar_includes_end_of_stream(self):
|
||||
"""Test that end_of_stream flag is always included"""
|
||||
# Arrange
|
||||
translator = PromptResponseTranslator()
|
||||
|
||||
# Test with end_of_stream=False
|
||||
response = PromptResponse(
|
||||
text="chunk",
|
||||
object=None,
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result
|
||||
assert result["end_of_stream"] is False
|
||||
|
||||
|
||||
class TestTextCompletionResponseTranslator:
|
||||
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||
|
||||
def test_from_pulsar_always_includes_response(self):
|
||||
"""Test that response field is always included, even if empty"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
response = TextCompletionResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
in_token=100,
|
||||
out_token=5,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert - Response should always be present
|
||||
assert "response" in result
|
||||
assert result["response"] == ""
|
||||
|
||||
def test_from_response_with_completion_with_empty_final(self):
|
||||
"""Test that empty final response is handled correctly"""
|
||||
# Arrange
|
||||
translator = TextCompletionResponseTranslator()
|
||||
response = TextCompletionResponse(
|
||||
response="",
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
in_token=100,
|
||||
out_token=5,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
# Act
|
||||
result, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True
|
||||
assert result["response"] == ""
|
||||
|
||||
|
||||
class TestStreamingProtocolCompliance:
|
||||
"""Test that all translators follow streaming protocol conventions"""
|
||||
|
||||
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||
(PromptResponseTranslator, PromptResponse, "text"),
|
||||
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||
])
|
||||
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
|
||||
"""Test that all translators preserve empty final chunks"""
|
||||
# Arrange
|
||||
translator = translator_class()
|
||||
kwargs = {
|
||||
field_name: "",
|
||||
"end_of_stream": True,
|
||||
"error": None,
|
||||
}
|
||||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
|
||||
|
||||
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||
])
|
||||
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
|
||||
"""Test that end_of_stream flag is included in all response translators"""
|
||||
# Arrange
|
||||
translator = translator_class()
|
||||
kwargs = {
|
||||
field_name: "test content",
|
||||
"end_of_stream": True,
|
||||
"error": None,
|
||||
}
|
||||
response = response_class(**kwargs)
|
||||
|
||||
# Act
|
||||
result = translator.from_pulsar(response)
|
||||
|
||||
# Assert
|
||||
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||
assert result["end_of_stream"] is True
|
||||
446
tests/unit/test_python_api_client.py
Normal file
446
tests/unit/test_python_api_client.py
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
"""
|
||||
Unit tests for TrustGraph Python API client library
|
||||
|
||||
These tests use mocks and do not require a running server.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
import json
|
||||
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
Triple,
|
||||
AgentThought,
|
||||
AgentObservation,
|
||||
AgentAnswer,
|
||||
RAGChunk,
|
||||
)
|
||||
|
||||
|
||||
class TestApiInstantiation:
|
||||
"""Test Api class instantiation and configuration"""
|
||||
|
||||
def test_api_instantiation_defaults(self):
|
||||
"""Test Api with default parameters"""
|
||||
api = Api()
|
||||
assert api.url == "http://localhost:8088/api/v1/"
|
||||
assert api.timeout == 60
|
||||
assert api.token is None
|
||||
|
||||
def test_api_instantiation_with_url(self):
|
||||
"""Test Api with custom URL"""
|
||||
api = Api(url="http://test-server:9000/")
|
||||
assert api.url == "http://test-server:9000/api/v1/"
|
||||
|
||||
def test_api_instantiation_with_url_trailing_slash(self):
|
||||
"""Test Api adds trailing slash if missing"""
|
||||
api = Api(url="http://test-server:9000")
|
||||
assert api.url == "http://test-server:9000/api/v1/"
|
||||
|
||||
def test_api_instantiation_with_token(self):
|
||||
"""Test Api with authentication token"""
|
||||
api = Api(token="test-token-123")
|
||||
assert api.token == "test-token-123"
|
||||
|
||||
def test_api_instantiation_with_timeout(self):
|
||||
"""Test Api with custom timeout"""
|
||||
api = Api(timeout=120)
|
||||
assert api.timeout == 120
|
||||
|
||||
|
||||
class TestApiLazyInitialization:
|
||||
"""Test lazy initialization of client components"""
|
||||
|
||||
def test_socket_client_lazy_init(self):
|
||||
"""Test socket client is created on first access"""
|
||||
api = Api(url="http://test/", token="token")
|
||||
|
||||
assert api._socket_client is None
|
||||
socket = api.socket()
|
||||
assert api._socket_client is not None
|
||||
assert socket is api._socket_client
|
||||
|
||||
# Second access returns same instance
|
||||
socket2 = api.socket()
|
||||
assert socket2 is socket
|
||||
|
||||
def test_bulk_client_lazy_init(self):
|
||||
"""Test bulk client is created on first access"""
|
||||
api = Api(url="http://test/")
|
||||
|
||||
assert api._bulk_client is None
|
||||
bulk = api.bulk()
|
||||
assert api._bulk_client is not None
|
||||
|
||||
def test_async_flow_lazy_init(self):
|
||||
"""Test async flow is created on first access"""
|
||||
api = Api(url="http://test/")
|
||||
|
||||
assert api._async_flow is None
|
||||
async_flow = api.async_flow()
|
||||
assert api._async_flow is not None
|
||||
|
||||
def test_metrics_lazy_init(self):
|
||||
"""Test metrics client is created on first access"""
|
||||
api = Api(url="http://test/")
|
||||
|
||||
assert api._metrics is None
|
||||
metrics = api.metrics()
|
||||
assert api._metrics is not None
|
||||
|
||||
|
||||
class TestApiContextManager:
|
||||
"""Test context manager functionality"""
|
||||
|
||||
def test_sync_context_manager(self):
|
||||
"""Test synchronous context manager"""
|
||||
with Api(url="http://test/") as api:
|
||||
assert api is not None
|
||||
assert isinstance(api, Api)
|
||||
# Should exit cleanly
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
"""Test asynchronous context manager"""
|
||||
async with Api(url="http://test/") as api:
|
||||
assert api is not None
|
||||
assert isinstance(api, Api)
|
||||
# Should exit cleanly
|
||||
|
||||
|
||||
class TestFlowClient:
|
||||
"""Test Flow client functionality"""
|
||||
|
||||
@patch('requests.post')
|
||||
def test_flow_list(self, mock_post):
|
||||
"""Test listing flows"""
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {"flow-ids": ["flow1", "flow2"]}
|
||||
|
||||
api = Api(url="http://test/")
|
||||
flows = api.flow().list()
|
||||
|
||||
assert flows == ["flow1", "flow2"]
|
||||
assert mock_post.called
|
||||
|
||||
@patch('requests.post')
|
||||
def test_flow_list_with_token(self, mock_post):
|
||||
"""Test flow listing includes auth token"""
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {"flow-ids": []}
|
||||
|
||||
api = Api(url="http://test/", token="my-token")
|
||||
api.flow().list()
|
||||
|
||||
# Verify Authorization header was set
|
||||
call_args = mock_post.call_args
|
||||
headers = call_args[1]['headers'] if 'headers' in call_args[1] else {}
|
||||
assert 'Authorization' in headers
|
||||
assert headers['Authorization'] == 'Bearer my-token'
|
||||
|
||||
@patch('requests.post')
|
||||
def test_flow_get(self, mock_post):
|
||||
"""Test getting flow definition"""
|
||||
flow_def = {"name": "test-flow", "description": "Test"}
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {"flow": json.dumps(flow_def)}
|
||||
|
||||
api = Api(url="http://test/")
|
||||
result = api.flow().get("test-flow")
|
||||
|
||||
assert result == flow_def
|
||||
|
||||
def test_flow_instance_creation(self):
|
||||
"""Test creating flow instance"""
|
||||
api = Api(url="http://test/")
|
||||
flow_instance = api.flow().id("my-flow")
|
||||
|
||||
assert flow_instance is not None
|
||||
assert flow_instance.id == "my-flow"
|
||||
|
||||
def test_flow_instance_has_methods(self):
|
||||
"""Test flow instance has expected methods"""
|
||||
api = Api(url="http://test/")
|
||||
flow_instance = api.flow().id("my-flow")
|
||||
|
||||
expected_methods = [
|
||||
'text_completion', 'agent', 'graph_rag', 'document_rag',
|
||||
'graph_embeddings_query', 'embeddings', 'prompt',
|
||||
'triples_query', 'objects_query'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||
|
||||
|
||||
class TestSocketClient:
|
||||
"""Test WebSocket client functionality"""
|
||||
|
||||
def test_socket_client_url_conversion_http(self):
|
||||
"""Test HTTP URL converted to WebSocket"""
|
||||
api = Api(url="http://test-server:8088/")
|
||||
socket = api.socket()
|
||||
|
||||
assert socket.url.startswith("ws://")
|
||||
assert "test-server" in socket.url
|
||||
|
||||
def test_socket_client_url_conversion_https(self):
|
||||
"""Test HTTPS URL converted to secure WebSocket"""
|
||||
api = Api(url="https://test-server:8088/")
|
||||
socket = api.socket()
|
||||
|
||||
assert socket.url.startswith("wss://")
|
||||
|
||||
def test_socket_client_token_passed(self):
|
||||
"""Test token is passed to socket client"""
|
||||
api = Api(url="http://test/", token="socket-token")
|
||||
socket = api.socket()
|
||||
|
||||
assert socket.token == "socket-token"
|
||||
|
||||
def test_socket_flow_instance(self):
|
||||
"""Test creating socket flow instance"""
|
||||
api = Api(url="http://test/")
|
||||
socket = api.socket()
|
||||
flow_instance = socket.flow("test-flow")
|
||||
|
||||
assert flow_instance is not None
|
||||
assert flow_instance.flow_id == "test-flow"
|
||||
|
||||
def test_socket_flow_has_methods(self):
|
||||
"""Test socket flow instance has expected methods"""
|
||||
api = Api(url="http://test/")
|
||||
flow_instance = api.socket().flow("test-flow")
|
||||
|
||||
expected_methods = [
|
||||
'agent', 'text_completion', 'graph_rag', 'document_rag',
|
||||
'prompt', 'graph_embeddings_query', 'embeddings',
|
||||
'triples_query', 'objects_query', 'mcp_tool'
|
||||
]
|
||||
|
||||
for method in expected_methods:
|
||||
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||
|
||||
|
||||
class TestBulkClient:
|
||||
"""Test bulk operations client"""
|
||||
|
||||
def test_bulk_client_url_conversion(self):
|
||||
"""Test bulk client uses WebSocket URL"""
|
||||
api = Api(url="http://test/")
|
||||
bulk = api.bulk()
|
||||
|
||||
assert bulk.url.startswith("ws://")
|
||||
|
||||
def test_bulk_client_has_import_methods(self):
|
||||
"""Test bulk client has import methods"""
|
||||
api = Api(url="http://test/")
|
||||
bulk = api.bulk()
|
||||
|
||||
import_methods = [
|
||||
'import_triples',
|
||||
'import_graph_embeddings',
|
||||
'import_document_embeddings',
|
||||
'import_entity_contexts',
|
||||
'import_objects'
|
||||
]
|
||||
|
||||
for method in import_methods:
|
||||
assert hasattr(bulk, method), f"Missing method: {method}"
|
||||
|
||||
def test_bulk_client_has_export_methods(self):
|
||||
"""Test bulk client has export methods"""
|
||||
api = Api(url="http://test/")
|
||||
bulk = api.bulk()
|
||||
|
||||
export_methods = [
|
||||
'export_triples',
|
||||
'export_graph_embeddings',
|
||||
'export_document_embeddings',
|
||||
'export_entity_contexts'
|
||||
]
|
||||
|
||||
for method in export_methods:
|
||||
assert hasattr(bulk, method), f"Missing method: {method}"
|
||||
|
||||
|
||||
class TestMetricsClient:
|
||||
"""Test metrics client"""
|
||||
|
||||
@patch('requests.get')
|
||||
def test_metrics_get(self, mock_get):
|
||||
"""Test getting metrics"""
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.text = "# HELP metric_name\nmetric_name 42"
|
||||
|
||||
api = Api(url="http://test/")
|
||||
metrics_text = api.metrics().get()
|
||||
|
||||
assert "metric_name" in metrics_text
|
||||
assert mock_get.called
|
||||
|
||||
@patch('requests.get')
|
||||
def test_metrics_with_token(self, mock_get):
|
||||
"""Test metrics request includes token"""
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.text = "metrics"
|
||||
|
||||
api = Api(url="http://test/", token="metrics-token")
|
||||
api.metrics().get()
|
||||
|
||||
# Verify token in headers
|
||||
call_args = mock_get.call_args
|
||||
headers = call_args[1].get('headers', {})
|
||||
assert 'Authorization' in headers
|
||||
|
||||
|
||||
class TestStreamingTypes:
|
||||
"""Test streaming chunk types"""
|
||||
|
||||
def test_agent_thought_creation(self):
|
||||
"""Test creating AgentThought chunk"""
|
||||
chunk = AgentThought(content="thinking...", end_of_message=False)
|
||||
|
||||
assert chunk.content == "thinking..."
|
||||
assert chunk.end_of_message is False
|
||||
assert chunk.chunk_type == "thought"
|
||||
|
||||
def test_agent_observation_creation(self):
|
||||
"""Test creating AgentObservation chunk"""
|
||||
chunk = AgentObservation(content="observing...", end_of_message=False)
|
||||
|
||||
assert chunk.content == "observing..."
|
||||
assert chunk.chunk_type == "observation"
|
||||
|
||||
def test_agent_answer_creation(self):
|
||||
"""Test creating AgentAnswer chunk"""
|
||||
chunk = AgentAnswer(
|
||||
content="answer",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True
|
||||
)
|
||||
|
||||
assert chunk.content == "answer"
|
||||
assert chunk.end_of_message is True
|
||||
assert chunk.end_of_dialog is True
|
||||
assert chunk.chunk_type == "final-answer"
|
||||
|
||||
def test_rag_chunk_creation(self):
|
||||
"""Test creating RAGChunk"""
|
||||
chunk = RAGChunk(
|
||||
content="response chunk",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
assert chunk.content == "response chunk"
|
||||
assert chunk.end_of_stream is False
|
||||
assert chunk.error is None
|
||||
|
||||
def test_rag_chunk_with_error(self):
|
||||
"""Test RAGChunk with error"""
|
||||
error_dict = {"type": "error", "message": "failed"}
|
||||
chunk = RAGChunk(
|
||||
content="",
|
||||
end_of_stream=True,
|
||||
error=error_dict
|
||||
)
|
||||
|
||||
assert chunk.error == error_dict
|
||||
|
||||
|
||||
class TestTripleType:
|
||||
"""Test Triple data structure"""
|
||||
|
||||
def test_triple_creation(self):
|
||||
"""Test creating Triple"""
|
||||
triple = Triple(s="subject", p="predicate", o="object")
|
||||
|
||||
assert triple.s == "subject"
|
||||
assert triple.p == "predicate"
|
||||
assert triple.o == "object"
|
||||
|
||||
def test_triple_with_uris(self):
|
||||
"""Test Triple with URI values"""
|
||||
triple = Triple(
|
||||
s="http://example.org/entity1",
|
||||
p="http://example.org/relation",
|
||||
o="http://example.org/entity2"
|
||||
)
|
||||
|
||||
assert triple.s.startswith("http://")
|
||||
assert triple.p.startswith("http://")
|
||||
assert triple.o.startswith("http://")
|
||||
|
||||
|
||||
class TestAsyncClients:
|
||||
"""Test async client availability"""
|
||||
|
||||
def test_async_flow_creation(self):
|
||||
"""Test creating async flow client"""
|
||||
api = Api(url="http://test/")
|
||||
async_flow = api.async_flow()
|
||||
|
||||
assert async_flow is not None
|
||||
|
||||
def test_async_socket_creation(self):
|
||||
"""Test creating async socket client"""
|
||||
api = Api(url="http://test/")
|
||||
async_socket = api.async_socket()
|
||||
|
||||
assert async_socket is not None
|
||||
assert async_socket.url.startswith("ws://")
|
||||
|
||||
def test_async_bulk_creation(self):
|
||||
"""Test creating async bulk client"""
|
||||
api = Api(url="http://test/")
|
||||
async_bulk = api.async_bulk()
|
||||
|
||||
assert async_bulk is not None
|
||||
|
||||
def test_async_metrics_creation(self):
|
||||
"""Test creating async metrics client"""
|
||||
api = Api(url="http://test/")
|
||||
async_metrics = api.async_metrics()
|
||||
|
||||
assert async_metrics is not None
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling"""
|
||||
|
||||
@patch('requests.post')
|
||||
def test_protocol_exception_on_non_200(self, mock_post):
|
||||
"""Test ProtocolException raised on non-200 status"""
|
||||
from trustgraph.api.exceptions import ProtocolException
|
||||
|
||||
mock_post.return_value.status_code = 500
|
||||
|
||||
api = Api(url="http://test/")
|
||||
|
||||
with pytest.raises(ProtocolException):
|
||||
api.flow().list()
|
||||
|
||||
@patch('requests.post')
|
||||
def test_application_exception_on_error_response(self, mock_post):
|
||||
"""Test ApplicationException on error in response"""
|
||||
from trustgraph.api.exceptions import ApplicationException
|
||||
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"error": {
|
||||
"type": "ValidationError",
|
||||
"message": "Invalid input"
|
||||
}
|
||||
}
|
||||
|
||||
api = Api(url="http://test/")
|
||||
|
||||
with pytest.raises(ApplicationException):
|
||||
api.flow().list()
|
||||
|
||||
|
||||
# Run tests with: pytest tests/unit/test_python_api_client.py -v
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -23,9 +23,9 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
|
||||
assert request.operation == "detect-type"
|
||||
assert request.sample == "test data"
|
||||
assert request.type is None # Optional, defaults to None
|
||||
assert request.schema_name is None # Optional, defaults to None
|
||||
assert request.options is None # Optional, defaults to None
|
||||
assert request.type == "" # Optional, defaults to empty string
|
||||
assert request.schema_name == "" # Optional, defaults to empty string
|
||||
assert request.options == {} # Optional, defaults to empty dict
|
||||
|
||||
def test_request_schema_all_operations(self):
|
||||
"""Test request schema supports all operations"""
|
||||
|
|
@ -66,9 +66,9 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
assert response.detected_type == "xml"
|
||||
assert response.confidence == 0.9
|
||||
assert response.error is None
|
||||
assert response.descriptor is None
|
||||
assert response.metadata is None
|
||||
assert response.schema_matches is None # New field, defaults to None
|
||||
assert response.descriptor == "" # Defaults to empty string
|
||||
assert response.metadata == {} # Defaults to empty dict
|
||||
assert response.schema_matches == [] # Defaults to empty list
|
||||
|
||||
def test_response_schema_with_error(self):
|
||||
"""Test response schema with error"""
|
||||
|
|
@ -140,6 +140,7 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
assert response.metadata == metadata
|
||||
assert response.metadata["field_count"] == "5"
|
||||
|
||||
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
|
||||
def test_schema_serialization(self):
|
||||
"""Test that schemas can be serialized and deserialized correctly"""
|
||||
# Test request serialization
|
||||
|
|
@ -158,6 +159,7 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
assert deserialized.sample == request.sample
|
||||
assert deserialized.options == request.options
|
||||
|
||||
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
|
||||
def test_response_serialization_with_schema_matches(self):
|
||||
"""Test response serialization with schema_matches array"""
|
||||
response = StructuredDataDiagnosisResponse(
|
||||
|
|
@ -185,7 +187,7 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
)
|
||||
|
||||
# Verify default value for new field
|
||||
assert response.schema_matches is None # Defaults to None when not set
|
||||
assert response.schema_matches == [] # Defaults to empty list when not set
|
||||
|
||||
# Verify old fields still work
|
||||
assert response.detected_type == "json"
|
||||
|
|
@ -221,7 +223,7 @@ class TestStructuredDiagnosisSchemaContract:
|
|||
)
|
||||
|
||||
assert error_response.error is not None
|
||||
assert error_response.schema_matches is None # Default None when not set
|
||||
assert error_response.schema_matches == [] # Default empty list when not set
|
||||
|
||||
def test_all_operations_supported(self):
|
||||
"""Verify all operations are properly supported in the contract"""
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TestMessageDispatcher:
|
|||
assert dispatcher.max_workers == 10
|
||||
assert dispatcher.semaphore._value == 10
|
||||
assert dispatcher.active_tasks == set()
|
||||
assert dispatcher.pulsar_client is None
|
||||
assert dispatcher.backend is None
|
||||
assert dispatcher.dispatcher_manager is None
|
||||
assert len(dispatcher.service_mapping) > 0
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ class TestMessageDispatcher:
|
|||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_backend = MagicMock()
|
||||
mock_config_receiver = MagicMock()
|
||||
mock_dispatcher_instance = MagicMock()
|
||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||
|
|
@ -94,14 +94,14 @@ class TestMessageDispatcher:
|
|||
dispatcher = MessageDispatcher(
|
||||
max_workers=8,
|
||||
config_receiver=mock_config_receiver,
|
||||
pulsar_client=mock_pulsar_client
|
||||
backend=mock_backend
|
||||
)
|
||||
|
||||
assert dispatcher.max_workers == 8
|
||||
assert dispatcher.pulsar_client == mock_pulsar_client
|
||||
assert dispatcher.backend == mock_backend
|
||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||
mock_dispatcher_manager.assert_called_once_with(
|
||||
mock_pulsar_client, mock_config_receiver, prefix="rev-gateway"
|
||||
mock_backend, mock_config_receiver, prefix="rev-gateway"
|
||||
)
|
||||
|
||||
def test_message_dispatcher_service_mapping(self):
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with default parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
|
@ -38,11 +38,11 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with custom parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri="wss://example.com:8080/websocket",
|
||||
|
|
@ -65,11 +65,11 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||
|
||||
|
|
@ -78,53 +78,49 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||
ReverseGateway(websocket_uri="http://example.com")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with missing hostname"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||
ReverseGateway(websocket_uri="ws://")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates Pulsar client with authentication"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
with patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth.return_value = mock_auth_instance
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_pulsar_client.assert_called_once_with(
|
||||
"pulsar://pulsar:6650",
|
||||
listener_name="test-listener",
|
||||
authentication=mock_auth_instance
|
||||
)
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates backend with authentication"""
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
# Verify get_pubsub was called with the correct parameters
|
||||
mock_get_pubsub.assert_called_once_with(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway successful connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_ws = AsyncMock()
|
||||
|
|
@ -142,13 +138,13 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway connection failure"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||
|
|
@ -162,12 +158,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway disconnect"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
|
@ -189,12 +185,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
|
@ -211,12 +207,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message with closed connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
|
@ -234,12 +230,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||
|
|
@ -263,12 +259,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message with invalid JSON"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
|
|
@ -285,12 +281,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with text message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
|
@ -318,12 +314,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with binary message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
|
@ -351,12 +347,12 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with close message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
|
@ -383,36 +379,36 @@ class TestReverseGateway:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway shutdown"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
|
||||
# Mock disconnect
|
||||
gateway.disconnect = AsyncMock()
|
||||
|
||||
|
||||
await gateway.shutdown()
|
||||
|
||||
|
||||
assert gateway.running is False
|
||||
mock_dispatcher_instance.shutdown.assert_called_once()
|
||||
gateway.disconnect.assert_called_once()
|
||||
mock_client_instance.close.assert_called_once()
|
||||
mock_backend.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway stop"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
|
@ -427,12 +423,12 @@ class TestReverseGatewayRun:
|
|||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.get_pubsub')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
mock_backend = MagicMock()
|
||||
mock_get_pubsub.return_value = mock_backend
|
||||
|
||||
mock_config_receiver_instance = AsyncMock()
|
||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||
|
|
|
|||
|
|
@ -15,11 +15,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
"""Test Qdrant document embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -34,9 +32,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
|
|
@ -45,11 +40,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -68,11 +61,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_document_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -87,7 +78,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
|
|
@ -121,11 +115,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -140,7 +132,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('multi_user', 'multi_collection')] = {}
|
||||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
|
|
@ -180,11 +175,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -199,7 +192,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||
|
||||
# Create mock message with chunk having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
|
|
@ -237,11 +233,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert point.payload['doc'] == 'multi-vector document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client):
|
||||
"""Test storing document embeddings skips empty chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -277,11 +271,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_collection_creation_when_not_exists(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test that writing to non-existent collection creates it lazily"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -297,6 +289,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('new_user', 'new_collection')] = {}
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
|
|
@ -326,11 +321,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_collection_creation_exception(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test that collection creation errors are propagated"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
# Simulate creation failure
|
||||
|
|
@ -348,6 +341,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('error_user', 'error_collection')] = {}
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
|
|
@ -364,12 +360,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
|
||||
async def test_collection_validation_on_write(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test collection validation checks collection exists before writing"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -385,6 +379,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('cache_user', 'cache_collection')] = {}
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
|
|
@ -428,11 +425,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_different_dimensions_different_collections(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test that different vector dimensions create different collections"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -448,6 +443,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('dim_user', 'dim_collection')] = {}
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
|
|
@ -482,11 +480,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_add_args_calls_parent(self, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
|
|
@ -502,11 +498,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test proper UTF-8 decoding of chunk text"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -521,7 +515,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('utf8_user', 'utf8_collection')] = {}
|
||||
|
||||
# Create mock message with UTF-8 encoded text
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'utf8_user'
|
||||
|
|
@ -546,11 +543,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
|
||||
"""Test handling of chunk decode exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -562,7 +557,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('decode_user', 'decode_collection')] = {}
|
||||
|
||||
# Create mock message with decode error
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'decode_user'
|
||||
|
|
|
|||
|
|
@ -15,11 +15,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
"""Test Qdrant graph embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_processor_initialization_basic(self, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -34,9 +32,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
|
|
@ -46,11 +41,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_graph_embeddings_basic(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -64,7 +57,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('test_user', 'test_collection')] = {}
|
||||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
|
|
@ -98,11 +94,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_graph_embeddings_multiple_entities(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple entities"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -116,7 +110,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('multi_user', 'multi_collection')] = {}
|
||||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
|
|
@ -156,11 +153,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
|
@ -174,7 +169,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Add collection to known_collections (simulates config push)
|
||||
processor.known_collections[('vector_user', 'vector_collection')] = {}
|
||||
|
||||
# Create mock message with entity having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
|
|
@ -212,11 +210,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert point.payload['entity'] == 'multi_vector_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
|
||||
"""Test storing graph embeddings skips empty entity values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -253,11 +249,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
|
@ -275,11 +269,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
async def test_add_args_calls_parent(self, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ dependencies = [
|
|||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"requests",
|
||||
"python-logging-loki",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
|||
|
|
@ -1,3 +1,114 @@
|
|||
|
||||
from . api import *
|
||||
# Core API
|
||||
from .api import Api
|
||||
|
||||
# Flow clients
|
||||
from .flow import Flow, FlowInstance
|
||||
from .async_flow import AsyncFlow, AsyncFlowInstance
|
||||
|
||||
# WebSocket clients
|
||||
from .socket_client import SocketClient, SocketFlowInstance
|
||||
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
|
||||
|
||||
# Bulk operation clients
|
||||
from .bulk_client import BulkClient
|
||||
from .async_bulk_client import AsyncBulkClient
|
||||
|
||||
# Metrics clients
|
||||
from .metrics import Metrics
|
||||
from .async_metrics import AsyncMetrics
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
Triple,
|
||||
ConfigKey,
|
||||
ConfigValue,
|
||||
DocumentMetadata,
|
||||
ProcessingMetadata,
|
||||
CollectionMetadata,
|
||||
StreamingChunk,
|
||||
AgentThought,
|
||||
AgentObservation,
|
||||
AgentAnswer,
|
||||
RAGChunk,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
ProtocolException,
|
||||
TrustGraphException,
|
||||
AgentError,
|
||||
ConfigError,
|
||||
DocumentRagError,
|
||||
FlowError,
|
||||
GatewayError,
|
||||
GraphRagError,
|
||||
LLMError,
|
||||
LoadError,
|
||||
LookupError,
|
||||
NLPQueryError,
|
||||
ObjectsQueryError,
|
||||
RequestError,
|
||||
StructuredQueryError,
|
||||
UnexpectedError,
|
||||
# Legacy alias
|
||||
ApplicationException,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core API
|
||||
"Api",
|
||||
|
||||
# Flow clients
|
||||
"Flow",
|
||||
"FlowInstance",
|
||||
"AsyncFlow",
|
||||
"AsyncFlowInstance",
|
||||
|
||||
# WebSocket clients
|
||||
"SocketClient",
|
||||
"SocketFlowInstance",
|
||||
"AsyncSocketClient",
|
||||
"AsyncSocketFlowInstance",
|
||||
|
||||
# Bulk operation clients
|
||||
"BulkClient",
|
||||
"AsyncBulkClient",
|
||||
|
||||
# Metrics clients
|
||||
"Metrics",
|
||||
"AsyncMetrics",
|
||||
|
||||
# Types
|
||||
"Triple",
|
||||
"ConfigKey",
|
||||
"ConfigValue",
|
||||
"DocumentMetadata",
|
||||
"ProcessingMetadata",
|
||||
"CollectionMetadata",
|
||||
"StreamingChunk",
|
||||
"AgentThought",
|
||||
"AgentObservation",
|
||||
"AgentAnswer",
|
||||
"RAGChunk",
|
||||
|
||||
# Exceptions
|
||||
"ProtocolException",
|
||||
"TrustGraphException",
|
||||
"AgentError",
|
||||
"ConfigError",
|
||||
"DocumentRagError",
|
||||
"FlowError",
|
||||
"GatewayError",
|
||||
"GraphRagError",
|
||||
"LLMError",
|
||||
"LoadError",
|
||||
"LookupError",
|
||||
"NLPQueryError",
|
||||
"ObjectsQueryError",
|
||||
"RequestError",
|
||||
"StructuredQueryError",
|
||||
"UnexpectedError",
|
||||
"ApplicationException", # Legacy alias
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import requests
|
|||
import json
|
||||
import base64
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from . library import Library
|
||||
from . flow import Flow
|
||||
|
|
@ -26,7 +27,7 @@ def check_error(response):
|
|||
|
||||
class Api:
|
||||
|
||||
def __init__(self, url="http://localhost:8088/", timeout=60):
|
||||
def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None):
|
||||
|
||||
self.url = url
|
||||
|
||||
|
|
@ -36,6 +37,16 @@ class Api:
|
|||
self.url += "api/v1/"
|
||||
|
||||
self.timeout = timeout
|
||||
self.token = token
|
||||
|
||||
# Lazy initialization for new clients
|
||||
self._socket_client = None
|
||||
self._bulk_client = None
|
||||
self._async_flow = None
|
||||
self._async_socket_client = None
|
||||
self._async_bulk_client = None
|
||||
self._metrics = None
|
||||
self._async_metrics = None
|
||||
|
||||
def flow(self):
|
||||
return Flow(api=self)
|
||||
|
|
@ -50,8 +61,12 @@ class Api:
|
|||
|
||||
url = f"{self.url}{path}"
|
||||
|
||||
headers = {}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
# Invoke the API, input is passed as JSON
|
||||
resp = requests.post(url, json=request, timeout=self.timeout)
|
||||
resp = requests.post(url, json=request, timeout=self.timeout, headers=headers)
|
||||
|
||||
# Should be a 200 status code
|
||||
if resp.status_code != 200:
|
||||
|
|
@ -72,3 +87,96 @@ class Api:
|
|||
|
||||
def collection(self):
|
||||
return Collection(self)
|
||||
|
||||
# New synchronous methods
|
||||
def socket(self):
|
||||
"""Synchronous WebSocket-based interface for streaming operations"""
|
||||
if self._socket_client is None:
|
||||
from . socket_client import SocketClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._socket_client = SocketClient(base_url, self.timeout, self.token)
|
||||
return self._socket_client
|
||||
|
||||
def bulk(self):
|
||||
"""Synchronous bulk operations interface for import/export"""
|
||||
if self._bulk_client is None:
|
||||
from . bulk_client import BulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
|
||||
return self._bulk_client
|
||||
|
||||
def metrics(self):
|
||||
"""Synchronous metrics interface"""
|
||||
if self._metrics is None:
|
||||
from . metrics import Metrics
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._metrics = Metrics(base_url, self.timeout, self.token)
|
||||
return self._metrics
|
||||
|
||||
# New asynchronous methods
|
||||
def async_flow(self):
|
||||
"""Asynchronous REST-based flow interface"""
|
||||
if self._async_flow is None:
|
||||
from . async_flow import AsyncFlow
|
||||
self._async_flow = AsyncFlow(self.url, self.timeout, self.token)
|
||||
return self._async_flow
|
||||
|
||||
def async_socket(self):
|
||||
"""Asynchronous WebSocket-based interface for streaming operations"""
|
||||
if self._async_socket_client is None:
|
||||
from . async_socket_client import AsyncSocketClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token)
|
||||
return self._async_socket_client
|
||||
|
||||
def async_bulk(self):
|
||||
"""Asynchronous bulk operations interface for import/export"""
|
||||
if self._async_bulk_client is None:
|
||||
from . async_bulk_client import AsyncBulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
|
||||
return self._async_bulk_client
|
||||
|
||||
def async_metrics(self):
|
||||
"""Asynchronous metrics interface"""
|
||||
if self._async_metrics is None:
|
||||
from . async_metrics import AsyncMetrics
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._async_metrics = AsyncMetrics(base_url, self.timeout, self.token)
|
||||
return self._async_metrics
|
||||
|
||||
# Resource management
|
||||
def close(self):
|
||||
"""Close all synchronous connections"""
|
||||
if self._socket_client:
|
||||
self._socket_client.close()
|
||||
if self._bulk_client:
|
||||
self._bulk_client.close()
|
||||
|
||||
async def aclose(self):
|
||||
"""Close all asynchronous connections"""
|
||||
if self._async_socket_client:
|
||||
await self._async_socket_client.aclose()
|
||||
if self._async_bulk_client:
|
||||
await self._async_bulk_client.aclose()
|
||||
if self._async_flow:
|
||||
await self._async_flow.aclose()
|
||||
|
||||
# Context manager support
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.aclose()
|
||||
|
|
|
|||
131
trustgraph-base/trustgraph/api/async_bulk_client.py
Normal file
131
trustgraph-base/trustgraph/api/async_bulk_client.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
|
||||
import json
|
||||
import websockets
|
||||
from typing import Optional, AsyncIterator, Dict, Any, Iterator
|
||||
|
||||
from . types import Triple
|
||||
|
||||
|
||||
class AsyncBulkClient:
|
||||
"""Asynchronous bulk operations client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
if url.startswith("http://"):
|
||||
return url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
return url.replace("https://", "wss://", 1)
|
||||
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||
return url
|
||||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
||||
"""Bulk import triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for triple in triples:
|
||||
message = {
|
||||
"s": triple.s,
|
||||
"p": triple.p,
|
||||
"o": triple.o
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
||||
"""Bulk export triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
data = json.loads(raw_message)
|
||||
yield Triple(
|
||||
s=data.get("s", ""),
|
||||
p=data.get("p", ""),
|
||||
o=data.get("o", "")
|
||||
)
|
||||
|
||||
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
await websocket.send(json.dumps(embedding))
|
||||
|
||||
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
await websocket.send(json.dumps(embedding))
|
||||
|
||||
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for context in contexts:
|
||||
await websocket.send(json.dumps(context))
|
||||
|
||||
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import objects via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for obj in objects:
|
||||
await websocket.send(json.dumps(obj))
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close connections"""
|
||||
# Cleanup handled by context managers
|
||||
pass
|
||||
245
trustgraph-base/trustgraph/api/async_flow.py
Normal file
245
trustgraph-base/trustgraph/api/async_flow.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
|
||||
import aiohttp
|
||||
import json
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
def check_error(response):
|
||||
if "error" in response:
|
||||
try:
|
||||
msg = response["error"]["message"]
|
||||
tp = response["error"]["type"]
|
||||
except:
|
||||
raise ApplicationException(response["error"])
|
||||
|
||||
raise ApplicationException(f"{tp}: {msg}")
|
||||
|
||||
|
||||
class AsyncFlow:
|
||||
"""Asynchronous REST-based flow interface"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = url
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
|
||||
async def request(self, path: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Make async HTTP request to Gateway API"""
|
||||
url = f"{self.url}{path}"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=request_data, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise ProtocolException(f"Status code {resp.status}")
|
||||
|
||||
try:
|
||||
obj = await resp.json()
|
||||
except:
|
||||
raise ProtocolException(f"Expected JSON response")
|
||||
|
||||
check_error(obj)
|
||||
return obj
|
||||
|
||||
async def list(self) -> List[str]:
|
||||
"""List all flows"""
|
||||
result = await self.request("flow", {"operation": "list-flows"})
|
||||
return result.get("flow-ids", [])
|
||||
|
||||
async def get(self, id: str) -> Dict[str, Any]:
|
||||
"""Get flow definition"""
|
||||
result = await self.request("flow", {
|
||||
"operation": "get-flow",
|
||||
"flow-id": id
|
||||
})
|
||||
return json.loads(result.get("flow", "{}"))
|
||||
|
||||
async def start(self, class_name: str, id: str, description: str, parameters: Optional[Dict] = None):
|
||||
"""Start a flow"""
|
||||
request_data = {
|
||||
"operation": "start-flow",
|
||||
"flow-id": id,
|
||||
"class-name": class_name,
|
||||
"description": description
|
||||
}
|
||||
if parameters:
|
||||
request_data["parameters"] = json.dumps(parameters)
|
||||
|
||||
await self.request("flow", request_data)
|
||||
|
||||
async def stop(self, id: str):
|
||||
"""Stop a flow"""
|
||||
await self.request("flow", {
|
||||
"operation": "stop-flow",
|
||||
"flow-id": id
|
||||
})
|
||||
|
||||
async def list_classes(self) -> List[str]:
|
||||
"""List flow classes"""
|
||||
result = await self.request("flow", {"operation": "list-classes"})
|
||||
return result.get("class-names", [])
|
||||
|
||||
async def get_class(self, class_name: str) -> Dict[str, Any]:
|
||||
"""Get flow class definition"""
|
||||
result = await self.request("flow", {
|
||||
"operation": "get-class",
|
||||
"class-name": class_name
|
||||
})
|
||||
return json.loads(result.get("class-definition", "{}"))
|
||||
|
||||
async def put_class(self, class_name: str, definition: Dict[str, Any]):
|
||||
"""Create/update flow class"""
|
||||
await self.request("flow", {
|
||||
"operation": "put-class",
|
||||
"class-name": class_name,
|
||||
"class-definition": json.dumps(definition)
|
||||
})
|
||||
|
||||
async def delete_class(self, class_name: str):
|
||||
"""Delete flow class"""
|
||||
await self.request("flow", {
|
||||
"operation": "delete-class",
|
||||
"class-name": class_name
|
||||
})
|
||||
|
||||
def id(self, flow_id: str):
|
||||
"""Get async flow instance"""
|
||||
return AsyncFlowInstance(self, flow_id)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close connection (cleanup handled by aiohttp session)"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncFlowInstance:
|
||||
"""Asynchronous REST flow instance"""
|
||||
|
||||
def __init__(self, flow: AsyncFlow, flow_id: str):
|
||||
self.flow = flow
|
||||
self.flow_id = flow_id
|
||||
|
||||
async def request(self, service: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Make request to flow-scoped service"""
|
||||
return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data)
|
||||
|
||||
async def agent(self, question: str, user: str, state: Optional[Dict] = None,
|
||||
group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Execute agent (non-streaming, use async_socket for streaming)"""
|
||||
request_data = {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"streaming": False # REST doesn't support streaming
|
||||
}
|
||||
if state is not None:
|
||||
request_data["state"] = state
|
||||
if group is not None:
|
||||
request_data["group"] = group
|
||||
if history is not None:
|
||||
request_data["history"] = history
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("agent", request_data)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
|
||||
"""Text completion (non-streaming, use async_socket for streaming)"""
|
||||
request_data = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
result = await self.request("text-completion", request_data)
|
||||
return result.get("response", "")
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3, **kwargs: Any) -> str:
|
||||
"""Graph RAG (non-streaming, use async_socket for streaming)"""
|
||||
request_data = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
result = await self.request("graph-rag", request_data)
|
||||
return result.get("response", "")
|
||||
|
||||
async def document_rag(self, query: str, user: str, collection: str,
|
||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||
"""Document RAG (non-streaming, use async_socket for streaming)"""
|
||||
request_data = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
result = await self.request("document-rag", request_data)
|
||||
return result.get("response", "")
|
||||
|
||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
request_data = {
|
||||
"text": text,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("graph-embeddings", request_data)
|
||||
|
||||
async def embeddings(self, text: str, **kwargs: Any):
|
||||
"""Generate text embeddings"""
|
||||
request_data = {"text": text}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any):
|
||||
"""Triple pattern query"""
|
||||
request_data = {"limit": limit}
|
||||
if s is not None:
|
||||
request_data["s"] = str(s)
|
||||
if p is not None:
|
||||
request_data["p"] = str(p)
|
||||
if o is not None:
|
||||
request_data["o"] = str(o)
|
||||
if user is not None:
|
||||
request_data["user"] = user
|
||||
if collection is not None:
|
||||
request_data["collection"] = collection
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("triples", request_data)
|
||||
|
||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs: Any):
|
||||
"""GraphQL query"""
|
||||
request_data = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
if variables:
|
||||
request_data["variables"] = variables
|
||||
if operation_name:
|
||||
request_data["operationName"] = operation_name
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("objects", request_data)
|
||||
33
trustgraph-base/trustgraph/api/async_metrics.py
Normal file
33
trustgraph-base/trustgraph/api/async_metrics.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
|
||||
import aiohttp
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class AsyncMetrics:
|
||||
"""Asynchronous metrics client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = url
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
|
||||
async def get(self) -> str:
|
||||
"""Get Prometheus metrics as text"""
|
||||
url: str = f"{self.url}/api/metrics"
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Status code {resp.status}")
|
||||
|
||||
return await resp.text()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close connections"""
|
||||
pass
|
||||
343
trustgraph-base/trustgraph/api/async_socket_client.py
Normal file
343
trustgraph-base/trustgraph/api/async_socket_client.py
Normal file
|
|
@ -0,0 +1,343 @@
|
|||
|
||||
import json
|
||||
import websockets
|
||||
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
class AsyncSocketClient:
|
||||
"""Asynchronous WebSocket client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]):
|
||||
self.url = self._convert_to_ws_url(url)
|
||||
self.timeout = timeout
|
||||
self.token = token
|
||||
self._request_counter = 0
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
if url.startswith("http://"):
|
||||
return url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
return url.replace("https://", "wss://", 1)
|
||||
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||
return url
|
||||
else:
|
||||
# Assume ws://
|
||||
return f"ws://{url}"
|
||||
|
||||
def flow(self, flow_id: str):
|
||||
"""Get async flow instance for WebSocket operations"""
|
||||
return AsyncSocketFlowInstance(self, flow_id)
|
||||
|
||||
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||
"""Async WebSocket request implementation (non-streaming)"""
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
# Wait for single response
|
||||
raw_message = await websocket.recv()
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
raise ProtocolException(f"Response ID mismatch")
|
||||
|
||||
if "error" in response:
|
||||
raise ApplicationException(response["error"])
|
||||
|
||||
if "response" not in response:
|
||||
raise ProtocolException(f"Missing response in message")
|
||||
|
||||
return response["response"]
|
||||
|
||||
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||
"""Async WebSocket request implementation (streaming)"""
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
# Yield chunks as they arrive
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
continue # Ignore messages for other requests
|
||||
|
||||
if "error" in response:
|
||||
raise ApplicationException(response["error"])
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final chunk
|
||||
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
break
|
||||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||
"""Parse response chunk into appropriate type"""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "observation":
|
||||
return AgentObservation(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
# Agent action chunks - treat as thoughts for display purposes
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
else:
|
||||
# RAG-style chunk (or generic chunk)
|
||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None # Errors are always thrown, never stored
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
"""Close WebSocket connection"""
|
||||
# Cleanup handled by context manager
|
||||
pass
|
||||
|
||||
|
||||
class AsyncSocketFlowInstance:
|
||||
"""Asynchronous WebSocket flow instance"""
|
||||
|
||||
def __init__(self, client: AsyncSocketClient, flow_id: str):
|
||||
self.client = client
|
||||
self.flow_id = flow_id
|
||||
|
||||
async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None,
|
||||
group: Optional[str] = None, history: Optional[list] = None,
|
||||
streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]:
|
||||
"""Agent with optional streaming"""
|
||||
request = {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"streaming": streaming
|
||||
}
|
||||
if state is not None:
|
||||
request["state"] = state
|
||||
if group is not None:
|
||||
request["group"] = group
|
||||
if history is not None:
|
||||
request["history"] = history
|
||||
request.update(kwargs)
|
||||
|
||||
if streaming:
|
||||
return self.client._send_request_streaming("agent", self.flow_id, request)
|
||||
else:
|
||||
return await self.client._send_request("agent", self.flow_id, request)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
|
||||
"""Text completion with optional streaming"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
if streaming:
|
||||
return self._text_completion_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("text-completion", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
|
||||
async def _text_completion_streaming(self, request):
|
||||
"""Helper for streaming text completion"""
|
||||
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3, streaming: bool = False, **kwargs):
|
||||
"""Graph RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
if streaming:
|
||||
return self._graph_rag_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("graph-rag", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
|
||||
async def _graph_rag_streaming(self, request):
|
||||
"""Helper for streaming graph RAG"""
|
||||
async for chunk in self.client._send_request_streaming("graph-rag", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
async def document_rag(self, query: str, user: str, collection: str,
|
||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
if streaming:
|
||||
return self._document_rag_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("document-rag", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
|
||||
async def _document_rag_streaming(self, request):
|
||||
"""Helper for streaming document RAG"""
|
||||
async for chunk in self.client._send_request_streaming("document-rag", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
async def prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs):
|
||||
"""Execute prompt with optional streaming"""
|
||||
request = {
|
||||
"id": id,
|
||||
"variables": variables,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
if streaming:
|
||||
return self._prompt_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("prompt", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
|
||||
async def _prompt_streaming(self, request):
|
||||
"""Helper for streaming prompt"""
|
||||
async for chunk in self.client._send_request_streaming("prompt", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
request = {
|
||||
"text": text,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("graph-embeddings", self.flow_id, request)
|
||||
|
||||
async def embeddings(self, text: str, **kwargs):
|
||||
"""Generate text embeddings"""
|
||||
request = {"text": text}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs):
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
request["collection"] = collection
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("triples", self.flow_id, request)
|
||||
|
||||
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
|
||||
operation_name: Optional[str] = None, **kwargs):
|
||||
"""GraphQL query"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
if variables:
|
||||
request["variables"] = variables
|
||||
if operation_name:
|
||||
request["operationName"] = operation_name
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("objects", self.flow_id, request)
|
||||
|
||||
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
|
||||
"""Execute MCP tool"""
|
||||
request = {
|
||||
"name": name,
|
||||
"parameters": parameters
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||
270
trustgraph-base/trustgraph/api/bulk_client.py
Normal file
270
trustgraph-base/trustgraph/api/bulk_client.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import websockets
|
||||
from typing import Optional, Iterator, Dict, Any, Coroutine
|
||||
|
||||
from . types import Triple
|
||||
from . exceptions import ProtocolException
|
||||
|
||||
|
||||
class BulkClient:
|
||||
"""Synchronous bulk operations client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
if url.startswith("http://"):
|
||||
return url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
return url.replace("https://", "wss://", 1)
|
||||
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||
return url
|
||||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
"""Run async coroutine synchronously"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None:
|
||||
"""Bulk import triples via WebSocket"""
|
||||
self._run_async(self._import_triples_async(flow, triples))
|
||||
|
||||
async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None:
|
||||
"""Async implementation of triple import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for triple in triples:
|
||||
message = {
|
||||
"s": triple.s,
|
||||
"p": triple.p,
|
||||
"o": triple.o
|
||||
}
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]:
|
||||
"""Bulk export triples via WebSocket"""
|
||||
async_gen = self._export_triples_async(flow)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
triple = loop.run_until_complete(async_gen.__anext__())
|
||||
yield triple
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
||||
"""Async implementation of triple export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
data = json.loads(raw_message)
|
||||
yield Triple(
|
||||
s=data.get("s", ""),
|
||||
p=data.get("p", ""),
|
||||
o=data.get("o", "")
|
||||
)
|
||||
|
||||
def import_graph_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import graph embeddings via WebSocket"""
|
||||
self._run_async(self._import_graph_embeddings_async(flow, embeddings))
|
||||
|
||||
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of graph embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
await websocket.send(json.dumps(embedding))
|
||||
|
||||
def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||
"""Bulk export graph embeddings via WebSocket"""
|
||||
async_gen = self._export_graph_embeddings_async(flow)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
embedding = loop.run_until_complete(async_gen.__anext__())
|
||||
yield embedding
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of graph embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
def import_document_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import document embeddings via WebSocket"""
|
||||
self._run_async(self._import_document_embeddings_async(flow, embeddings))
|
||||
|
||||
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of document embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
await websocket.send(json.dumps(embedding))
|
||||
|
||||
def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||
"""Bulk export document embeddings via WebSocket"""
|
||||
async_gen = self._export_document_embeddings_async(flow)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
embedding = loop.run_until_complete(async_gen.__anext__())
|
||||
yield embedding
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of document embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import entity contexts via WebSocket"""
|
||||
self._run_async(self._import_entity_contexts_async(flow, contexts))
|
||||
|
||||
async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of entity contexts import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for context in contexts:
|
||||
await websocket.send(json.dumps(context))
|
||||
|
||||
def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
|
||||
"""Bulk export entity contexts via WebSocket"""
|
||||
async_gen = self._export_entity_contexts_async(flow)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
context = loop.run_until_complete(async_gen.__anext__())
|
||||
yield context
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of entity contexts export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
yield json.loads(raw_message)
|
||||
|
||||
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import objects via WebSocket"""
|
||||
self._run_async(self._import_objects_async(flow, objects))
|
||||
|
||||
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of objects import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for obj in objects:
|
||||
await websocket.send(json.dumps(obj))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close connections"""
|
||||
# Cleanup handled by context managers
|
||||
pass
|
||||
|
|
@ -41,9 +41,7 @@ class Collection:
|
|||
collection = v["collection"],
|
||||
name = v["name"],
|
||||
description = v["description"],
|
||||
tags = v["tags"],
|
||||
created_at = v["created_at"],
|
||||
updated_at = v["updated_at"]
|
||||
tags = v["tags"]
|
||||
)
|
||||
for v in collections
|
||||
]
|
||||
|
|
@ -76,9 +74,7 @@ class Collection:
|
|||
collection = v["collection"],
|
||||
name = v["name"],
|
||||
description = v["description"],
|
||||
tags = v["tags"],
|
||||
created_at = v["created_at"],
|
||||
updated_at = v["updated_at"]
|
||||
tags = v["tags"]
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,134 @@
|
|||
"""
|
||||
TrustGraph API Exceptions
|
||||
|
||||
Exception hierarchy for errors returned by TrustGraph services.
|
||||
Each service error type maps to a specific exception class.
|
||||
"""
|
||||
|
||||
# Protocol-level exceptions (communication errors)
|
||||
class ProtocolException(Exception):
|
||||
"""Raised when WebSocket protocol errors occur"""
|
||||
pass
|
||||
|
||||
class ApplicationException(Exception):
|
||||
|
||||
# Base class for all TrustGraph application errors
|
||||
class TrustGraphException(Exception):
|
||||
"""Base class for all TrustGraph service errors"""
|
||||
def __init__(self, message: str, error_type: str = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
|
||||
|
||||
# Service-specific exceptions
|
||||
class AgentError(TrustGraphException):
|
||||
"""Agent service error"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(TrustGraphException):
|
||||
"""Configuration service error"""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentRagError(TrustGraphException):
|
||||
"""Document RAG retrieval error"""
|
||||
pass
|
||||
|
||||
|
||||
class FlowError(TrustGraphException):
|
||||
"""Flow management error"""
|
||||
pass
|
||||
|
||||
|
||||
class GatewayError(TrustGraphException):
|
||||
"""API Gateway error"""
|
||||
pass
|
||||
|
||||
|
||||
class GraphRagError(TrustGraphException):
|
||||
"""Graph RAG retrieval error"""
|
||||
pass
|
||||
|
||||
|
||||
class LLMError(TrustGraphException):
|
||||
"""LLM service error"""
|
||||
pass
|
||||
|
||||
|
||||
class LoadError(TrustGraphException):
|
||||
"""Data loading error"""
|
||||
pass
|
||||
|
||||
|
||||
class LookupError(TrustGraphException):
|
||||
"""Lookup/search error"""
|
||||
pass
|
||||
|
||||
|
||||
class NLPQueryError(TrustGraphException):
|
||||
"""NLP query service error"""
|
||||
pass
|
||||
|
||||
|
||||
class ObjectsQueryError(TrustGraphException):
|
||||
"""Objects query service error"""
|
||||
pass
|
||||
|
||||
|
||||
class RequestError(TrustGraphException):
|
||||
"""Request processing error"""
|
||||
pass
|
||||
|
||||
|
||||
class StructuredQueryError(TrustGraphException):
|
||||
"""Structured query service error"""
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedError(TrustGraphException):
|
||||
"""Unexpected/unknown error"""
|
||||
pass
|
||||
|
||||
|
||||
# Mapping from error type string to exception class
|
||||
ERROR_TYPE_MAPPING = {
|
||||
"agent-error": AgentError,
|
||||
"config-error": ConfigError,
|
||||
"document-rag-error": DocumentRagError,
|
||||
"flow-error": FlowError,
|
||||
"gateway-error": GatewayError,
|
||||
"graph-rag-error": GraphRagError,
|
||||
"llm-error": LLMError,
|
||||
"load-error": LoadError,
|
||||
"lookup-error": LookupError,
|
||||
"nlp-query-error": NLPQueryError,
|
||||
"objects-query-error": ObjectsQueryError,
|
||||
"request-error": RequestError,
|
||||
"structured-query-error": StructuredQueryError,
|
||||
"unexpected-error": UnexpectedError,
|
||||
}
|
||||
|
||||
|
||||
def raise_from_error_dict(error_dict: dict) -> None:
|
||||
"""
|
||||
Raise appropriate exception from TrustGraph error dictionary.
|
||||
|
||||
Args:
|
||||
error_dict: Dictionary with 'type' and 'message' keys
|
||||
|
||||
Raises:
|
||||
Appropriate TrustGraphException subclass based on error type
|
||||
"""
|
||||
error_type = error_dict.get("type", "unexpected-error")
|
||||
message = error_dict.get("message", "Unknown error")
|
||||
|
||||
# Look up exception class, default to UnexpectedError
|
||||
exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError)
|
||||
|
||||
# Raise the appropriate exception
|
||||
raise exception_class(message, error_type)
|
||||
|
||||
|
||||
# Legacy exception for backwards compatibility
|
||||
ApplicationException = TrustGraphException
|
||||
|
|
|
|||
|
|
@ -160,14 +160,14 @@ class FlowInstance:
|
|||
)["answer"]
|
||||
|
||||
def graph_rag(
|
||||
self, question, user="trustgraph", collection="default",
|
||||
self, query, user="trustgraph", collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
):
|
||||
|
||||
# The input consists of a question
|
||||
input = {
|
||||
"query": question,
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
|
|
@ -182,13 +182,13 @@ class FlowInstance:
|
|||
)["response"]
|
||||
|
||||
def document_rag(
|
||||
self, question, user="trustgraph", collection="default",
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=10,
|
||||
):
|
||||
|
||||
# The input consists of a question
|
||||
input = {
|
||||
"query": question,
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
|
|
@ -211,6 +211,21 @@ class FlowInstance:
|
|||
input
|
||||
)["vectors"]
|
||||
|
||||
def graph_embeddings_query(self, text, user, collection, limit=10):
|
||||
|
||||
# Query graph embeddings for semantic search
|
||||
input = {
|
||||
"text": text,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/graph-embeddings",
|
||||
input
|
||||
)
|
||||
|
||||
def prompt(self, id, variables):
|
||||
|
||||
input = {
|
||||
|
|
|
|||
27
trustgraph-base/trustgraph/api/metrics.py
Normal file
27
trustgraph-base/trustgraph/api/metrics.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
|
||||
import requests
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class Metrics:
|
||||
"""Synchronous metrics client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = url
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
|
||||
def get(self) -> str:
|
||||
"""Get Prometheus metrics as text"""
|
||||
url: str = f"{self.url}/api/metrics"
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
resp = requests.get(url, timeout=self.timeout, headers=headers)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Status code {resp.status_code}")
|
||||
|
||||
return resp.text
|
||||
457
trustgraph-base/trustgraph/api/socket_client.py
Normal file
457
trustgraph-base/trustgraph/api/socket_client.py
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import websockets
|
||||
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||
from threading import Lock
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
|
||||
from . exceptions import ProtocolException, raise_from_error_dict
|
||||
|
||||
|
||||
class SocketClient:
|
||||
"""Synchronous WebSocket client (wraps async websockets library)"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
self._connection: Optional[Any] = None
|
||||
self._request_counter: int = 0
|
||||
self._lock: Lock = Lock()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
if url.startswith("http://"):
|
||||
return url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
return url.replace("https://", "wss://", 1)
|
||||
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||
return url
|
||||
else:
|
||||
# Assume ws://
|
||||
return f"ws://{url}"
|
||||
|
||||
def flow(self, flow_id: str) -> "SocketFlowInstance":
|
||||
"""Get flow instance for WebSocket operations"""
|
||||
return SocketFlowInstance(self, flow_id)
|
||||
|
||||
def _send_request_sync(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
streaming: bool = False
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||
"""Synchronous wrapper around async WebSocket communication"""
|
||||
# Create event loop if needed
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is running (e.g., in Jupyter), create new loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if streaming:
|
||||
# For streaming, we need to return an iterator
|
||||
# Create a generator that runs async code
|
||||
return self._streaming_generator(service, flow, request, loop)
|
||||
else:
|
||||
# For non-streaming, just run the async code and return result
|
||||
return loop.run_until_complete(self._send_request_async(service, flow, request))
|
||||
|
||||
def _streaming_generator(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any],
|
||||
loop: asyncio.AbstractEventLoop
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Generator that yields streaming chunks"""
|
||||
async_gen = self._send_request_async_streaming(service, flow, request)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = loop.run_until_complete(async_gen.__anext__())
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
# Clean up async generator
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
except:
|
||||
pass
|
||||
|
||||
async def _send_request_async(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Async implementation of WebSocket request (non-streaming)"""
|
||||
# Generate unique request ID
|
||||
with self._lock:
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
# Wait for single response
|
||||
raw_message = await websocket.recv()
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
raise ProtocolException(f"Response ID mismatch")
|
||||
|
||||
if "error" in response:
|
||||
raise_from_error_dict(response["error"])
|
||||
|
||||
if "response" not in response:
|
||||
raise ProtocolException(f"Missing response in message")
|
||||
|
||||
return response["response"]
|
||||
|
||||
async def _send_request_async_streaming(
|
||||
self,
|
||||
service: str,
|
||||
flow: Optional[str],
|
||||
request: Dict[str, Any]
|
||||
) -> Iterator[StreamingChunk]:
|
||||
"""Async implementation of WebSocket request (streaming)"""
|
||||
# Generate unique request ID
|
||||
with self._lock:
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
# Yield chunks as they arrive
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
continue # Ignore messages for other requests
|
||||
|
||||
if "error" in response:
|
||||
raise_from_error_dict(response["error"])
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
|
||||
# Check for errors in response chunks
|
||||
if "error" in resp:
|
||||
raise_from_error_dict(resp["error"])
|
||||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final chunk
|
||||
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
break
|
||||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk:
|
||||
"""Parse response chunk into appropriate type"""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
|
||||
if chunk_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "observation":
|
||||
return AgentObservation(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
# Agent action chunks - treat as thoughts for display purposes
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
else:
|
||||
# RAG-style chunk (or generic chunk)
|
||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None # Errors are always thrown, never stored
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close WebSocket connection"""
|
||||
# Cleanup handled by context manager in async code
|
||||
pass
|
||||
|
||||
|
||||
class SocketFlowInstance:
|
||||
"""Synchronous WebSocket flow instance with same interface as REST FlowInstance"""
|
||||
|
||||
def __init__(self, client: SocketClient, flow_id: str) -> None:
|
||||
self.client: SocketClient = client
|
||||
self.flow_id: str = flow_id
|
||||
|
||||
def agent(
|
||||
self,
|
||||
question: str,
|
||||
user: str,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
group: Optional[str] = None,
|
||||
history: Optional[List[Dict[str, Any]]] = None,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
|
||||
"""Agent with optional streaming"""
|
||||
request = {
|
||||
"question": question,
|
||||
"user": user,
|
||||
"streaming": streaming
|
||||
}
|
||||
if state is not None:
|
||||
request["state"] = state
|
||||
if group is not None:
|
||||
request["group"] = group
|
||||
if history is not None:
|
||||
request["history"] = history
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("agent", self.flow_id, request, streaming)
|
||||
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
|
||||
"""Text completion with optional streaming"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
# For text completion, yield just the content
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def graph_rag(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Graph RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def document_rag(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
id: str,
|
||||
variables: Dict[str, str],
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute prompt with optional streaming"""
|
||||
request = {
|
||||
"id": id,
|
||||
"variables": variables,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def graph_embeddings_query(
|
||||
self,
|
||||
text: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Query graph embeddings for semantic search"""
|
||||
request = {
|
||||
"text": text,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
|
||||
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Generate text embeddings"""
|
||||
request = {"text": text}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[str] = None,
|
||||
p: Optional[str] = None,
|
||||
o: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
collection: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
if s is not None:
|
||||
request["s"] = str(s)
|
||||
if p is not None:
|
||||
request["p"] = str(p)
|
||||
if o is not None:
|
||||
request["o"] = str(o)
|
||||
if user is not None:
|
||||
request["user"] = user
|
||||
if collection is not None:
|
||||
request["collection"] = collection
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("triples", self.flow_id, request, False)
|
||||
|
||||
def objects_query(
|
||||
self,
|
||||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
operation_name: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""GraphQL query"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection
|
||||
}
|
||||
if variables:
|
||||
request["variables"] = variables
|
||||
if operation_name:
|
||||
request["operationName"] = operation_name
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("objects", self.flow_id, request, False)
|
||||
|
||||
def mcp_tool(
|
||||
self,
|
||||
name: str,
|
||||
parameters: Dict[str, Any],
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute MCP tool"""
|
||||
request = {
|
||||
"name": name,
|
||||
"parameters": parameters
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import dataclasses
|
||||
import datetime
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Any
|
||||
from .. knowledge import hash, Uri, Literal
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -49,5 +49,34 @@ class CollectionMetadata:
|
|||
name : str
|
||||
description : str
|
||||
tags : List[str]
|
||||
created_at : str
|
||||
updated_at : str
|
||||
|
||||
# Streaming chunk types
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StreamingChunk:
|
||||
"""Base class for streaming chunks"""
|
||||
content: str
|
||||
end_of_message: bool = False
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentThought(StreamingChunk):
|
||||
"""Agent reasoning chunk"""
|
||||
chunk_type: str = "thought"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentObservation(StreamingChunk):
|
||||
"""Agent tool observation chunk"""
|
||||
chunk_type: str = "observation"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentAnswer(StreamingChunk):
|
||||
"""Agent final answer chunk"""
|
||||
chunk_type: str = "final-answer"
|
||||
end_of_dialog: bool = False
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RAGChunk(StreamingChunk):
|
||||
"""RAG streaming chunk"""
|
||||
chunk_type: str = "rag"
|
||||
end_of_stream: bool = False
|
||||
error: Optional[Dict[str, str]] = None
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
|
||||
from . pubsub import PulsarClient
|
||||
from . pubsub import PulsarClient, get_pubsub
|
||||
from . async_processor import AsyncProcessor
|
||||
from . consumer import Consumer
|
||||
from . producer import Producer
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||
from . logging import add_logging_args, setup_logging
|
||||
from . flow_processor import FlowProcessor
|
||||
from . consumer_spec import ConsumerSpec
|
||||
from . parameter_spec import ParameterSpec
|
||||
|
|
@ -33,4 +34,5 @@ from . tool_service import ToolService
|
|||
from . tool_client import ToolClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -15,10 +15,11 @@ from prometheus_client import start_http_server, Info
|
|||
|
||||
from .. schema import ConfigPush, config_push_queue
|
||||
from .. log_level import LogLevel
|
||||
from . pubsub import PulsarClient
|
||||
from . pubsub import PulsarClient, get_pubsub
|
||||
from . producer import Producer
|
||||
from . consumer import Consumer
|
||||
from . metrics import ProcessorMetrics, ConsumerMetrics
|
||||
from . logging import add_logging_args, setup_logging
|
||||
|
||||
default_config_queue = config_push_queue
|
||||
|
||||
|
|
@ -33,8 +34,11 @@ class AsyncProcessor:
|
|||
# Store the identity
|
||||
self.id = params.get("id")
|
||||
|
||||
# Register a pulsar client
|
||||
self.pulsar_client_object = PulsarClient(**params)
|
||||
# Create pub/sub backend via factory
|
||||
self.pubsub_backend = get_pubsub(**params)
|
||||
|
||||
# Store pulsar_host for backward compatibility
|
||||
self._pulsar_host = params.get("pulsar_host", "pulsar://pulsar:6650")
|
||||
|
||||
# Initialise metrics, records the parameters
|
||||
ProcessorMetrics(processor = self.id).info({
|
||||
|
|
@ -69,7 +73,7 @@ class AsyncProcessor:
|
|||
self.config_sub_task = Consumer(
|
||||
|
||||
taskgroup = self.taskgroup,
|
||||
client = self.pulsar_client,
|
||||
backend = self.pubsub_backend, # Changed from client to backend
|
||||
subscriber = config_subscriber_id,
|
||||
flow = None,
|
||||
|
||||
|
|
@ -95,16 +99,16 @@ class AsyncProcessor:
|
|||
# This is called to stop all threads. An over-ride point for extra
|
||||
# functionality
|
||||
def stop(self):
|
||||
self.pulsar_client.close()
|
||||
self.pubsub_backend.close()
|
||||
self.running = False
|
||||
|
||||
# Returns the pulsar host
|
||||
# Returns the pub/sub backend (new interface)
|
||||
@property
|
||||
def pulsar_host(self): return self.pulsar_client_object.pulsar_host
|
||||
def pubsub(self): return self.pubsub_backend
|
||||
|
||||
# Returns the pulsar client
|
||||
# Returns the pulsar host (backward compatibility)
|
||||
@property
|
||||
def pulsar_client(self): return self.pulsar_client_object.client
|
||||
def pulsar_host(self): return self._pulsar_host
|
||||
|
||||
# Register a new event handler for configuration change
|
||||
def register_config_handler(self, handler):
|
||||
|
|
@ -165,18 +169,9 @@ class AsyncProcessor:
|
|||
raise e
|
||||
|
||||
@classmethod
|
||||
def setup_logging(cls, log_level='INFO'):
|
||||
def setup_logging(cls, args):
|
||||
"""Configure logging for the entire application"""
|
||||
# Support environment variable override
|
||||
env_log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', log_level)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, env_log_level.upper()),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()]
|
||||
)
|
||||
logger.info(f"Logging configured with level: {env_log_level}")
|
||||
setup_logging(args)
|
||||
|
||||
# Startup fabric. launch calls launch_async in async mode.
|
||||
@classmethod
|
||||
|
|
@ -203,7 +198,7 @@ class AsyncProcessor:
|
|||
args = vars(args)
|
||||
|
||||
# Setup logging before anything else
|
||||
cls.setup_logging(args.get('log_level', 'INFO').upper())
|
||||
cls.setup_logging(args)
|
||||
|
||||
# Debug
|
||||
logger.debug(f"Arguments: {args}")
|
||||
|
|
@ -255,12 +250,21 @@ class AsyncProcessor:
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
# Pub/sub backend selection
|
||||
parser.add_argument(
|
||||
'--pubsub-backend',
|
||||
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
||||
choices=['pulsar', 'mqtt'],
|
||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||
)
|
||||
|
||||
PulsarClient.add_args(parser)
|
||||
add_logging_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-queue',
|
||||
'--config-push-queue',
|
||||
default=default_config_queue,
|
||||
help=f'Config push queue {default_config_queue}',
|
||||
help=f'Config push queue (default: {default_config_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
148
trustgraph-base/trustgraph/base/backend.py
Normal file
148
trustgraph-base/trustgraph/base/backend.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Backend abstraction interfaces for pub/sub systems.
|
||||
|
||||
This module defines Protocol classes that all pub/sub backends must implement,
|
||||
allowing TrustGraph to work with different messaging systems (Pulsar, MQTT, Kafka, etc.)
|
||||
"""
|
||||
|
||||
from typing import Protocol, Any, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Message(Protocol):
|
||||
"""Protocol for a received message."""
|
||||
|
||||
def value(self) -> Any:
|
||||
"""
|
||||
Get the deserialized message content.
|
||||
|
||||
Returns:
|
||||
Dataclass instance representing the message
|
||||
"""
|
||||
...
|
||||
|
||||
def properties(self) -> dict:
|
||||
"""
|
||||
Get message properties/metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary of message properties
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BackendProducer(Protocol):
|
||||
"""Protocol for backend-specific producer."""
|
||||
|
||||
def send(self, message: Any, properties: dict = {}) -> None:
|
||||
"""
|
||||
Send a message (dataclass instance) with optional properties.
|
||||
|
||||
Args:
|
||||
message: Dataclass instance to send
|
||||
properties: Optional metadata properties
|
||||
"""
|
||||
...
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush any buffered messages."""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the producer."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BackendConsumer(Protocol):
|
||||
"""Protocol for backend-specific consumer."""
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""
|
||||
Receive a message from the topic.
|
||||
|
||||
Args:
|
||||
timeout_millis: Timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Message object
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no message received within timeout
|
||||
"""
|
||||
...
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
"""
|
||||
Acknowledge successful processing of a message.
|
||||
|
||||
Args:
|
||||
message: The message to acknowledge
|
||||
"""
|
||||
...
|
||||
|
||||
def negative_acknowledge(self, message: Message) -> None:
|
||||
"""
|
||||
Negative acknowledge - triggers redelivery.
|
||||
|
||||
Args:
|
||||
message: The message to negatively acknowledge
|
||||
"""
|
||||
...
|
||||
|
||||
def unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the topic."""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the consumer."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PubSubBackend(Protocol):
|
||||
"""Protocol defining the interface all pub/sub backends must implement."""
|
||||
|
||||
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||
"""
|
||||
Create a producer for a topic.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
schema: Dataclass type for messages
|
||||
**options: Backend-specific options (e.g., chunking_enabled)
|
||||
|
||||
Returns:
|
||||
Backend-specific producer instance
|
||||
"""
|
||||
...
|
||||
|
||||
def create_consumer(
|
||||
self,
|
||||
topic: str,
|
||||
subscription: str,
|
||||
schema: type,
|
||||
initial_position: str = 'latest',
|
||||
consumer_type: str = 'shared',
|
||||
**options
|
||||
) -> BackendConsumer:
|
||||
"""
|
||||
Create a consumer for a topic.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
subscription: Subscription/consumer group name
|
||||
schema: Dataclass type for messages
|
||||
initial_position: 'earliest' or 'latest' (some backends may ignore)
|
||||
consumer_type: 'shared', 'exclusive', 'failover' (some backends may ignore)
|
||||
**options: Backend-specific options
|
||||
|
||||
Returns:
|
||||
Backend-specific consumer instance
|
||||
"""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the backend connection."""
|
||||
...
|
||||
|
|
@ -13,14 +13,15 @@ from typing import Optional, Tuple, List, Any
|
|||
def get_cassandra_defaults() -> dict:
|
||||
"""
|
||||
Get default Cassandra configuration values from environment variables or fallback defaults.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with 'host', 'username', and 'password' keys
|
||||
dict: Dictionary with 'host', 'username', 'password', and 'keyspace' keys
|
||||
"""
|
||||
return {
|
||||
'host': os.getenv('CASSANDRA_HOST', 'cassandra'),
|
||||
'username': os.getenv('CASSANDRA_USERNAME'),
|
||||
'password': os.getenv('CASSANDRA_PASSWORD')
|
||||
'password': os.getenv('CASSANDRA_PASSWORD'),
|
||||
'keyspace': os.getenv('CASSANDRA_KEYSPACE')
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -53,82 +54,108 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
|
|||
password_help += " (default: <set>)"
|
||||
if 'CASSANDRA_PASSWORD' in os.environ:
|
||||
password_help += " [from CASSANDRA_PASSWORD]"
|
||||
|
||||
|
||||
keyspace_help = "Cassandra keyspace (default: service-specific)"
|
||||
if defaults['keyspace']:
|
||||
keyspace_help = f"Cassandra keyspace (default: {defaults['keyspace']})"
|
||||
if 'CASSANDRA_KEYSPACE' in os.environ:
|
||||
keyspace_help += " [from CASSANDRA_KEYSPACE]"
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-host',
|
||||
default=defaults['host'],
|
||||
help=host_help
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-username',
|
||||
default=defaults['username'],
|
||||
help=username_help
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-password',
|
||||
default=defaults['password'],
|
||||
help=password_help
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--cassandra-keyspace',
|
||||
default=defaults['keyspace'],
|
||||
help=keyspace_help
|
||||
)
|
||||
|
||||
|
||||
def resolve_cassandra_config(
|
||||
args: Optional[Any] = None,
|
||||
host: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None
|
||||
) -> Tuple[List[str], Optional[str], Optional[str]]:
|
||||
password: Optional[str] = None,
|
||||
default_keyspace: Optional[str] = None
|
||||
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Resolve Cassandra configuration from various sources.
|
||||
|
||||
|
||||
Can accept either argparse args object or explicit parameters.
|
||||
Converts host string to list format for Cassandra driver.
|
||||
|
||||
|
||||
Args:
|
||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password
|
||||
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace
|
||||
host: Optional explicit host parameter (overrides args)
|
||||
username: Optional explicit username parameter (overrides args)
|
||||
password: Optional explicit password parameter (overrides args)
|
||||
|
||||
default_keyspace: Optional default keyspace if not specified elsewhere
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_list, username, password)
|
||||
tuple: (hosts_list, username, password, keyspace)
|
||||
"""
|
||||
# If args provided, extract values
|
||||
keyspace = None
|
||||
if args is not None:
|
||||
host = host or getattr(args, 'cassandra_host', None)
|
||||
username = username or getattr(args, 'cassandra_username', None)
|
||||
password = password or getattr(args, 'cassandra_password', None)
|
||||
|
||||
keyspace = getattr(args, 'cassandra_keyspace', None)
|
||||
|
||||
# Apply defaults if still None
|
||||
defaults = get_cassandra_defaults()
|
||||
host = host or defaults['host']
|
||||
username = username or defaults['username']
|
||||
password = password or defaults['password']
|
||||
|
||||
keyspace = keyspace or defaults['keyspace'] or default_keyspace
|
||||
|
||||
# Convert host string to list
|
||||
if isinstance(host, str):
|
||||
hosts = [h.strip() for h in host.split(',') if h.strip()]
|
||||
else:
|
||||
hosts = host
|
||||
|
||||
return hosts, username, password
|
||||
|
||||
return hosts, username, password, keyspace
|
||||
|
||||
|
||||
def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[str], Optional[str]]:
|
||||
def get_cassandra_config_from_params(
|
||||
params: dict,
|
||||
default_keyspace: Optional[str] = None
|
||||
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract and resolve Cassandra configuration from a parameters dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters that may contain Cassandra configuration
|
||||
|
||||
default_keyspace: Optional default keyspace if not specified in params
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_list, username, password)
|
||||
tuple: (hosts_list, username, password, keyspace)
|
||||
"""
|
||||
# Get Cassandra parameters
|
||||
host = params.get('cassandra_host')
|
||||
username = params.get('cassandra_username')
|
||||
password = params.get('cassandra_password')
|
||||
|
||||
|
||||
# Use resolve function to handle defaults and list conversion
|
||||
return resolve_cassandra_config(host=host, username=username, password=password)
|
||||
return resolve_cassandra_config(
|
||||
host=host,
|
||||
username=username,
|
||||
password=password,
|
||||
default_keyspace=default_keyspace
|
||||
)
|
||||
128
trustgraph-base/trustgraph/base/collection_config_handler.py
Normal file
128
trustgraph-base/trustgraph/base/collection_config_handler.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
Handler for storage services to process collection configuration from config push
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CollectionConfigHandler:
|
||||
"""
|
||||
Handles collection configuration from config push messages for storage services.
|
||||
|
||||
Storage services should:
|
||||
1. Inherit from this class along with their service base class
|
||||
2. Call register_config_handler(self.on_collection_config) in __init__
|
||||
3. Implement create_collection(user, collection, metadata) method
|
||||
4. Implement delete_collection(user, collection) method
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Track known collections: {(user, collection): metadata_dict}
|
||||
self.known_collections: Dict[tuple, dict] = {}
|
||||
# Pass remaining kwargs up the inheritance chain
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def on_collection_config(self, config: dict, version: int):
|
||||
"""
|
||||
Handle config push messages and extract collection information
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary from ConfigPush message
|
||||
version: Configuration version number
|
||||
"""
|
||||
logger.info(f"Processing collection configuration (version {version})")
|
||||
|
||||
# Extract collections from config (treat missing key as empty)
|
||||
collection_config = config.get("collection", {})
|
||||
|
||||
# Track which collections we've seen in this config
|
||||
current_collections: Set[tuple] = set()
|
||||
|
||||
# Process each collection in the config
|
||||
for key, value_json in collection_config.items():
|
||||
try:
|
||||
# Parse user:collection key
|
||||
if ":" not in key:
|
||||
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
|
||||
continue
|
||||
|
||||
user, collection = key.split(":", 1)
|
||||
current_collections.add((user, collection))
|
||||
|
||||
# Parse metadata
|
||||
metadata = json.loads(value_json)
|
||||
|
||||
# Check if this is a new collection or updated
|
||||
collection_key = (user, collection)
|
||||
if collection_key not in self.known_collections:
|
||||
logger.info(f"New collection detected: {user}/{collection}")
|
||||
await self.create_collection(user, collection, metadata)
|
||||
self.known_collections[collection_key] = metadata
|
||||
else:
|
||||
# Collection already exists, update metadata if changed
|
||||
if self.known_collections[collection_key] != metadata:
|
||||
logger.info(f"Collection metadata updated: {user}/{collection}")
|
||||
# Most storage services don't need to do anything for metadata updates
|
||||
# They just need to know the collection exists
|
||||
self.known_collections[collection_key] = metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
|
||||
|
||||
# Find collections that were deleted (in known but not in current)
|
||||
deleted_collections = set(self.known_collections.keys()) - current_collections
|
||||
for user, collection in deleted_collections:
|
||||
logger.info(f"Collection deleted: {user}/{collection}")
|
||||
try:
|
||||
# Remove from known_collections FIRST to immediately reject new writes
|
||||
# This eliminates race condition with worker threads
|
||||
del self.known_collections[(user, collection)]
|
||||
# Physical deletion happens after - worker threads already rejecting writes
|
||||
await self.delete_collection(user, collection)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
|
||||
# If physical deletion failed, should we re-add to known_collections?
|
||||
# For now, keep it removed - collection is logically deleted per config
|
||||
|
||||
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
|
||||
|
||||
async def create_collection(self, user: str, collection: str, metadata: dict):
|
||||
"""
|
||||
Create a collection in the storage backend.
|
||||
|
||||
Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
collection: Collection ID
|
||||
metadata: Collection metadata dictionary
|
||||
"""
|
||||
raise NotImplementedError("Storage service must implement create_collection method")
|
||||
|
||||
async def delete_collection(self, user: str, collection: str):
|
||||
"""
|
||||
Delete a collection from the storage backend.
|
||||
|
||||
Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
collection: Collection ID
|
||||
"""
|
||||
raise NotImplementedError("Storage service must implement delete_collection method")
|
||||
|
||||
def collection_exists(self, user: str, collection: str) -> bool:
|
||||
"""
|
||||
Check if a collection is known to exist
|
||||
|
||||
Args:
|
||||
user: User ID
|
||||
collection: Collection ID
|
||||
|
||||
Returns:
|
||||
True if collection exists, False otherwise
|
||||
"""
|
||||
return (user, collection) in self.known_collections
|
||||
|
|
@ -9,9 +9,6 @@
|
|||
# one handler, and a single thread of concurrency, nothing too outrageous
|
||||
# will happen if synchronous / blocking code is used
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
import pulsar
|
||||
import _pulsar
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
|
@ -21,11 +18,15 @@ from .. exceptions import TooManyRequests
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Timeout exception - can come from different backends
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
class Consumer:
|
||||
|
||||
def __init__(
|
||||
self, taskgroup, flow, client, topic, subscriber, schema,
|
||||
handler,
|
||||
self, taskgroup, flow, backend, topic, subscriber, schema,
|
||||
handler,
|
||||
metrics = None,
|
||||
start_of_messages=False,
|
||||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||
|
|
@ -35,7 +36,7 @@ class Consumer:
|
|||
|
||||
self.taskgroup = taskgroup
|
||||
self.flow = flow
|
||||
self.client = client
|
||||
self.backend = backend # Changed from 'client' to 'backend'
|
||||
self.topic = topic
|
||||
self.subscriber = subscriber
|
||||
self.schema = schema
|
||||
|
|
@ -96,18 +97,20 @@ class Consumer:
|
|||
|
||||
logger.info(f"Subscribing to topic: {self.topic}")
|
||||
|
||||
# Determine initial position
|
||||
if self.start_of_messages:
|
||||
pos = pulsar.InitialPosition.Earliest
|
||||
initial_pos = 'earliest'
|
||||
else:
|
||||
pos = pulsar.InitialPosition.Latest
|
||||
initial_pos = 'latest'
|
||||
|
||||
# Create consumer via backend
|
||||
self.consumer = await asyncio.to_thread(
|
||||
self.client.subscribe,
|
||||
self.backend.create_consumer,
|
||||
topic = self.topic,
|
||||
subscription_name = self.subscriber,
|
||||
schema = JsonSchema(self.schema),
|
||||
initial_position = pos,
|
||||
consumer_type = pulsar.ConsumerType.Shared,
|
||||
subscription = self.subscriber,
|
||||
schema = self.schema,
|
||||
initial_position = initial_pos,
|
||||
consumer_type = 'shared',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -159,9 +162,10 @@ class Consumer:
|
|||
self.consumer.receive,
|
||||
timeout_millis=2000
|
||||
)
|
||||
except _pulsar.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
# Handle timeout from any backend
|
||||
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
|
||||
continue
|
||||
raise e
|
||||
|
||||
await self.handle_one_from_queue(msg)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class ConsumerSpec(Spec):
|
|||
consumer = Consumer(
|
||||
taskgroup = processor.taskgroup,
|
||||
flow = flow,
|
||||
client = processor.pulsar_client,
|
||||
backend = processor.pubsub,
|
||||
topic = definition[self.name],
|
||||
subscriber = processor.id + "--" + flow.name + "--" + self.name,
|
||||
schema = self.schema,
|
||||
|
|
|
|||
159
trustgraph-base/trustgraph/base/logging.py
Normal file
159
trustgraph-base/trustgraph/base/logging.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
|
||||
"""
|
||||
Centralized logging configuration for TrustGraph server-side components.
|
||||
|
||||
This module provides standardized logging setup across all TrustGraph services,
|
||||
ensuring consistent log formats, levels, and command-line arguments.
|
||||
|
||||
Supports dual output to console and Loki for centralized log aggregation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
from queue import Queue
|
||||
import os
|
||||
|
||||
|
||||
def add_logging_args(parser):
|
||||
"""
|
||||
Add standard logging arguments to an argument parser.
|
||||
|
||||
Args:
|
||||
parser: argparse.ArgumentParser instance to add arguments to
|
||||
"""
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
default='INFO',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
help='Log level (default: INFO)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--loki-enabled',
|
||||
action='store_true',
|
||||
default=True,
|
||||
help='Enable Loki logging (default: True)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-loki-enabled',
|
||||
dest='loki_enabled',
|
||||
action='store_false',
|
||||
help='Disable Loki logging'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--loki-url',
|
||||
default=os.getenv('LOKI_URL', 'http://loki:3100/loki/api/v1/push'),
|
||||
help='Loki push URL (default: http://loki:3100/loki/api/v1/push)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--loki-username',
|
||||
default=os.getenv('LOKI_USERNAME', None),
|
||||
help='Loki username for authentication (optional)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--loki-password',
|
||||
default=os.getenv('LOKI_PASSWORD', None),
|
||||
help='Loki password for authentication (optional)'
|
||||
)
|
||||
|
||||
|
||||
def setup_logging(args):
|
||||
"""
|
||||
Configure logging from parsed command-line arguments.
|
||||
|
||||
Sets up logging with a standardized format and output to stdout.
|
||||
Optionally enables Loki integration for centralized log aggregation.
|
||||
|
||||
This should be called early in application startup, before any
|
||||
logging calls are made.
|
||||
|
||||
Args:
|
||||
args: Dictionary of parsed arguments (typically from vars(args))
|
||||
Must contain 'log_level' key, optional Loki configuration
|
||||
"""
|
||||
log_level = args.get('log_level', 'INFO')
|
||||
loki_enabled = args.get('loki_enabled', True)
|
||||
|
||||
# Build list of handlers starting with console
|
||||
handlers = [logging.StreamHandler()]
|
||||
|
||||
# Add Loki handler if enabled
|
||||
queue_listener = None
|
||||
if loki_enabled:
|
||||
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
|
||||
loki_username = args.get('loki_username')
|
||||
loki_password = args.get('loki_password')
|
||||
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
|
||||
|
||||
try:
|
||||
from logging_loki import LokiHandler
|
||||
|
||||
# Create Loki handler with optional authentication and processor label
|
||||
loki_handler_kwargs = {
|
||||
'url': loki_url,
|
||||
'version': "1",
|
||||
}
|
||||
|
||||
if loki_username and loki_password:
|
||||
loki_handler_kwargs['auth'] = (loki_username, loki_password)
|
||||
|
||||
# Add processor label if available (for consistency with Prometheus metrics)
|
||||
if processor_id:
|
||||
loki_handler_kwargs['tags'] = {'processor': processor_id}
|
||||
|
||||
loki_handler = LokiHandler(**loki_handler_kwargs)
|
||||
|
||||
# Wrap in QueueHandler for non-blocking operation
|
||||
log_queue = Queue(maxsize=500)
|
||||
queue_handler = logging.handlers.QueueHandler(log_queue)
|
||||
handlers.append(queue_handler)
|
||||
|
||||
# Start QueueListener in background thread
|
||||
queue_listener = logging.handlers.QueueListener(
|
||||
log_queue,
|
||||
loki_handler,
|
||||
respect_handler_level=True
|
||||
)
|
||||
queue_listener.start()
|
||||
|
||||
# Store listener reference for potential cleanup
|
||||
# (attached to root logger for access if needed)
|
||||
logging.getLogger().loki_queue_listener = queue_listener
|
||||
|
||||
except ImportError:
|
||||
# Graceful degradation if python-logging-loki not installed
|
||||
print("WARNING: python-logging-loki not installed, Loki logging disabled")
|
||||
print("Install with: pip install python-logging-loki")
|
||||
except Exception as e:
|
||||
# Graceful degradation if Loki connection fails
|
||||
print(f"WARNING: Failed to setup Loki logging: {e}")
|
||||
print("Continuing with console-only logging")
|
||||
|
||||
# Get processor ID for log formatting (use 'unknown' if not available)
|
||||
processor_id = args.get('id', 'unknown')
|
||||
|
||||
# Configure logging with all handlers
|
||||
# Use processor ID as the primary identifier in logs
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration if already configured
|
||||
)
|
||||
|
||||
# Prevent recursive logging from Loki's HTTP client
|
||||
if loki_enabled and queue_listener:
|
||||
# Disable urllib3 logging to prevent infinite loop
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging configured with level: {log_level}")
|
||||
if loki_enabled and queue_listener:
|
||||
logger.info(f"Loki logging enabled: {loki_url}")
|
||||
elif loki_enabled:
|
||||
logger.warning("Loki logging requested but not available")
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
|
|
@ -8,10 +7,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Producer:
|
||||
|
||||
def __init__(self, client, topic, schema, metrics=None,
|
||||
def __init__(self, backend, topic, schema, metrics=None,
|
||||
chunking_enabled=True):
|
||||
|
||||
self.client = client
|
||||
self.backend = backend # Changed from 'client' to 'backend'
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
|
||||
|
|
@ -44,9 +43,9 @@ class Producer:
|
|||
|
||||
try:
|
||||
logger.info(f"Connecting publisher to {self.topic}...")
|
||||
self.producer = self.client.create_producer(
|
||||
self.producer = self.backend.create_producer(
|
||||
topic = self.topic,
|
||||
schema = JsonSchema(self.schema),
|
||||
schema = self.schema,
|
||||
chunking_enabled = self.chunking_enabled,
|
||||
)
|
||||
logger.info(f"Connected publisher to {self.topic}")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class ProducerSpec(Spec):
|
|||
)
|
||||
|
||||
producer = Producer(
|
||||
client = processor.pulsar_client,
|
||||
backend = processor.pubsub,
|
||||
topic = definition[self.name],
|
||||
schema = self.schema,
|
||||
metrics = producer_metrics,
|
||||
|
|
|
|||
|
|
@ -37,33 +37,34 @@ class PromptClient(RequestResponse):
|
|||
|
||||
else:
|
||||
logger.info("DEBUG prompt_client: Streaming path")
|
||||
# Streaming path - collect all chunks
|
||||
full_text = ""
|
||||
full_object = None
|
||||
# Streaming path - just forward chunks, don't accumulate
|
||||
last_text = ""
|
||||
last_object = None
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_text, full_object
|
||||
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||
async def forward_chunks(resp):
|
||||
nonlocal last_text, last_object
|
||||
logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
|
||||
|
||||
if resp.error:
|
||||
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text:
|
||||
full_text += resp.text
|
||||
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars")
|
||||
# Call chunk callback if provided
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
# Always call callback if there's text OR if it's the final message
|
||||
if resp.text is not None:
|
||||
last_text = resp.text
|
||||
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||
if chunk_callback:
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
|
||||
logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text)
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text)
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
logger.info(f"DEBUG prompt_client: Got object response")
|
||||
full_object = resp.object
|
||||
last_object = resp.object
|
||||
|
||||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
|
||||
return end_stream
|
||||
|
||||
|
|
@ -79,17 +80,17 @@ class PromptClient(RequestResponse):
|
|||
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
|
||||
await self.request(
|
||||
req,
|
||||
recipient=collect_chunks,
|
||||
recipient=forward_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars")
|
||||
logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
|
||||
|
||||
if full_text:
|
||||
logger.info("DEBUG prompt_client: Returning full_text")
|
||||
return full_text
|
||||
if last_text:
|
||||
logger.info("DEBUG prompt_client: Returning last_text")
|
||||
return last_text
|
||||
|
||||
logger.info("DEBUG prompt_client: Returning parsed full_object")
|
||||
return json.loads(full_object)
|
||||
logger.info("DEBUG prompt_client: Returning parsed last_object")
|
||||
return json.loads(last_object) if last_object else None
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import pulsar
|
||||
import logging
|
||||
|
||||
# Module logger
|
||||
|
|
@ -11,9 +8,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class Publisher:
|
||||
|
||||
def __init__(self, client, topic, schema=None, max_size=10,
|
||||
def __init__(self, backend, topic, schema=None, max_size=10,
|
||||
chunking_enabled=True, drain_timeout=5.0):
|
||||
self.client = client
|
||||
self.backend = backend # Changed from 'client' to 'backend'
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
|
|
@ -47,9 +44,9 @@ class Publisher:
|
|||
|
||||
try:
|
||||
|
||||
producer = self.client.create_producer(
|
||||
producer = self.backend.create_producer(
|
||||
topic=self.topic,
|
||||
schema=JsonSchema(self.schema),
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,45 @@ import pulsar
|
|||
import _pulsar
|
||||
import uuid
|
||||
from pulsar.schema import JsonSchema
|
||||
import logging
|
||||
|
||||
from .. log_level import LogLevel
|
||||
from .pulsar_backend import PulsarBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_pubsub(**config):
|
||||
"""
|
||||
Factory function to create a pub/sub backend based on configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary from command-line args
|
||||
Must include 'pubsub_backend' key
|
||||
|
||||
Returns:
|
||||
Backend instance (PulsarBackend, MQTTBackend, etc.)
|
||||
|
||||
Example:
|
||||
backend = get_pubsub(
|
||||
pubsub_backend='pulsar',
|
||||
pulsar_host='pulsar://localhost:6650'
|
||||
)
|
||||
"""
|
||||
backend_type = config.get('pubsub_backend', 'pulsar')
|
||||
|
||||
if backend_type == 'pulsar':
|
||||
return PulsarBackend(
|
||||
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
|
||||
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
|
||||
listener=config.get('pulsar_listener'),
|
||||
)
|
||||
elif backend_type == 'mqtt':
|
||||
# TODO: Implement MQTT backend
|
||||
raise NotImplementedError("MQTT backend not yet implemented")
|
||||
else:
|
||||
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||
|
||||
|
||||
class PulsarClient:
|
||||
|
||||
|
|
@ -71,10 +108,3 @@ class PulsarClient:
|
|||
'--pulsar-listener',
|
||||
help=f'Pulsar listener (default: none)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
default='INFO',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
help=f'Log level (default: INFO)'
|
||||
)
|
||||
|
|
|
|||
350
trustgraph-base/trustgraph/base/pulsar_backend.py
Normal file
350
trustgraph-base/trustgraph/base/pulsar_backend.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""
|
||||
Pulsar backend implementation for pub/sub abstraction.
|
||||
|
||||
This module provides a Pulsar-specific implementation of the backend interfaces,
|
||||
handling topic mapping, serialization, and Pulsar client management.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
import _pulsar
|
||||
import json
|
||||
import logging
|
||||
import base64
|
||||
import types
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any
|
||||
|
||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any) -> dict:
|
||||
"""
|
||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||
|
||||
None values are excluded from the dictionary (not serialized).
|
||||
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
|
||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
|
||||
# Handle dataclass - convert to dict then recursively process all values
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for key, value in asdict(obj).items():
|
||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||
return result
|
||||
|
||||
# Handle list - recursively process all items
|
||||
if isinstance(obj, list):
|
||||
return [dataclass_to_dict(item) for item in obj]
|
||||
|
||||
# Handle dict - recursively process all values
|
||||
if isinstance(obj, dict):
|
||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||
|
||||
# Return primitive types as-is
|
||||
return obj
|
||||
|
||||
|
||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||
"""
|
||||
Convert a dictionary back to a dataclass instance.
|
||||
|
||||
Handles nested dataclasses and missing fields.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if not is_dataclass(cls):
|
||||
return data
|
||||
|
||||
# Get field types from the dataclass
|
||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in field_types:
|
||||
field_type = field_types[key]
|
||||
|
||||
# Handle modern union types (X | Y)
|
||||
if isinstance(field_type, types.UnionType):
|
||||
# Check if it's Optional (X | None)
|
||||
if type(None) in field_type.__args__:
|
||||
# Get the non-None type
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Check if this is a generic type (list, dict, etc.)
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
# Handle list[T]
|
||||
if field_type.__origin__ == list:
|
||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||
kwargs[key] = [
|
||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle old-style Optional[T] (which is Union[T, None])
|
||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||
# Get the non-None type from Union
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle direct dataclass fields
|
||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||
elif field_type == bytes and isinstance(value, str):
|
||||
kwargs[key] = value.encode('utf-8')
|
||||
else:
|
||||
kwargs[key] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class PulsarMessage:
|
||||
"""Wrapper for Pulsar messages to match Message protocol."""
|
||||
|
||||
def __init__(self, pulsar_msg, schema_cls):
|
||||
self._msg = pulsar_msg
|
||||
self._schema_cls = schema_cls
|
||||
self._value = None
|
||||
|
||||
def value(self) -> Any:
|
||||
"""Deserialize and return the message value as a dataclass."""
|
||||
if self._value is None:
|
||||
# Get JSON string from Pulsar message
|
||||
json_data = self._msg.data().decode('utf-8')
|
||||
data_dict = json.loads(json_data)
|
||||
# Convert to dataclass
|
||||
self._value = dict_to_dataclass(data_dict, self._schema_cls)
|
||||
return self._value
|
||||
|
||||
def properties(self) -> dict:
|
||||
"""Return message properties."""
|
||||
return self._msg.properties()
|
||||
|
||||
|
||||
class PulsarBackendProducer:
|
||||
"""Pulsar-specific producer implementation."""
|
||||
|
||||
def __init__(self, pulsar_producer, schema_cls):
|
||||
self._producer = pulsar_producer
|
||||
self._schema_cls = schema_cls
|
||||
|
||||
def send(self, message: Any, properties: dict = {}) -> None:
|
||||
"""Send a dataclass message."""
|
||||
# Convert dataclass to dict, excluding None values
|
||||
data_dict = dataclass_to_dict(message)
|
||||
# Serialize to JSON
|
||||
json_data = json.dumps(data_dict)
|
||||
# Send via Pulsar
|
||||
self._producer.send(json_data.encode('utf-8'), properties=properties)
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush buffered messages."""
|
||||
self._producer.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the producer."""
|
||||
self._producer.close()
|
||||
|
||||
|
||||
class PulsarBackendConsumer:
|
||||
"""Pulsar-specific consumer implementation."""
|
||||
|
||||
def __init__(self, pulsar_consumer, schema_cls):
|
||||
self._consumer = pulsar_consumer
|
||||
self._schema_cls = schema_cls
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message."""
|
||||
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
|
||||
return PulsarMessage(pulsar_msg, self._schema_cls)
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
"""Acknowledge a message."""
|
||||
if isinstance(message, PulsarMessage):
|
||||
self._consumer.acknowledge(message._msg)
|
||||
|
||||
def negative_acknowledge(self, message: Message) -> None:
|
||||
"""Negative acknowledge a message."""
|
||||
if isinstance(message, PulsarMessage):
|
||||
self._consumer.negative_acknowledge(message._msg)
|
||||
|
||||
def unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the topic."""
|
||||
self._consumer.unsubscribe()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the consumer."""
|
||||
self._consumer.close()
|
||||
|
||||
|
||||
class PulsarBackend:
|
||||
"""
|
||||
Pulsar backend implementation.
|
||||
|
||||
Handles topic mapping, client management, and creation of Pulsar-specific
|
||||
producers and consumers.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, api_key: str = None, listener: str = None):
|
||||
"""
|
||||
Initialize Pulsar backend.
|
||||
|
||||
Args:
|
||||
host: Pulsar broker URL (e.g., pulsar://localhost:6650)
|
||||
api_key: Optional API key for authentication
|
||||
listener: Optional listener name for multi-homed setups
|
||||
"""
|
||||
self.host = host
|
||||
self.api_key = api_key
|
||||
self.listener = listener
|
||||
|
||||
# Create Pulsar client
|
||||
client_args = {'service_url': host}
|
||||
|
||||
if listener:
|
||||
client_args['listener_name'] = listener
|
||||
|
||||
if api_key:
|
||||
client_args['authentication'] = pulsar.AuthenticationToken(api_key)
|
||||
|
||||
self.client = pulsar.Client(**client_args)
|
||||
logger.info(f"Pulsar client connected to {host}")
|
||||
|
||||
def map_topic(self, generic_topic: str) -> str:
|
||||
"""
|
||||
Map generic topic format to Pulsar URI.
|
||||
|
||||
Format: qos/tenant/namespace/queue
|
||||
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
|
||||
|
||||
Args:
|
||||
generic_topic: Generic topic string or already-formatted Pulsar URI
|
||||
|
||||
Returns:
|
||||
Pulsar topic URI
|
||||
"""
|
||||
# If already a Pulsar URI, return as-is
|
||||
if '://' in generic_topic:
|
||||
return generic_topic
|
||||
|
||||
parts = generic_topic.split('/', 3)
|
||||
if len(parts) != 4:
|
||||
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
|
||||
|
||||
qos, tenant, namespace, queue = parts
|
||||
|
||||
# Map QoS to persistence
|
||||
if qos == 'q0':
|
||||
persistence = 'non-persistent'
|
||||
elif qos in ['q1', 'q2']:
|
||||
persistence = 'persistent'
|
||||
else:
|
||||
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
|
||||
|
||||
return f"{persistence}://{tenant}/{namespace}/{queue}"
|
||||
|
||||
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
|
||||
"""
|
||||
Create a Pulsar producer.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
schema: Dataclass type for messages
|
||||
**options: Backend-specific options (e.g., chunking_enabled)
|
||||
|
||||
Returns:
|
||||
PulsarBackendProducer instance
|
||||
"""
|
||||
pulsar_topic = self.map_topic(topic)
|
||||
|
||||
producer_args = {
|
||||
'topic': pulsar_topic,
|
||||
'schema': pulsar.schema.BytesSchema(), # We handle serialization ourselves
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if 'chunking_enabled' in options:
|
||||
producer_args['chunking_enabled'] = options['chunking_enabled']
|
||||
|
||||
pulsar_producer = self.client.create_producer(**producer_args)
|
||||
logger.debug(f"Created producer for topic: {pulsar_topic}")
|
||||
|
||||
return PulsarBackendProducer(pulsar_producer, schema)
|
||||
|
||||
def create_consumer(
|
||||
self,
|
||||
topic: str,
|
||||
subscription: str,
|
||||
schema: type,
|
||||
initial_position: str = 'latest',
|
||||
consumer_type: str = 'shared',
|
||||
**options
|
||||
) -> BackendConsumer:
|
||||
"""
|
||||
Create a Pulsar consumer.
|
||||
|
||||
Args:
|
||||
topic: Generic topic format (qos/tenant/namespace/queue)
|
||||
subscription: Subscription name
|
||||
schema: Dataclass type for messages
|
||||
initial_position: 'earliest' or 'latest'
|
||||
consumer_type: 'shared', 'exclusive', or 'failover'
|
||||
**options: Backend-specific options
|
||||
|
||||
Returns:
|
||||
PulsarBackendConsumer instance
|
||||
"""
|
||||
pulsar_topic = self.map_topic(topic)
|
||||
|
||||
# Map initial position
|
||||
if initial_position == 'earliest':
|
||||
pos = pulsar.InitialPosition.Earliest
|
||||
else:
|
||||
pos = pulsar.InitialPosition.Latest
|
||||
|
||||
# Map consumer type
|
||||
if consumer_type == 'exclusive':
|
||||
ctype = pulsar.ConsumerType.Exclusive
|
||||
elif consumer_type == 'failover':
|
||||
ctype = pulsar.ConsumerType.Failover
|
||||
else:
|
||||
ctype = pulsar.ConsumerType.Shared
|
||||
|
||||
consumer_args = {
|
||||
'topic': pulsar_topic,
|
||||
'subscription_name': subscription,
|
||||
'schema': pulsar.schema.BytesSchema(), # We handle deserialization ourselves
|
||||
'initial_position': pos,
|
||||
'consumer_type': ctype,
|
||||
}
|
||||
|
||||
pulsar_consumer = self.client.subscribe(**consumer_args)
|
||||
logger.debug(f"Created consumer for topic: {pulsar_topic}, subscription: {subscription}")
|
||||
|
||||
return PulsarBackendConsumer(pulsar_consumer, schema)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Pulsar client."""
|
||||
self.client.close()
|
||||
logger.info("Pulsar client closed")
|
||||
|
|
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||
class RequestResponse(Subscriber):
|
||||
|
||||
def __init__(
|
||||
self, client, subscription, consumer_name,
|
||||
self, backend, subscription, consumer_name,
|
||||
request_topic, request_schema,
|
||||
request_metrics,
|
||||
response_topic, response_schema,
|
||||
|
|
@ -22,7 +22,7 @@ class RequestResponse(Subscriber):
|
|||
):
|
||||
|
||||
super(RequestResponse, self).__init__(
|
||||
client = client,
|
||||
backend = backend,
|
||||
subscription = subscription,
|
||||
consumer_name = consumer_name,
|
||||
topic = response_topic,
|
||||
|
|
@ -31,7 +31,7 @@ class RequestResponse(Subscriber):
|
|||
)
|
||||
|
||||
self.producer = Producer(
|
||||
client = client,
|
||||
backend = backend,
|
||||
topic = request_topic,
|
||||
schema = request_schema,
|
||||
metrics = request_metrics,
|
||||
|
|
@ -126,7 +126,7 @@ class RequestResponseSpec(Spec):
|
|||
)
|
||||
|
||||
rr = self.impl(
|
||||
client = processor.pulsar_client,
|
||||
backend = processor.pubsub,
|
||||
|
||||
# Make subscription names unique, so that all subscribers get
|
||||
# to see all response messages
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@
|
|||
# off of a queue and make it available using an internal broker system,
|
||||
# so suitable for when multiple recipients are reading from the same queue
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
import asyncio
|
||||
import _pulsar
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
|
|
@ -13,12 +11,16 @@ import uuid
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Timeout exception - can come from different backends
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
class Subscriber:
|
||||
|
||||
def __init__(self, client, topic, subscription, consumer_name,
|
||||
def __init__(self, backend, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100, metrics=None,
|
||||
backpressure_strategy="block", drain_timeout=5.0):
|
||||
self.client = client
|
||||
self.backend = backend # Changed from 'client' to 'backend'
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
self.consumer_name = consumer_name
|
||||
|
|
@ -43,18 +45,14 @@ class Subscriber:
|
|||
|
||||
async def start(self):
|
||||
|
||||
# Build subscribe arguments
|
||||
subscribe_args = {
|
||||
'topic': self.topic,
|
||||
'subscription_name': self.subscription,
|
||||
'consumer_name': self.consumer_name,
|
||||
}
|
||||
|
||||
# Only add schema if provided (omit if None)
|
||||
if self.schema is not None:
|
||||
subscribe_args['schema'] = JsonSchema(self.schema)
|
||||
|
||||
self.consumer = self.client.subscribe(**subscribe_args)
|
||||
# Create consumer via backend
|
||||
self.consumer = await asyncio.to_thread(
|
||||
self.backend.create_consumer,
|
||||
topic=self.topic,
|
||||
subscription=self.subscription,
|
||||
schema=self.schema,
|
||||
consumer_type='shared',
|
||||
)
|
||||
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
|
|
@ -94,12 +92,13 @@ class Subscriber:
|
|||
drain_end_time = time.time() + self.drain_timeout
|
||||
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
|
||||
|
||||
# Stop accepting new messages from Pulsar during drain
|
||||
if self.consumer:
|
||||
# Stop accepting new messages during drain
|
||||
# Note: Not all backends support pausing message listeners
|
||||
if self.consumer and hasattr(self.consumer, 'pause_message_listener'):
|
||||
try:
|
||||
self.consumer.pause_message_listener()
|
||||
except _pulsar.InvalidConfiguration:
|
||||
# Not all consumers have message listeners (e.g., blocking receive mode)
|
||||
except Exception:
|
||||
# Not all consumers support message listeners
|
||||
pass
|
||||
|
||||
# Check drain timeout
|
||||
|
|
@ -133,9 +132,10 @@ class Subscriber:
|
|||
self.consumer.receive,
|
||||
timeout_millis=250
|
||||
)
|
||||
except _pulsar.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
# Handle timeout from any backend
|
||||
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
|
||||
continue
|
||||
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
|
|
@ -157,19 +157,20 @@ class Subscriber:
|
|||
for msg in self.pending_acks.values():
|
||||
try:
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Consumer already closed
|
||||
except Exception:
|
||||
pass # Consumer already closed or error
|
||||
self.pending_acks.clear()
|
||||
|
||||
if self.consumer:
|
||||
try:
|
||||
self.consumer.unsubscribe()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
if hasattr(self.consumer, 'unsubscribe'):
|
||||
try:
|
||||
self.consumer.unsubscribe()
|
||||
except Exception:
|
||||
pass # Already closed or error
|
||||
try:
|
||||
self.consumer.close()
|
||||
except _pulsar.AlreadyClosed:
|
||||
pass # Already closed
|
||||
except Exception:
|
||||
pass # Already closed or error
|
||||
self.consumer = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class SubscriberSpec(Spec):
|
|||
)
|
||||
|
||||
subscriber = Subscriber(
|
||||
client = processor.pulsar_client,
|
||||
backend = processor.pubsub,
|
||||
topic = definition[self.name],
|
||||
subscription = flow.id,
|
||||
consumer_name = flow.id,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import time
|
|||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .. exceptions import *
|
||||
from ..base.pubsub import get_pubsub
|
||||
|
||||
# Default timeout for a request/response. In seconds.
|
||||
DEFAULT_TIMEOUT=300
|
||||
|
|
@ -39,30 +40,25 @@ class BaseClient:
|
|||
if subscriber == None:
|
||||
subscriber = str(uuid.uuid4())
|
||||
|
||||
if pulsar_api_key:
|
||||
auth = pulsar.AuthenticationToken(pulsar_api_key)
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
logger=pulsar.ConsoleLogger(log_level),
|
||||
authentication=auth,
|
||||
listener=listener,
|
||||
)
|
||||
else:
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
logger=pulsar.ConsoleLogger(log_level),
|
||||
listener_name=listener,
|
||||
)
|
||||
# Create backend using factory
|
||||
self.backend = get_pubsub(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
pulsar_listener=listener,
|
||||
pubsub_backend='pulsar'
|
||||
)
|
||||
|
||||
self.producer = self.client.create_producer(
|
||||
self.producer = self.backend.create_producer(
|
||||
topic=input_queue,
|
||||
schema=JsonSchema(input_schema),
|
||||
schema=input_schema,
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
output_queue, subscriber,
|
||||
schema=JsonSchema(output_schema),
|
||||
self.consumer = self.backend.create_consumer(
|
||||
topic=output_queue,
|
||||
subscription=subscriber,
|
||||
schema=output_schema,
|
||||
consumer_type='shared',
|
||||
)
|
||||
|
||||
self.input_schema = input_schema
|
||||
|
|
@ -136,10 +132,11 @@ class BaseClient:
|
|||
|
||||
if hasattr(self, "consumer"):
|
||||
self.consumer.close()
|
||||
|
||||
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.flush()
|
||||
self.producer.close()
|
||||
|
||||
self.client.close()
|
||||
|
||||
if hasattr(self, "backend"):
|
||||
self.backend.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ class ConfigClient(BaseClient):
|
|||
def get(self, keys, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="get",
|
||||
keys=[
|
||||
ConfigKey(
|
||||
|
|
@ -88,7 +87,6 @@ class ConfigClient(BaseClient):
|
|||
def list(self, type, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="list",
|
||||
type=type,
|
||||
timeout=timeout
|
||||
|
|
@ -99,7 +97,6 @@ class ConfigClient(BaseClient):
|
|||
def getvalues(self, type, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="getvalues",
|
||||
type=type,
|
||||
timeout=timeout
|
||||
|
|
@ -117,7 +114,6 @@ class ConfigClient(BaseClient):
|
|||
def delete(self, keys, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="delete",
|
||||
keys=[
|
||||
ConfigKey(
|
||||
|
|
@ -134,7 +130,6 @@ class ConfigClient(BaseClient):
|
|||
def put(self, values, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="put",
|
||||
values=[
|
||||
ConfigValue(
|
||||
|
|
@ -152,7 +147,6 @@ class ConfigClient(BaseClient):
|
|||
def config(self, timeout=300):
|
||||
|
||||
resp = self.call(
|
||||
id=id,
|
||||
operation="config",
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
|||
name=data.get("name"),
|
||||
description=data.get("description"),
|
||||
tags=data.get("tags"),
|
||||
created_at=data.get("created_at"),
|
||||
updated_at=data.get("updated_at"),
|
||||
tag_filter=data.get("tag_filter"),
|
||||
limit=data.get("limit")
|
||||
)
|
||||
|
|
@ -38,10 +36,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
|||
result["description"] = obj.description
|
||||
if obj.tags is not None:
|
||||
result["tags"] = list(obj.tags)
|
||||
if obj.created_at is not None:
|
||||
result["created_at"] = obj.created_at
|
||||
if obj.updated_at is not None:
|
||||
result["updated_at"] = obj.updated_at
|
||||
if obj.tag_filter is not None:
|
||||
result["tag_filter"] = list(obj.tag_filter)
|
||||
if obj.limit is not None:
|
||||
|
|
@ -73,9 +67,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
|
|||
collection=coll_data.get("collection"),
|
||||
name=coll_data.get("name"),
|
||||
description=coll_data.get("description"),
|
||||
tags=coll_data.get("tags"),
|
||||
created_at=coll_data.get("created_at"),
|
||||
updated_at=coll_data.get("updated_at")
|
||||
tags=coll_data.get("tags", [])
|
||||
))
|
||||
|
||||
return CollectionManagementResponse(
|
||||
|
|
@ -104,9 +96,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
|
|||
"collection": coll.collection,
|
||||
"name": coll.name,
|
||||
"description": coll.description,
|
||||
"tags": list(coll.tags) if coll.tags else [],
|
||||
"created_at": coll.created_at,
|
||||
"updated_at": coll.updated_at
|
||||
"tags": list(coll.tags) if coll.tags else []
|
||||
})
|
||||
|
||||
print("RESULT IS", result, flush=True)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,9 @@ class StructuredDataDiagnosisResponseTranslator(MessageTranslator):
|
|||
result["descriptor"] = obj.descriptor
|
||||
if obj.metadata:
|
||||
result["metadata"] = obj.metadata
|
||||
if obj.schema_matches is not None:
|
||||
# For schema-selection, always include schema_matches (even if empty)
|
||||
# For other operations, only include if non-empty
|
||||
if obj.operation == "schema-selection" or obj.schema_matches:
|
||||
result["schema-matches"] = obj.schema_matches
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -42,12 +42,17 @@ class PromptResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.text:
|
||||
|
||||
# Include text field if present (even if empty string)
|
||||
if obj.text is not None:
|
||||
result["text"] = obj.text
|
||||
if obj.object:
|
||||
# Include object field if present
|
||||
if obj.object is not None:
|
||||
result["object"] = obj.object
|
||||
|
||||
|
||||
# Always include end_of_stream flag for streaming support
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -34,14 +34,12 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Check if this is a streaming response (has chunk)
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
result["chunk"] = obj.chunk
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
else:
|
||||
# Non-streaming response
|
||||
if obj.response:
|
||||
result["response"] = obj.response
|
||||
# Include response content (even if empty string)
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Include end_of_stream flag
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
|
|
@ -51,13 +49,7 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
# For streaming responses, check end_of_stream
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
else:
|
||||
# For non-streaming responses, it's always final
|
||||
is_final = True
|
||||
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
||||
|
||||
|
|
@ -98,14 +90,12 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
# Check if this is a streaming response (has chunk)
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
result["chunk"] = obj.chunk
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
else:
|
||||
# Non-streaming response
|
||||
if obj.response:
|
||||
result["response"] = obj.response
|
||||
# Include response content (even if empty string)
|
||||
if obj.response is not None:
|
||||
result["response"] = obj.response
|
||||
|
||||
# Include end_of_stream flag
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
|
|
@ -115,11 +105,5 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
# For streaming responses, check end_of_stream
|
||||
if hasattr(obj, 'chunk') and obj.chunk:
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
else:
|
||||
# For non-streaming responses, it's always final
|
||||
is_final = True
|
||||
|
||||
is_final = getattr(obj, 'end_of_stream', False)
|
||||
return self.from_pulsar(obj), is_final
|
||||
|
|
@ -28,14 +28,17 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||
result = {"response": obj.response}
|
||||
|
||||
|
||||
if obj.in_token:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model:
|
||||
result["model"] = obj.model
|
||||
|
||||
|
||||
# Always include end_of_stream flag for streaming support
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
|
||||
from pulsar.schema import Record, String, Array
|
||||
from dataclasses import dataclass, field
|
||||
from .primitives import Triple
|
||||
|
||||
class Metadata(Record):
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
# Source identifier
|
||||
id = String()
|
||||
id: str = ""
|
||||
|
||||
# Subgraph
|
||||
metadata = Array(Triple())
|
||||
metadata: list[Triple] = field(default_factory=list)
|
||||
|
||||
# Collection management
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
|
|
|
|||
|
|
@ -1,34 +1,39 @@
|
|||
|
||||
from pulsar.schema import Record, String, Boolean, Array, Integer
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
class Error(Record):
|
||||
type = String()
|
||||
message = String()
|
||||
@dataclass
|
||||
class Error:
|
||||
type: str = ""
|
||||
message: str = ""
|
||||
|
||||
class Value(Record):
|
||||
value = String()
|
||||
is_uri = Boolean()
|
||||
type = String()
|
||||
@dataclass
|
||||
class Value:
|
||||
value: str = ""
|
||||
is_uri: bool = False
|
||||
type: str = ""
|
||||
|
||||
class Triple(Record):
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
@dataclass
|
||||
class Triple:
|
||||
s: Value | None = None
|
||||
p: Value | None = None
|
||||
o: Value | None = None
|
||||
|
||||
class Field(Record):
|
||||
name = String()
|
||||
@dataclass
|
||||
class Field:
|
||||
name: str = ""
|
||||
# int, string, long, bool, float, double, timestamp
|
||||
type = String()
|
||||
size = Integer()
|
||||
primary = Boolean()
|
||||
description = String()
|
||||
type: str = ""
|
||||
size: int = 0
|
||||
primary: bool = False
|
||||
description: str = ""
|
||||
# NEW FIELDS for structured data:
|
||||
required = Boolean() # Whether field is required
|
||||
enum_values = Array(String()) # For enum type fields
|
||||
indexed = Boolean() # Whether field should be indexed
|
||||
required: bool = False # Whether field is required
|
||||
enum_values: list[str] = field(default_factory=list) # For enum type fields
|
||||
indexed: bool = False # Whether field should be indexed
|
||||
|
||||
class RowSchema(Record):
|
||||
name = String()
|
||||
description = String()
|
||||
fields = Array(Field())
|
||||
@dataclass
|
||||
class RowSchema:
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
fields: list[Field] = field(default_factory=list)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,23 @@
|
|||
|
||||
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
|
||||
return f"{kind}://{tenant}/{namespace}/{topic}"
|
||||
def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
|
||||
"""
|
||||
Create a generic topic identifier that can be mapped by backends.
|
||||
|
||||
Args:
|
||||
queue_name: The queue/topic name
|
||||
qos: Quality of service
|
||||
- 'q0' = best-effort (no ack)
|
||||
- 'q1' = at-least-once (ack required)
|
||||
- 'q2' = exactly-once (two-phase ack)
|
||||
tenant: Tenant identifier for multi-tenancy
|
||||
namespace: Namespace within tenant
|
||||
|
||||
Returns:
|
||||
Generic topic string: qos/tenant/namespace/queue_name
|
||||
|
||||
Examples:
|
||||
topic('my-queue') # q1/tg/flow/my-queue
|
||||
topic('config', qos='q2', namespace='config') # q2/tg/config/config
|
||||
"""
|
||||
return f"{qos}/{tenant}/{namespace}/{queue_name}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, Bytes
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.topic import topic
|
||||
|
|
@ -6,24 +6,27 @@ from ..core.topic import topic
|
|||
############################################################################
|
||||
|
||||
# PDF docs etc.
|
||||
class Document(Record):
|
||||
metadata = Metadata()
|
||||
data = Bytes()
|
||||
@dataclass
|
||||
class Document:
|
||||
metadata: Metadata | None = None
|
||||
data: bytes = b""
|
||||
|
||||
############################################################################
|
||||
|
||||
# Text documents / text from PDF
|
||||
|
||||
class TextDocument(Record):
|
||||
metadata = Metadata()
|
||||
text = Bytes()
|
||||
@dataclass
|
||||
class TextDocument:
|
||||
metadata: Metadata | None = None
|
||||
text: bytes = b""
|
||||
|
||||
############################################################################
|
||||
|
||||
# Chunks of text
|
||||
|
||||
class Chunk(Record):
|
||||
metadata = Metadata()
|
||||
chunk = Bytes()
|
||||
@dataclass
|
||||
class Chunk:
|
||||
metadata: Metadata | None = None
|
||||
chunk: bytes = b""
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double, Map
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.primitives import Value, RowSchema
|
||||
|
|
@ -8,49 +8,55 @@ from ..core.topic import topic
|
|||
|
||||
# Graph embeddings are embeddings associated with a graph entity
|
||||
|
||||
class EntityEmbeddings(Record):
|
||||
entity = Value()
|
||||
vectors = Array(Array(Double()))
|
||||
@dataclass
|
||||
class EntityEmbeddings:
|
||||
entity: Value | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
# This is a 'batching' mechanism for the above data
|
||||
class GraphEmbeddings(Record):
|
||||
metadata = Metadata()
|
||||
entities = Array(EntityEmbeddings())
|
||||
@dataclass
|
||||
class GraphEmbeddings:
|
||||
metadata: Metadata | None = None
|
||||
entities: list[EntityEmbeddings] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Document embeddings are embeddings associated with a chunk
|
||||
|
||||
class ChunkEmbeddings(Record):
|
||||
chunk = Bytes()
|
||||
vectors = Array(Array(Double()))
|
||||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk: bytes = b""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
# This is a 'batching' mechanism for the above data
|
||||
class DocumentEmbeddings(Record):
|
||||
metadata = Metadata()
|
||||
chunks = Array(ChunkEmbeddings())
|
||||
@dataclass
|
||||
class DocumentEmbeddings:
|
||||
metadata: Metadata | None = None
|
||||
chunks: list[ChunkEmbeddings] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Object embeddings are embeddings associated with the primary key of an
|
||||
# object
|
||||
|
||||
class ObjectEmbeddings(Record):
|
||||
metadata = Metadata()
|
||||
vectors = Array(Array(Double()))
|
||||
name = String()
|
||||
key_name = String()
|
||||
id = String()
|
||||
@dataclass
|
||||
class ObjectEmbeddings:
|
||||
metadata: Metadata | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
name: str = ""
|
||||
key_name: str = ""
|
||||
id: str = ""
|
||||
|
||||
############################################################################
|
||||
|
||||
# Structured object embeddings with enhanced capabilities
|
||||
|
||||
class StructuredObjectEmbedding(Record):
|
||||
metadata = Metadata()
|
||||
vectors = Array(Array(Double()))
|
||||
schema_name = String()
|
||||
object_id = String() # Primary key value
|
||||
field_embeddings = Map(Array(Double())) # Per-field embeddings
|
||||
@dataclass
|
||||
class StructuredObjectEmbedding:
|
||||
metadata: Metadata | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
schema_name: str = ""
|
||||
object_id: str = "" # Primary key value
|
||||
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Array
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.primitives import Value, Triple
|
||||
from ..core.metadata import Metadata
|
||||
|
|
@ -8,21 +8,24 @@ from ..core.topic import topic
|
|||
|
||||
# Entity context are an entity associated with textual context
|
||||
|
||||
class EntityContext(Record):
|
||||
entity = Value()
|
||||
context = String()
|
||||
@dataclass
|
||||
class EntityContext:
|
||||
entity: Value | None = None
|
||||
context: str = ""
|
||||
|
||||
# This is a 'batching' mechanism for the above data
|
||||
class EntityContexts(Record):
|
||||
metadata = Metadata()
|
||||
entities = Array(EntityContext())
|
||||
@dataclass
|
||||
class EntityContexts:
|
||||
metadata: Metadata | None = None
|
||||
entities: list[EntityContext] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph triples
|
||||
|
||||
class Triples(Record):
|
||||
metadata = Metadata()
|
||||
triples = Array(Triple())
|
||||
@dataclass
|
||||
class Triples:
|
||||
metadata: Metadata | None = None
|
||||
triples: list[Triple] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Array, Long, Boolean
|
||||
from dataclasses import dataclass, field
|
||||
from ..core.primitives import Triple, Error
|
||||
from ..core.topic import topic
|
||||
from ..core.metadata import Metadata
|
||||
|
|
@ -22,40 +21,40 @@ from .embeddings import GraphEmbeddings
|
|||
# <- ()
|
||||
# <- (error)
|
||||
|
||||
class KnowledgeRequest(Record):
|
||||
|
||||
@dataclass
|
||||
class KnowledgeRequest:
|
||||
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
|
||||
# load-kg-core, unload-kg-core
|
||||
operation = String()
|
||||
operation: str = ""
|
||||
|
||||
# list-kg-cores, delete-kg-core, put-kg-core
|
||||
user = String()
|
||||
user: str = ""
|
||||
|
||||
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
|
||||
# load-kg-core, unload-kg-core
|
||||
id = String()
|
||||
id: str = ""
|
||||
|
||||
# load-kg-core
|
||||
flow = String()
|
||||
flow: str = ""
|
||||
|
||||
# load-kg-core
|
||||
collection = String()
|
||||
collection: str = ""
|
||||
|
||||
# put-kg-core
|
||||
triples = Triples()
|
||||
graph_embeddings = GraphEmbeddings()
|
||||
triples: Triples | None = None
|
||||
graph_embeddings: GraphEmbeddings | None = None
|
||||
|
||||
class KnowledgeResponse(Record):
|
||||
error = Error()
|
||||
ids = Array(String())
|
||||
eos = Boolean() # Indicates end of knowledge core stream
|
||||
triples = Triples()
|
||||
graph_embeddings = GraphEmbeddings()
|
||||
@dataclass
|
||||
class KnowledgeResponse:
|
||||
error: Error | None = None
|
||||
ids: list[str] = field(default_factory=list)
|
||||
eos: bool = False # Indicates end of knowledge core stream
|
||||
triples: Triples | None = None
|
||||
graph_embeddings: GraphEmbeddings | None = None
|
||||
|
||||
knowledge_request_queue = topic(
|
||||
'knowledge', kind='non-persistent', namespace='request'
|
||||
'knowledge', qos='q0', namespace='request'
|
||||
)
|
||||
knowledge_response_queue = topic(
|
||||
'knowledge', kind='non-persistent', namespace='response',
|
||||
'knowledge', qos='q0', namespace='response',
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Boolean
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..core.topic import topic
|
||||
|
||||
|
|
@ -6,21 +6,25 @@ from ..core.topic import topic
|
|||
|
||||
# NLP extraction data types
|
||||
|
||||
class Definition(Record):
|
||||
name = String()
|
||||
definition = String()
|
||||
@dataclass
|
||||
class Definition:
|
||||
name: str = ""
|
||||
definition: str = ""
|
||||
|
||||
class Topic(Record):
|
||||
name = String()
|
||||
definition = String()
|
||||
@dataclass
|
||||
class Topic:
|
||||
name: str = ""
|
||||
definition: str = ""
|
||||
|
||||
class Relationship(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
o_entity = Boolean()
|
||||
@dataclass
|
||||
class Relationship:
|
||||
s: str = ""
|
||||
p: str = ""
|
||||
o: str = ""
|
||||
o_entity: bool = False
|
||||
|
||||
class Fact(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
@dataclass
|
||||
class Fact:
|
||||
s: str = ""
|
||||
p: str = ""
|
||||
o: str = ""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Map, Double, Array
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.topic import topic
|
||||
|
|
@ -7,11 +7,13 @@ from ..core.topic import topic
|
|||
|
||||
# Extracted object from text processing
|
||||
|
||||
class ExtractedObject(Record):
|
||||
metadata = Metadata()
|
||||
schema_name = String() # Which schema this object belongs to
|
||||
values = Array(Map(String())) # Array of objects, each object is field name -> value
|
||||
confidence = Double()
|
||||
source_span = String() # Text span where object was found
|
||||
@dataclass
|
||||
class ExtractedObject:
|
||||
metadata: Metadata | None = None
|
||||
schema_name: str = "" # Which schema this object belongs to
|
||||
values: list[dict[str, str]] = field(default_factory=list) # Array of objects, each object is field name -> value
|
||||
confidence: float = 0.0
|
||||
source_span: str = "" # Text span where object was found
|
||||
|
||||
############################################################################
|
||||
|
||||
############################################################################
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, Array, Map, String
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.primitives import RowSchema
|
||||
|
|
@ -8,9 +8,10 @@ from ..core.topic import topic
|
|||
|
||||
# Stores rows of information
|
||||
|
||||
class Rows(Record):
|
||||
metadata = Metadata()
|
||||
row_schema = RowSchema()
|
||||
rows = Array(Map(String()))
|
||||
@dataclass
|
||||
class Rows:
|
||||
metadata: Metadata | None = None
|
||||
row_schema: RowSchema | None = None
|
||||
rows: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Bytes, Map
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.metadata import Metadata
|
||||
from ..core.topic import topic
|
||||
|
|
@ -7,11 +7,13 @@ from ..core.topic import topic
|
|||
|
||||
# Structured data submission for fire-and-forget processing
|
||||
|
||||
class StructuredDataSubmission(Record):
|
||||
metadata = Metadata()
|
||||
format = String() # "json", "csv", "xml"
|
||||
schema_name = String() # Reference to schema in config
|
||||
data = Bytes() # Raw data to ingest
|
||||
options = Map(String()) # Format-specific options
|
||||
@dataclass
|
||||
class StructuredDataSubmission:
|
||||
metadata: Metadata | None = None
|
||||
format: str = "" # "json", "csv", "xml"
|
||||
schema_name: str = "" # Reference to schema in config
|
||||
data: bytes = b"" # Raw data to ingest
|
||||
options: dict[str, str] = field(default_factory=dict) # Format-specific options
|
||||
|
||||
############################################################################
|
||||
|
||||
############################################################################
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from pulsar.schema import Record, String, Array, Map, Boolean
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.topic import topic
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -8,33 +8,36 @@ from ..core.primitives import Error
|
|||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
class AgentStep(Record):
|
||||
thought = String()
|
||||
action = String()
|
||||
arguments = Map(String())
|
||||
observation = String()
|
||||
user = String() # User context for the step
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
thought: str = ""
|
||||
action: str = ""
|
||||
arguments: dict[str, str] = field(default_factory=dict)
|
||||
observation: str = ""
|
||||
user: str = "" # User context for the step
|
||||
|
||||
class AgentRequest(Record):
|
||||
question = String()
|
||||
state = String()
|
||||
group = Array(String())
|
||||
history = Array(AgentStep())
|
||||
user = String() # User context for multi-tenancy
|
||||
streaming = Boolean() # NEW: Enable streaming response delivery (default false)
|
||||
@dataclass
|
||||
class AgentRequest:
|
||||
question: str = ""
|
||||
state: str = ""
|
||||
group: list[str] | None = None
|
||||
history: list[AgentStep] = field(default_factory=list)
|
||||
user: str = "" # User context for multi-tenancy
|
||||
streaming: bool = False # NEW: Enable streaming response delivery (default false)
|
||||
|
||||
class AgentResponse(Record):
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
# Streaming-first design
|
||||
chunk_type = String() # "thought", "action", "observation", "answer", "error"
|
||||
content = String() # The actual content (interpretation depends on chunk_type)
|
||||
end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete
|
||||
end_of_dialog = Boolean() # Entire agent dialog is complete
|
||||
chunk_type: str = "" # "thought", "action", "observation", "answer", "error"
|
||||
content: str = "" # The actual content (interpretation depends on chunk_type)
|
||||
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
|
||||
end_of_dialog: bool = False # Entire agent dialog is complete
|
||||
|
||||
# Legacy fields (deprecated but kept for backward compatibility)
|
||||
answer = String()
|
||||
error = Error()
|
||||
thought = String()
|
||||
observation = String()
|
||||
answer: str = ""
|
||||
error: Error | None = None
|
||||
thought: str = ""
|
||||
observation: str = ""
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pulsar.schema import Record, String, Integer, Array
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -10,41 +10,40 @@ from ..core.topic import topic
|
|||
|
||||
# Collection metadata operations (for librarian service)
|
||||
|
||||
class CollectionMetadata(Record):
|
||||
@dataclass
|
||||
class CollectionMetadata:
|
||||
"""Collection metadata record"""
|
||||
user = String()
|
||||
collection = String()
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
created_at = String() # ISO timestamp
|
||||
updated_at = String() # ISO timestamp
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
class CollectionManagementRequest(Record):
|
||||
@dataclass
|
||||
class CollectionManagementRequest:
|
||||
"""Request for collection management operations"""
|
||||
operation = String() # e.g., "delete-collection"
|
||||
operation: str = "" # e.g., "delete-collection"
|
||||
|
||||
# For 'list-collections'
|
||||
user = String()
|
||||
collection = String()
|
||||
timestamp = String() # ISO timestamp
|
||||
name = String()
|
||||
description = String()
|
||||
tags = Array(String())
|
||||
created_at = String() # ISO timestamp
|
||||
updated_at = String() # ISO timestamp
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
timestamp: str = "" # ISO timestamp
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
# For list
|
||||
tag_filter = Array(String()) # Optional filter by tags
|
||||
limit = Integer()
|
||||
tag_filter: list[str] = field(default_factory=list) # Optional filter by tags
|
||||
limit: int = 0
|
||||
|
||||
class CollectionManagementResponse(Record):
|
||||
@dataclass
|
||||
class CollectionManagementResponse:
|
||||
"""Response for collection management operations"""
|
||||
error = Error() # Only populated if there's an error
|
||||
timestamp = String() # ISO timestamp
|
||||
collections = Array(CollectionMetadata())
|
||||
error: Error | None = None # Only populated if there's an error
|
||||
timestamp: str = "" # ISO timestamp
|
||||
collections: list[CollectionMetadata] = field(default_factory=list)
|
||||
|
||||
|
||||
############################################################################
|
||||
|
|
@ -52,8 +51,9 @@ class CollectionManagementResponse(Record):
|
|||
# Topics
|
||||
|
||||
collection_request_queue = topic(
|
||||
'collection', kind='non-persistent', namespace='request'
|
||||
'collection', qos='q0', namespace='request'
|
||||
)
|
||||
collection_response_queue = topic(
|
||||
'collection', kind='non-persistent', namespace='response'
|
||||
'collection', qos='q0', namespace='response'
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.topic import topic
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -13,58 +13,61 @@ from ..core.primitives import Error
|
|||
# put(values) -> ()
|
||||
# delete(keys) -> ()
|
||||
# config() -> (version, config)
|
||||
class ConfigKey(Record):
|
||||
type = String()
|
||||
key = String()
|
||||
@dataclass
|
||||
class ConfigKey:
|
||||
type: str = ""
|
||||
key: str = ""
|
||||
|
||||
class ConfigValue(Record):
|
||||
type = String()
|
||||
key = String()
|
||||
value = String()
|
||||
@dataclass
|
||||
class ConfigValue:
|
||||
type: str = ""
|
||||
key: str = ""
|
||||
value: str = ""
|
||||
|
||||
# Prompt services, abstract the prompt generation
|
||||
class ConfigRequest(Record):
|
||||
|
||||
operation = String() # get, list, getvalues, delete, put, config
|
||||
@dataclass
|
||||
class ConfigRequest:
|
||||
operation: str = "" # get, list, getvalues, delete, put, config
|
||||
|
||||
# get, delete
|
||||
keys = Array(ConfigKey())
|
||||
keys: list[ConfigKey] = field(default_factory=list)
|
||||
|
||||
# list, getvalues
|
||||
type = String()
|
||||
type: str = ""
|
||||
|
||||
# put
|
||||
values = Array(ConfigValue())
|
||||
|
||||
class ConfigResponse(Record):
|
||||
values: list[ConfigValue] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ConfigResponse:
|
||||
# get, list, getvalues, config
|
||||
version = Integer()
|
||||
version: int = 0
|
||||
|
||||
# get, getvalues
|
||||
values = Array(ConfigValue())
|
||||
values: list[ConfigValue] = field(default_factory=list)
|
||||
|
||||
# list
|
||||
directory = Array(String())
|
||||
directory: list[str] = field(default_factory=list)
|
||||
|
||||
# config
|
||||
config = Map(Map(String()))
|
||||
config: dict[str, dict[str, str]] = field(default_factory=dict)
|
||||
|
||||
# Everything
|
||||
error = Error()
|
||||
error: Error | None = None
|
||||
|
||||
class ConfigPush(Record):
|
||||
version = Integer()
|
||||
config = Map(Map(String()))
|
||||
@dataclass
|
||||
class ConfigPush:
|
||||
version: int = 0
|
||||
config: dict[str, dict[str, str]] = field(default_factory=dict)
|
||||
|
||||
config_request_queue = topic(
|
||||
'config', kind='non-persistent', namespace='request'
|
||||
'config', qos='q0', namespace='request'
|
||||
)
|
||||
config_response_queue = topic(
|
||||
'config', kind='non-persistent', namespace='response'
|
||||
'config', qos='q0', namespace='response'
|
||||
)
|
||||
config_push_queue = topic(
|
||||
'config', kind='persistent', namespace='config'
|
||||
'config', qos='q2', namespace='config'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,33 +1,36 @@
|
|||
from pulsar.schema import Record, String, Map, Double, Array
|
||||
from dataclasses import dataclass, field
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Structured data diagnosis services
|
||||
|
||||
class StructuredDataDiagnosisRequest(Record):
|
||||
operation = String() # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
|
||||
sample = String() # Data sample to analyze (text content)
|
||||
type = String() # Data type (csv, json, xml) - optional, required for generate-descriptor
|
||||
schema_name = String() # Target schema name for descriptor generation - optional
|
||||
@dataclass
|
||||
class StructuredDataDiagnosisRequest:
|
||||
operation: str = "" # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
|
||||
sample: str = "" # Data sample to analyze (text content)
|
||||
type: str = "" # Data type (csv, json, xml) - optional, required for generate-descriptor
|
||||
schema_name: str = "" # Target schema name for descriptor generation - optional
|
||||
|
||||
# JSON encoded options (e.g., delimiter for CSV)
|
||||
options = Map(String())
|
||||
options: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
class StructuredDataDiagnosisResponse(Record):
|
||||
error = Error()
|
||||
@dataclass
|
||||
class StructuredDataDiagnosisResponse:
|
||||
error: Error | None = None
|
||||
|
||||
operation = String() # The operation that was performed
|
||||
detected_type = String() # Detected data type (for detect-type/diagnose) - optional
|
||||
confidence = Double() # Confidence score for type detection - optional
|
||||
operation: str = "" # The operation that was performed
|
||||
detected_type: str = "" # Detected data type (for detect-type/diagnose) - optional
|
||||
confidence: float = 0.0 # Confidence score for type detection - optional
|
||||
|
||||
# JSON encoded descriptor (for generate-descriptor/diagnose) - optional
|
||||
descriptor = String()
|
||||
descriptor: str = ""
|
||||
|
||||
# JSON encoded additional metadata (e.g., field count, sample records)
|
||||
metadata = Map(String())
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Array of matching schema IDs (for schema-selection operation) - optional
|
||||
schema_matches = Array(String())
|
||||
schema_matches: list[str] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
############################################################################
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.topic import topic
|
||||
from ..core.primitives import Error
|
||||
|
|
@ -11,61 +11,61 @@ from ..core.primitives import Error
|
|||
# get_class(classname) -> (class)
|
||||
# put_class(class) -> (class)
|
||||
# delete_class(classname) -> ()
|
||||
#
|
||||
#
|
||||
# list_flows() -> (flowid[])
|
||||
# get_flow(flowid) -> (flow)
|
||||
# start_flow(flowid, classname) -> ()
|
||||
# stop_flow(flowid) -> ()
|
||||
|
||||
# Prompt services, abstract the prompt generation
|
||||
class FlowRequest(Record):
|
||||
|
||||
operation = String() # list-classes, get-class, put-class, delete-class
|
||||
@dataclass
|
||||
class FlowRequest:
|
||||
operation: str = "" # list-classes, get-class, put-class, delete-class
|
||||
# list-flows, get-flow, start-flow, stop-flow
|
||||
|
||||
# get_class, put_class, delete_class, start_flow
|
||||
class_name = String()
|
||||
class_name: str = ""
|
||||
|
||||
# put_class
|
||||
class_definition = String()
|
||||
class_definition: str = ""
|
||||
|
||||
# start_flow
|
||||
description = String()
|
||||
description: str = ""
|
||||
|
||||
# get_flow, start_flow, stop_flow
|
||||
flow_id = String()
|
||||
flow_id: str = ""
|
||||
|
||||
# start_flow - optional parameters for flow customization
|
||||
parameters = Map(String())
|
||||
|
||||
class FlowResponse(Record):
|
||||
parameters: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class FlowResponse:
|
||||
# list_classes
|
||||
class_names = Array(String())
|
||||
class_names: list[str] = field(default_factory=list)
|
||||
|
||||
# list_flows
|
||||
flow_ids = Array(String())
|
||||
flow_ids: list[str] = field(default_factory=list)
|
||||
|
||||
# get_class
|
||||
class_definition = String()
|
||||
class_definition: str = ""
|
||||
|
||||
# get_flow
|
||||
flow = String()
|
||||
flow: str = ""
|
||||
|
||||
# get_flow
|
||||
description = String()
|
||||
description: str = ""
|
||||
|
||||
# get_flow - parameters used when flow was started
|
||||
parameters = Map(String())
|
||||
parameters: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Everything
|
||||
error = Error()
|
||||
error: Error | None = None
|
||||
|
||||
flow_request_queue = topic(
|
||||
'flow', kind='non-persistent', namespace='request'
|
||||
'flow', qos='q0', namespace='request'
|
||||
)
|
||||
flow_response_queue = topic(
|
||||
'flow', kind='non-persistent', namespace='response'
|
||||
'flow', qos='q0', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Array, Long
|
||||
from dataclasses import dataclass, field
|
||||
from ..core.primitives import Triple, Error
|
||||
from ..core.topic import topic
|
||||
from ..core.metadata import Metadata
|
||||
from ..knowledge.document import Document, TextDocument
|
||||
# Note: Document imports will be updated after knowledge schemas are converted
|
||||
|
||||
# add-document
|
||||
# -> (document_id, document_metadata, content)
|
||||
|
|
@ -50,76 +49,79 @@ from ..knowledge.document import Document, TextDocument
|
|||
# <- (processing_metadata[])
|
||||
# <- (error)
|
||||
|
||||
class DocumentMetadata(Record):
|
||||
id = String()
|
||||
time = Long()
|
||||
kind = String()
|
||||
title = String()
|
||||
comments = String()
|
||||
metadata = Array(Triple())
|
||||
user = String()
|
||||
tags = Array(String())
|
||||
@dataclass
|
||||
class DocumentMetadata:
|
||||
id: str = ""
|
||||
time: int = 0
|
||||
kind: str = ""
|
||||
title: str = ""
|
||||
comments: str = ""
|
||||
metadata: list[Triple] = field(default_factory=list)
|
||||
user: str = ""
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
class ProcessingMetadata(Record):
|
||||
id = String()
|
||||
document_id = String()
|
||||
time = Long()
|
||||
flow = String()
|
||||
user = String()
|
||||
collection = String()
|
||||
tags = Array(String())
|
||||
@dataclass
|
||||
class ProcessingMetadata:
|
||||
id: str = ""
|
||||
document_id: str = ""
|
||||
time: int = 0
|
||||
flow: str = ""
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
class Criteria(Record):
|
||||
key = String()
|
||||
value = String()
|
||||
operator = String()
|
||||
|
||||
class LibrarianRequest(Record):
|
||||
@dataclass
|
||||
class Criteria:
|
||||
key: str = ""
|
||||
value: str = ""
|
||||
operator: str = ""
|
||||
|
||||
@dataclass
|
||||
class LibrarianRequest:
|
||||
# add-document, remove-document, update-document, get-document-metadata,
|
||||
# get-document-content, add-processing, remove-processing, list-documents,
|
||||
# list-processing
|
||||
operation = String()
|
||||
operation: str = ""
|
||||
|
||||
# add-document, remove-document, update-document, get-document-metadata,
|
||||
# get-document-content
|
||||
document_id = String()
|
||||
document_id: str = ""
|
||||
|
||||
# add-processing, remove-processing
|
||||
processing_id = String()
|
||||
processing_id: str = ""
|
||||
|
||||
# add-document, update-document
|
||||
document_metadata = DocumentMetadata()
|
||||
document_metadata: DocumentMetadata | None = None
|
||||
|
||||
# add-processing
|
||||
processing_metadata = ProcessingMetadata()
|
||||
processing_metadata: ProcessingMetadata | None = None
|
||||
|
||||
# add-document
|
||||
content = Bytes()
|
||||
content: bytes = b""
|
||||
|
||||
# list-documents, list-processing
|
||||
user = String()
|
||||
user: str = ""
|
||||
|
||||
# list-documents?, list-processing?
|
||||
collection = String()
|
||||
collection: str = ""
|
||||
|
||||
#
|
||||
criteria = Array(Criteria())
|
||||
#
|
||||
criteria: list[Criteria] = field(default_factory=list)
|
||||
|
||||
class LibrarianResponse(Record):
|
||||
error = Error()
|
||||
document_metadata = DocumentMetadata()
|
||||
content = Bytes()
|
||||
document_metadatas = Array(DocumentMetadata())
|
||||
processing_metadatas = Array(ProcessingMetadata())
|
||||
@dataclass
|
||||
class LibrarianResponse:
|
||||
error: Error | None = None
|
||||
document_metadata: DocumentMetadata | None = None
|
||||
content: bytes = b""
|
||||
document_metadatas: list[DocumentMetadata] = field(default_factory=list)
|
||||
processing_metadatas: list[ProcessingMetadata] = field(default_factory=list)
|
||||
|
||||
# FIXME: Is this right? Using persistence on librarian so that
|
||||
# message chunking works
|
||||
|
||||
librarian_request_queue = topic(
|
||||
'librarian', kind='persistent', namespace='request'
|
||||
'librarian', qos='q1', namespace='request'
|
||||
)
|
||||
librarian_response_queue = topic(
|
||||
'librarian', kind='persistent', namespace='response',
|
||||
'librarian', qos='q1', namespace='response',
|
||||
)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue