Merge pull request #604 from trustgraph-ai/release/v1.8

Merge 1.8 into master
This commit is contained in:
cybermaggedon 2026-01-08 08:57:36 +00:00 committed by GitHub
commit 8dff90f36f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
233 changed files with 13294 additions and 4542 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Setup packages - name: Setup packages
run: make update-package-versions VERSION=1.7.999 run: make update-package-versions VERSION=1.8.999
- name: Setup environment - name: Setup environment
run: python3 -m venv env run: python3 -m venv env

View file

@ -1,8 +1,8 @@
# TrustGraph Librarian API # TrustGraph Librarian API
This API provides document library management for TrustGraph. It handles document storage, This API provides document library management for TrustGraph. It handles document storage,
metadata management, and processing orchestration using hybrid storage (MinIO for content, metadata management, and processing orchestration using hybrid storage (S3-compatible object
Cassandra for metadata) with multi-user support. storage for content, Cassandra for metadata) with multi-user support.
## Request/response ## Request/response
@ -374,13 +374,14 @@ await client.add_processing(
## Features ## Features
- **Hybrid Storage**: MinIO for content, Cassandra for metadata - **Hybrid Storage**: S3-compatible object storage (MinIO, Ceph RGW, AWS S3, etc.) for content, Cassandra for metadata
- **Multi-user Support**: User-based document ownership and access control - **Multi-user Support**: User-based document ownership and access control
- **Rich Metadata**: RDF-style metadata triples and tagging system - **Rich Metadata**: RDF-style metadata triples and tagging system
- **Processing Integration**: Automatic triggering of document processing workflows - **Processing Integration**: Automatic triggering of document processing workflows
- **Content Types**: Support for multiple document formats (PDF, text, etc.) - **Content Types**: Support for multiple document formats (PDF, text, etc.)
- **Collection Management**: Optional document grouping by collection - **Collection Management**: Optional document grouping by collection
- **Metadata Search**: Query documents by metadata criteria - **Metadata Search**: Query documents by metadata criteria
- **Flexible Storage Backend**: Works with any S3-compatible storage (MinIO, Ceph RADOS Gateway, AWS S3, Cloudflare R2, etc.)
## Use Cases ## Use Cases

View file

@ -233,9 +233,13 @@ When a user initiates collection deletion through the librarian service:
#### Collection Management Interface #### Collection Management Interface
All store writers implement a standardized collection management interface with a common schema: **⚠️ LEGACY APPROACH - REPLACED BY CONFIG-BASED PATTERN**
**Message Schema (`StorageManagementRequest`):** The queue-based architecture described below has been replaced with a config-based approach using `CollectionConfigHandler`. All storage backends now receive collection updates via config push messages instead of dedicated management queues.
~~All store writers implement a standardized collection management interface with a common schema:~~
~~**Message Schema (`StorageManagementRequest`):**~~
```json ```json
{ {
"operation": "create-collection" | "delete-collection", "operation": "create-collection" | "delete-collection",
@ -244,24 +248,26 @@ All store writers implement a standardized collection management interface with
} }
``` ```
**Queue Architecture:** ~~**Queue Architecture:**~~
- **Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores - ~~**Vector Store Management Queue** (`vector-storage-management`): Vector/embedding stores~~
- **Object Store Management Queue** (`object-storage-management`): Object/document stores - ~~**Object Store Management Queue** (`object-storage-management`): Object/document stores~~
- **Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores - ~~**Triple Store Management Queue** (`triples-storage-management`): Graph/RDF stores~~
- **Storage Response Queue** (`storage-management-response`): All responses sent here - ~~**Storage Response Queue** (`storage-management-response`): All responses sent here~~
Each store writer implements: **Current Implementation:**
- **Collection Management Handler**: Processes `StorageManagementRequest` messages
- **Create Collection Operation**: Establishes collection in storage backend
- **Delete Collection Operation**: Removes all data associated with collection
- **Collection State Tracking**: Maintains knowledge of which collections exist
- **Message Processing**: Consumes from dedicated management queue
- **Status Reporting**: Returns success/failure via `StorageManagementResponse`
- **Idempotent Operations**: Safe to call create/delete multiple times
**Supported Operations:** All storage backends now use `CollectionConfigHandler`:
- `create-collection`: Create collection in storage backend - **Config Push Integration**: Storage services register for config push notifications
- `delete-collection`: Remove all collection data from storage backend - **Automatic Synchronization**: Collections created/deleted based on config changes
- **Declarative Model**: Collections defined in config service, backends sync to match
- **No Request/Response**: Eliminates coordination overhead and response tracking
- **Collection State Tracking**: Maintained via `known_collections` cache
- **Idempotent Operations**: Safe to process same config multiple times
Each storage backend implements:
- `create_collection(user: str, collection: str, metadata: dict)` - Create collection structures
- `delete_collection(user: str, collection: str)` - Remove all collection data
- `collection_exists(user: str, collection: str) -> bool` - Validate before writes
#### Cassandra Triple Store Refactor #### Cassandra Triple Store Refactor
@ -365,62 +371,33 @@ Comprehensive testing will cover:
- `triples_collection` table for SPO queries and deletion tracking - `triples_collection` table for SPO queries and deletion tracking
- Collection deletion implemented with read-then-delete pattern - Collection deletion implemented with read-then-delete pattern
### 🔄 In Progress Components ### ✅ Migration to Config-Based Pattern - COMPLETED
1. **Collection Creation Broadcast** (`trustgraph-flow/trustgraph/librarian/collection_manager.py`) **All storage backends have been migrated from the queue-based pattern to the config-based `CollectionConfigHandler` pattern.**
- Update `update_collection()` to send "create-collection" to storage backends
- Wait for confirmations from all storage processors
- Handle creation failures appropriately
2. **Document Submission Handler** (`trustgraph-flow/trustgraph/librarian/service.py` or similar) Completed migrations:
- Check if collection exists when document submitted - ✅ `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
- If not exists: Create collection with defaults before processing document - ✅ `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
- Trigger same "create-collection" broadcast as `tg-set-collection` - ✅ `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
- Ensure collection established before document flows to storage processors - ✅ `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
- ✅ `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
### ❌ Pending Components All backends now:
- Inherit from `CollectionConfigHandler`
- Register for config push notifications via `self.register_config_handler(self.on_collection_config)`
- Implement `create_collection(user, collection, metadata)` and `delete_collection(user, collection)`
- Use `collection_exists(user, collection)` to validate before writes
- Automatically sync with config service changes
1. **Collection State Tracking** - Need to implement in each storage backend: Legacy queue-based infrastructure removed:
- **Cassandra Triples**: Use `triples_collection` table with marker triples - ✅ Removed `StorageManagementRequest` and `StorageManagementResponse` schemas
- **Neo4j/Memgraph/FalkorDB**: Create `:CollectionMetadata` nodes - ✅ Removed storage management queue topic definitions
- **Qdrant/Milvus/Pinecone**: Use native collection APIs - ✅ Removed storage management consumer/producer from all backends
- **Cassandra Objects**: Add collection metadata tracking - ✅ Removed `on_storage_management` handlers from all backends
2. **Storage Management Handlers** - Need "create-collection" support in 12 files:
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
- Plus any other storage implementations
3. **Write Operation Validation** - Add collection existence checks to all `store_*` methods
4. **Query Operation Handling** - Update queries to return empty for non-existent collections
### Next Implementation Steps
**Phase 1: Core Infrastructure (2-3 days)**
1. Add collection state tracking methods to all storage backends
2. Implement `collection_exists()` and `create_collection()` methods
**Phase 2: Storage Handlers (1 week)**
3. Add "create-collection" handlers to all storage processors
4. Add write validation to reject non-existent collections
5. Update query handling for non-existent collections
**Phase 3: Collection Manager (2-3 days)**
6. Update collection_manager to broadcast creates
7. Implement response tracking and error handling
**Phase 4: Testing (3-5 days)**
8. End-to-end testing of explicit creation workflow
9. Test all storage backends
10. Validate error handling and edge cases

View file

@ -2,17 +2,29 @@
## Overview ## Overview
TrustGraph uses Python's built-in `logging` module for all logging operations. This provides a standardized, flexible approach to logging across all components of the system. TrustGraph uses Python's built-in `logging` module for all logging operations, with centralized configuration and optional Loki integration for log aggregation. This provides a standardized, flexible approach to logging across all components of the system.
## Default Configuration ## Default Configuration
### Logging Level ### Logging Level
- **Default Level**: `INFO` - **Default Level**: `INFO`
- **Debug Mode**: `DEBUG` (enabled via command-line argument) - **Configurable via**: `--log-level` command-line argument
- **Production**: `WARNING` or `ERROR` as appropriate - **Choices**: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
### Output Destination ### Output Destinations
All logs should be written to **standard output (stdout)** to ensure compatibility with containerized environments and log aggregation systems. 1. **Console (stdout)**: Always enabled - ensures compatibility with containerized environments
2. **Loki**: Optional centralized log aggregation (enabled by default, can be disabled)
## Centralized Logging Module
All logging configuration is managed by `trustgraph.base.logging` module, which provides:
- `add_logging_args(parser)` - Adds standard logging CLI arguments
- `setup_logging(args)` - Configures logging from parsed arguments
This module is used by all server-side components:
- AsyncProcessor-based services
- API Gateway
- MCP Server
## Implementation Guidelines ## Implementation Guidelines
@ -26,39 +38,80 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
``` ```
### 2. Centralized Configuration The logger name is automatically used as a label in Loki for filtering and searching.
The logging configuration should be centralized in `async_processor.py` (or a dedicated logging configuration module) since it's inherited by much of the codebase: ### 2. Service Initialization
All server-side services automatically get logging configuration through the centralized module:
```python ```python
import logging from trustgraph.base import add_logging_args, setup_logging
import argparse import argparse
def setup_logging(log_level='INFO'): def main():
"""Configure logging for the entire application"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
'--log-level',
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Set the logging level (default: INFO)'
)
return parser.parse_args()
# In main execution # Add standard logging arguments (includes Loki configuration)
if __name__ == '__main__': add_logging_args(parser)
args = parse_args()
setup_logging(args.log_level) # Add your service-specific arguments
parser.add_argument('--port', type=int, default=8080)
args = parser.parse_args()
args = vars(args)
# Setup logging early in startup
setup_logging(args)
# Rest of your service initialization
logger = logging.getLogger(__name__)
logger.info("Service starting...")
``` ```
### 3. Logging Best Practices ### 3. Command-Line Arguments
All services support these logging arguments:
**Log Level:**
```bash
--log-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
```
**Loki Configuration:**
```bash
--loki-enabled # Enable Loki (default)
--no-loki-enabled # Disable Loki
--loki-url URL # Loki push URL (default: http://loki:3100/loki/api/v1/push)
--loki-username USERNAME # Optional authentication
--loki-password PASSWORD # Optional authentication
```
**Examples:**
```bash
# Default - INFO level, Loki enabled
./my-service
# Debug mode, console only
./my-service --log-level DEBUG --no-loki-enabled
# Custom Loki server with auth
./my-service --loki-url http://loki.prod:3100/loki/api/v1/push \
--loki-username admin --loki-password secret
```
### 4. Environment Variables
Loki configuration supports environment variable fallbacks:
```bash
export LOKI_URL=http://loki.prod:3100/loki/api/v1/push
export LOKI_USERNAME=admin
export LOKI_PASSWORD=secret
```
Command-line arguments take precedence over environment variables.
### 5. Logging Best Practices
#### Log Levels Usage #### Log Levels Usage
- **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit) - **DEBUG**: Detailed information for diagnosing problems (variable values, function entry/exit)
@ -89,20 +142,25 @@ if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Debug data: {debug_data}") logger.debug(f"Debug data: {debug_data}")
``` ```
### 4. Structured Logging ### 6. Structured Logging with Loki
For complex data, use structured logging: For complex data, use structured logging with extra tags for Loki:
```python ```python
logger.info("Request processed", extra={ logger.info("Request processed", extra={
'tags': {
'request_id': request_id, 'request_id': request_id,
'duration_ms': duration, 'user_id': user_id,
'status_code': status_code, 'status': 'success'
'user_id': user_id }
}) })
``` ```
### 5. Exception Logging These tags become searchable labels in Loki, in addition to automatic labels:
- `severity` - Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
- `logger` - Module name (from `__name__`)
### 7. Exception Logging
Always include stack traces for exceptions: Always include stack traces for exceptions:
@ -114,9 +172,13 @@ except Exception as e:
raise raise
``` ```
### 6. Async Logging Considerations ### 8. Async Logging Considerations
For async code, ensure thread-safe logging: The logging system uses non-blocking queued handlers for Loki:
- Console output is synchronous (fast)
- Loki output is queued with 500-message buffer
- Background thread handles Loki transmission
- No blocking of main application code
```python ```python
import asyncio import asyncio
@ -124,46 +186,165 @@ import logging
async def async_operation(): async def async_operation():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Logging is thread-safe and won't block async operations
logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}") logger.info(f"Starting async operation in task: {asyncio.current_task().get_name()}")
``` ```
## Environment Variables ## Loki Integration
Support environment-based configuration as a fallback: ### Architecture
The logging system uses Python's built-in `QueueHandler` and `QueueListener` for non-blocking Loki integration:
1. **QueueHandler**: Logs are placed in a 500-message queue (non-blocking)
2. **Background Thread**: QueueListener sends logs to Loki asynchronously
3. **Graceful Degradation**: If Loki is unavailable, console logging continues
### Automatic Labels
Every log sent to Loki includes:
- `processor`: Processor identity (e.g., `config-svc`, `text-completion`, `embeddings`)
- `severity`: Log level (DEBUG, INFO, etc.)
- `logger`: Module name (e.g., `trustgraph.gateway.service`, `trustgraph.agent.react.service`)
### Custom Labels
Add custom labels via the `extra` parameter:
```python ```python
import os logger.info("User action", extra={
'tags': {
log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', 'INFO') 'user_id': user_id,
'action': 'document_upload',
'collection': collection_name
}
})
``` ```
### Querying Logs in Loki
```logql
# All logs from a specific processor (recommended - matches Prometheus metrics)
{processor="config-svc"}
{processor="text-completion"}
{processor="embeddings"}
# Error logs from a specific processor
{processor="config-svc", severity="ERROR"}
# Error logs from all processors
{severity="ERROR"}
# Logs from a specific processor with text filter
{processor="text-completion"} |= "Processing"
# All logs from API gateway
{processor="api-gateway"}
# Logs from processors matching pattern
{processor=~".*-completion"}
# Logs with custom tags
{processor="api-gateway"} | json | user_id="12345"
```
### Graceful Degradation
If Loki is unavailable or `python-logging-loki` is not installed:
- Warning message printed to console
- Console logging continues normally
- Application continues running
- No retry logic for Loki connection (fail fast, degrade gracefully)
## Testing ## Testing
During tests, consider using a different logging configuration: During tests, consider using a different logging configuration:
```python ```python
# In test setup # In test setup
logging.getLogger().setLevel(logging.WARNING) # Reduce noise during tests import logging
# Reduce noise during tests
logging.getLogger().setLevel(logging.WARNING)
# Or disable Loki for tests
setup_logging({'log_level': 'WARNING', 'loki_enabled': False})
``` ```
## Monitoring Integration ## Monitoring Integration
Ensure log format is compatible with monitoring tools: ### Standard Format
- Include timestamps in ISO format All logs use consistent format:
- Use consistent field names ```
- Include correlation IDs where applicable 2025-01-09 10:30:45,123 - trustgraph.gateway.service - INFO - Request processed
- Structure logs for easy parsing (JSON format for production) ```
Format components:
- Timestamp (ISO format with milliseconds)
- Logger name (module path)
- Log level
- Message
### Loki Queries for Monitoring
Common monitoring queries:
```logql
# Error rate by processor
rate({severity="ERROR"}[5m]) by (processor)
# Top error-producing processors
topk(5, count_over_time({severity="ERROR"}[1h]) by (processor))
# Recent errors with processor name
{severity="ERROR"} | line_format "{{.processor}}: {{.message}}"
# All agent processors
{processor=~".*agent.*"} |= "exception"
# Specific processor error count
count_over_time({processor="config-svc", severity="ERROR"}[1h])
```
## Security Considerations ## Security Considerations
- Never log sensitive information (passwords, API keys, personal data) - **Never log sensitive information** (passwords, API keys, personal data, tokens)
- Sanitize user input before logging - **Sanitize user input** before logging
- Use placeholders for sensitive fields: `user_id=****1234` - **Use placeholders** for sensitive fields: `user_id=****1234`
- **Loki authentication**: Use `--loki-username` and `--loki-password` for secure deployments
- **Secure transport**: Use HTTPS for Loki URL in production: `https://loki.prod:3100/loki/api/v1/push`
## Dependencies
The centralized logging module requires:
- `python-logging-loki` - For Loki integration (optional, graceful degradation if missing)
Already included in `trustgraph-base/pyproject.toml` and `requirements.txt`.
## Migration Path ## Migration Path
For existing code using print statements: For existing code:
1. Replace `print()` with appropriate logger calls
2. Choose appropriate log levels based on message importance 1. **Services already using AsyncProcessor**: No changes needed, Loki support is automatic
3. Add context to make logs more useful 2. **Services not using AsyncProcessor** (api-gateway, mcp-server): Already updated
4. Test logging output at different levels 3. **CLI tools**: Out of scope - continue using print() or simple logging
### From print() to logging:
```python
# Before
print(f"Processing document {doc_id}")
# After
logger = logging.getLogger(__name__)
logger.info(f"Processing document {doc_id}")
```
## Configuration Summary
| Argument | Default | Environment Variable | Description |
|----------|---------|---------------------|-------------|
| `--log-level` | `INFO` | - | Console and Loki log level |
| `--loki-enabled` | `True` | - | Enable Loki logging |
| `--loki-url` | `http://loki:3100/loki/api/v1/push` | `LOKI_URL` | Loki push endpoint |
| `--loki-username` | `None` | `LOKI_USERNAME` | Loki auth username |
| `--loki-password` | `None` | `LOKI_PASSWORD` | Loki auth password |

View file

@ -0,0 +1,258 @@
# Tech Spec: S3-Compatible Storage Backend Support
## Overview
The Librarian service uses S3-compatible object storage for document blob storage. This spec documents the implementation that enables support for any S3-compatible backend including MinIO, Ceph RADOS Gateway (RGW), AWS S3, Cloudflare R2, DigitalOcean Spaces, and others.
## Architecture
### Storage Components
- **Blob Storage**: S3-compatible object storage via `minio` Python client library
- **Metadata Storage**: Cassandra (stores object_id mapping and document metadata)
- **Affected Component**: Librarian service only
- **Storage Pattern**: Hybrid storage with metadata in Cassandra, content in S3-compatible storage
### Implementation
- **Library**: `minio` Python client (supports any S3-compatible API)
- **Location**: `trustgraph-flow/trustgraph/librarian/blob_store.py`
- **Operations**:
- `add()` - Store blob with UUID object_id
- `get()` - Retrieve blob by object_id
- `remove()` - Delete blob by object_id
- `ensure_bucket()` - Create bucket if not exists
- **Bucket**: `library`
- **Object Path**: `doc/{object_id}`
- **Supported MIME Types**: `text/plain`, `application/pdf`
### Key Files
1. `trustgraph-flow/trustgraph/librarian/blob_store.py` - BlobStore implementation
2. `trustgraph-flow/trustgraph/librarian/librarian.py` - BlobStore initialization
3. `trustgraph-flow/trustgraph/librarian/service.py` - Service configuration
4. `trustgraph-flow/pyproject.toml` - Dependencies (`minio` package)
5. `docs/apis/api-librarian.md` - API documentation
## Supported Storage Backends
The implementation works with any S3-compatible object storage system:
### Tested/Supported
- **Ceph RADOS Gateway (RGW)** - Distributed storage system with S3 API (default configuration)
- **MinIO** - Lightweight self-hosted object storage
- **Garage** - Lightweight geo-distributed S3-compatible storage
### Should Work (S3-Compatible)
- **AWS S3** - Amazon's cloud object storage
- **Cloudflare R2** - Cloudflare's S3-compatible storage
- **DigitalOcean Spaces** - DigitalOcean's object storage
- **Wasabi** - S3-compatible cloud storage
- **Backblaze B2** - S3-compatible backup storage
- Any other service implementing the S3 REST API
## Configuration
### CLI Arguments
```bash
librarian \
--object-store-endpoint <hostname:port> \
--object-store-access-key <access_key> \
--object-store-secret-key <secret_key> \
[--object-store-use-ssl] \
[--object-store-region <region>]
```
**Note:** Do not include `http://` or `https://` in the endpoint. Use `--object-store-use-ssl` to enable HTTPS.
### Environment Variables (Alternative)
```bash
OBJECT_STORE_ENDPOINT=<hostname:port>
OBJECT_STORE_ACCESS_KEY=<access_key>
OBJECT_STORE_SECRET_KEY=<secret_key>
OBJECT_STORE_USE_SSL=true|false # Optional, default: false
OBJECT_STORE_REGION=<region> # Optional
```
### Examples
**Ceph RADOS Gateway (default):**
```bash
--object-store-endpoint ceph-rgw:7480 \
--object-store-access-key object-user \
--object-store-secret-key object-password
```
**MinIO:**
```bash
--object-store-endpoint minio:9000 \
--object-store-access-key minioadmin \
--object-store-secret-key minioadmin
```
**Garage (S3-compatible):**
```bash
--object-store-endpoint garage:3900 \
--object-store-access-key GK000000000000000000000001 \
--object-store-secret-key b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
```
**AWS S3 with SSL:**
```bash
--object-store-endpoint s3.amazonaws.com \
--object-store-access-key AKIAIOSFODNN7EXAMPLE \
--object-store-secret-key wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY \
--object-store-use-ssl \
--object-store-region us-east-1
```
## Authentication
All S3-compatible backends require AWS Signature Version 4 (or v2) authentication:
- **Access Key** - Public identifier (like username)
- **Secret Key** - Private signing key (like password)
The MinIO Python client handles all signature calculation automatically.
### Creating Credentials
**For MinIO:**
```bash
# Use default credentials or create user via MinIO Console
minioadmin / minioadmin
```
**For Ceph RGW:**
```bash
radosgw-admin user create --uid="trustgraph" --display-name="TrustGraph Service"
# Returns access_key and secret_key
```
**For AWS S3:**
- Create IAM user with S3 permissions
- Generate access key in AWS Console
## Library Selection: MinIO Python Client
**Rationale:**
- Lightweight (~500KB vs boto3's ~50MB)
- S3-compatible - works with any S3 API endpoint
- Simpler API than boto3 for basic operations
- Already in use, no migration needed
- Battle-tested with MinIO and other S3 systems
## BlobStore Implementation
**Location:** `trustgraph-flow/trustgraph/librarian/blob_store.py`
```python
from minio import Minio
import io
import logging
logger = logging.getLogger(__name__)
class BlobStore:
"""
S3-compatible blob storage for document content.
Supports MinIO, Ceph RGW, AWS S3, and other S3-compatible backends.
"""
def __init__(self, endpoint, access_key, secret_key, bucket_name,
use_ssl=False, region=None):
"""
Initialize S3-compatible blob storage.
Args:
endpoint: S3 endpoint (e.g., "minio:9000", "ceph-rgw:7480")
access_key: S3 access key
secret_key: S3 secret key
bucket_name: Bucket name for storage
use_ssl: Use HTTPS instead of HTTP (default: False)
region: S3 region (optional, e.g., "us-east-1")
"""
self.client = Minio(
endpoint=endpoint,
access_key=access_key,
secret_key=secret_key,
secure=use_ssl,
region=region,
)
self.bucket_name = bucket_name
protocol = "https" if use_ssl else "http"
logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}")
self.ensure_bucket()
def ensure_bucket(self):
"""Create bucket if it doesn't exist"""
found = self.client.bucket_exists(bucket_name=self.bucket_name)
if not found:
self.client.make_bucket(bucket_name=self.bucket_name)
logger.info(f"Created bucket {self.bucket_name}")
else:
logger.debug(f"Bucket {self.bucket_name} already exists")
async def add(self, object_id, blob, kind):
"""Store blob in S3-compatible storage"""
self.client.put_object(
bucket_name=self.bucket_name,
object_name=f"doc/{object_id}",
length=len(blob),
data=io.BytesIO(blob),
content_type=kind,
)
logger.debug("Add blob complete")
async def remove(self, object_id):
"""Delete blob from S3-compatible storage"""
self.client.remove_object(
bucket_name=self.bucket_name,
object_name=f"doc/{object_id}",
)
logger.debug("Remove blob complete")
async def get(self, object_id):
"""Retrieve blob from S3-compatible storage"""
resp = self.client.get_object(
bucket_name=self.bucket_name,
object_name=f"doc/{object_id}",
)
return resp.read()
```
## Key Benefits
1. **No Vendor Lock-in** - Works with any S3-compatible storage
2. **Lightweight** - MinIO client is only ~500KB
3. **Simple Configuration** - Just endpoint + credentials
4. **No Data Migration** - Drop-in replacement between backends
5. **Battle-Tested** - MinIO client works with all major S3 implementations
## Implementation Status
All code has been updated to use generic S3 parameter names:
- ✅ `blob_store.py` - Updated to accept `endpoint`, `access_key`, `secret_key`
- ✅ `librarian.py` - Updated parameter names
- ✅ `service.py` - Updated CLI arguments and configuration
- ✅ Documentation updated
## Future Enhancements
1. **SSL/TLS Support** - Add `--s3-use-ssl` flag for HTTPS
2. **Retry Logic** - Implement exponential backoff for transient failures
3. **Presigned URLs** - Generate temporary upload/download URLs
4. **Multi-region Support** - Replicate blobs across regions
5. **CDN Integration** - Serve blobs via CDN
6. **Storage Classes** - Use S3 storage classes for cost optimization
7. **Lifecycle Policies** - Automatic archival/deletion
8. **Versioning** - Store multiple versions of blobs
## References
- MinIO Python Client: https://min.io/docs/minio/linux/developers/python/API.html
- Ceph RGW S3 API: https://docs.ceph.com/en/latest/radosgw/s3/
- S3 API Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html

View file

@ -0,0 +1,772 @@
# Technical Specification: Multi-Tenant Support
## Overview
Enable multi-tenant deployments by fixing parameter name mismatches that prevent queue customization and adding Cassandra keyspace parameterization.
## Architecture Context
### Flow-Based Queue Resolution
The TrustGraph system uses a **flow-based architecture** for dynamic queue resolution, which inherently supports multi-tenancy:
- **Flow Definitions** are stored in Cassandra and specify queue names via interface definitions
- **Queue names use templates** with `{id}` variables that are replaced with flow instance IDs
- **Services dynamically resolve queues** by looking up flow configurations at request time
- **Each tenant can have unique flows** with different queue names, providing isolation
Example flow interface definition:
```json
{
"interfaces": {
"triples-store": "persistent://tg/flow/triples-store:{id}",
"graph-embeddings-store": "persistent://tg/flow/graph-embeddings-store:{id}"
}
}
```
When tenant A starts flow `tenant-a-prod` and tenant B starts flow `tenant-b-prod`, they automatically get isolated queues:
- `persistent://tg/flow/triples-store:tenant-a-prod`
- `persistent://tg/flow/triples-store:tenant-b-prod`
**Services correctly designed for multi-tenancy:**
- ✅ **Knowledge Management (cores)** - Dynamically resolves queues from flow configuration passed in requests
**Services needing fixes:**
- 🔴 **Config Service** - Parameter name mismatch prevents queue customization
- 🔴 **Librarian Service** - Hardcoded storage management topics (discussed below)
- 🔴 **All Services** - Cannot customize Cassandra keyspace
## Problem Statement
### Issue #1: Parameter Name Mismatch in AsyncProcessor
- **CLI defines:** `--config-queue` (unclear naming)
- **Argparse converts to:** `config_queue` (in params dict)
- **Code looks for:** `config_push_queue`
- **Result:** Parameter is ignored, defaults to `persistent://tg/config/config`
- **Impact:** Affects all 32+ services inheriting from AsyncProcessor
- **Blocks:** Multi-tenant deployments cannot use tenant-specific config queues
- **Solution:** Rename CLI parameter to `--config-push-queue` for clarity (breaking change acceptable since feature is currently broken)
### Issue #2: Parameter Name Mismatch in Config Service
- **CLI defines:** `--push-queue` (ambiguous naming)
- **Argparse converts to:** `push_queue` (in params dict)
- **Code looks for:** `config_push_queue`
- **Result:** Parameter is ignored
- **Impact:** Config service cannot use custom push queue
- **Solution:** Rename CLI parameter to `--config-push-queue` for consistency and clarity (breaking change acceptable)
### Issue #3: Hardcoded Cassandra Keyspace
- **Current:** Keyspace hardcoded as `"config"`, `"knowledge"`, `"librarian"` in various services
- **Result:** Cannot customize keyspace for multi-tenant deployments
- **Impact:** Config, cores, and librarian services
- **Blocks:** Multiple tenants cannot use separate Cassandra keyspaces
### Issue #4: Collection Management Architecture ✅ COMPLETED
- **Previous:** Collections stored in Cassandra librarian keyspace via separate collections table
- **Previous:** Librarian used 4 hardcoded storage management topics to coordinate collection create/delete:
- `vector_storage_management_topic`
- `object_storage_management_topic`
- `triples_storage_management_topic`
- `storage_management_response_topic`
- **Problems (Resolved):**
- Hardcoded topics could not be customized for multi-tenant deployments
- Complex async coordination between librarian and 4+ storage services
- Separate Cassandra table and management infrastructure
- Non-persistent request/response queues for critical operations
- **Solution Implemented:** Migrated collections to config service storage, use config push for distribution
- **Status:** All storage backends migrated to `CollectionConfigHandler` pattern
## Solution
This spec addresses Issues #1, #2, #3, and #4.
### Part 1: Fix Parameter Name Mismatches
#### Change 1: AsyncProcessor Base Class - Rename CLI Parameter
**File:** `trustgraph-base/trustgraph/base/async_processor.py`
**Line:** 260-264
**Current:**
```python
parser.add_argument(
'--config-queue',
default=default_config_queue,
help=f'Config push queue {default_config_queue}',
)
```
**Fixed:**
```python
parser.add_argument(
'--config-push-queue',
default=default_config_queue,
help=f'Config push queue (default: {default_config_queue})',
)
```
**Rationale:**
- Clearer, more explicit naming
- Matches the internal variable name `config_push_queue`
- Breaking change acceptable since feature is currently non-functional
- No code change needed in params.get() - it already looks for the correct name
#### Change 2: Config Service - Rename CLI Parameter
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
**Line:** 276-279
**Current:**
```python
parser.add_argument(
'--push-queue',
default=default_config_push_queue,
help=f'Config push queue (default: {default_config_push_queue})'
)
```
**Fixed:**
```python
parser.add_argument(
'--config-push-queue',
default=default_config_push_queue,
help=f'Config push queue (default: {default_config_push_queue})'
)
```
**Rationale:**
- Clearer naming - "config-push-queue" is more explicit than just "push-queue"
- Matches the internal variable name `config_push_queue`
- Consistent with AsyncProcessor's `--config-push-queue` parameter
- Breaking change acceptable since feature is currently non-functional
- No code change needed in params.get() - it already looks for the correct name
### Part 2: Add Cassandra Keyspace Parameterization
#### Change 3: Add Keyspace Parameter to cassandra_config Module
**File:** `trustgraph-base/trustgraph/base/cassandra_config.py`
**Add CLI argument** (in `add_cassandra_args()` function):
```python
parser.add_argument(
'--cassandra-keyspace',
default=None,
help='Cassandra keyspace (default: service-specific)'
)
```
**Add environment variable support** (in `resolve_cassandra_config()` function):
```python
keyspace = params.get(
"cassandra_keyspace",
os.environ.get("CASSANDRA_KEYSPACE")
)
```
**Update return value** of `resolve_cassandra_config()`:
- Currently returns: `(hosts, username, password)`
- Change to return: `(hosts, username, password, keyspace)`
**Rationale:**
- Consistent with existing Cassandra configuration pattern
- Available to all services via `add_cassandra_args()`
- Supports both CLI and environment variable configuration
#### Change 4: Config Service - Use Parameterized Keyspace
**File:** `trustgraph-flow/trustgraph/config/service/service.py`
**Line 30** - Remove hardcoded keyspace:
```python
# DELETE THIS LINE:
keyspace = "config"
```
**Lines 69-73** - Update cassandra config resolution:
**Current:**
```python
cassandra_host, cassandra_username, cassandra_password = \
resolve_cassandra_config(params)
```
**Fixed:**
```python
cassandra_host, cassandra_username, cassandra_password, keyspace = \
resolve_cassandra_config(params, default_keyspace="config")
```
**Rationale:**
- Maintains backward compatibility with "config" as default
- Allows override via `--cassandra-keyspace` or `CASSANDRA_KEYSPACE`
#### Change 5: Cores/Knowledge Service - Use Parameterized Keyspace
**File:** `trustgraph-flow/trustgraph/cores/service.py`
**Line 37** - Remove hardcoded keyspace:
```python
# DELETE THIS LINE:
keyspace = "knowledge"
```
**Update cassandra config resolution** (similar location as config service):
```python
cassandra_host, cassandra_username, cassandra_password, keyspace = \
resolve_cassandra_config(params, default_keyspace="knowledge")
```
#### Change 6: Librarian Service - Use Parameterized Keyspace
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
**Line 51** - Remove hardcoded keyspace:
```python
# DELETE THIS LINE:
keyspace = "librarian"
```
**Update cassandra config resolution** (similar location as config service):
```python
cassandra_host, cassandra_username, cassandra_password, keyspace = \
resolve_cassandra_config(params, default_keyspace="librarian")
```
### Part 3: Migrate Collection Management to Config Service
#### Overview
Migrate collections from Cassandra librarian keyspace to config service storage. This eliminates hardcoded storage management topics and simplifies the architecture by using the existing config push mechanism for distribution.
#### Current Architecture
```
API Request → Gateway → Librarian Service
CollectionManager
Cassandra Collections Table (librarian keyspace)
Broadcast to 4 Storage Management Topics (hardcoded)
Wait for 4+ Storage Service Responses
Response to Gateway
```
#### New Architecture
```
API Request → Gateway → Librarian Service
CollectionManager
Config Service API (put/delete/getvalues)
Cassandra Config Table (class='collections', key='user:collection')
Config Push (to all subscribers on config-push-queue)
All Storage Services receive config update independently
```
#### Change 7: Collection Manager - Use Config Service API
**File:** `trustgraph-flow/trustgraph/librarian/collection_manager.py`
**Remove:**
- `LibraryTableStore` usage (Lines 33, 40-41)
- Storage management producers initialization (Lines 86-140)
- `on_storage_response` method (Lines 400-430)
- `pending_deletions` tracking (Lines 57, 90-96, and usage throughout)
**Add:**
- Config service client for API calls (request/response pattern)
**Config Client Setup:**
```python
# In __init__, add config request/response producers/consumers
from trustgraph.schema.services.config import ConfigRequest, ConfigResponse
# Producer for config requests
self.config_request_producer = Producer(
client=pulsar_client,
topic=config_request_queue,
schema=ConfigRequest,
)
# Consumer for config responses (with correlation ID)
self.config_response_consumer = Consumer(
taskgroup=taskgroup,
client=pulsar_client,
flow=None,
topic=config_response_queue,
subscriber=f"{id}-config",
schema=ConfigResponse,
handler=self.on_config_response,
)
# Tracking for pending config requests
self.pending_config_requests = {} # request_id -> asyncio.Event
```
**Modify `list_collections` (Lines 145-180):**
```python
async def list_collections(self, user, tag_filter=None, limit=None):
"""List collections from config service"""
# Send getvalues request to config service
request = ConfigRequest(
id=str(uuid.uuid4()),
operation='getvalues',
type='collections',
)
# Send request and wait for response
response = await self.send_config_request(request)
# Parse collections from response
collections = []
for key, value_json in response.values.items():
if ":" in key:
coll_user, collection = key.split(":", 1)
if coll_user == user:
metadata = json.loads(value_json)
collections.append(CollectionMetadata(**metadata))
# Apply tag filtering in-memory (as before)
if tag_filter:
collections = [c for c in collections if any(tag in c.tags for tag in tag_filter)]
# Apply limit
if limit:
collections = collections[:limit]
return collections
async def send_config_request(self, request):
"""Send config request and wait for response"""
event = asyncio.Event()
self.pending_config_requests[request.id] = event
await self.config_request_producer.send(request)
await event.wait()
return self.pending_config_requests.pop(request.id + "_response")
async def on_config_response(self, message, consumer, flow):
"""Handle config response"""
response = message.value()
if response.id in self.pending_config_requests:
self.pending_config_requests[response.id + "_response"] = response
self.pending_config_requests[response.id].set()
```
**Modify `update_collection` (Lines 182-312):**
```python
async def update_collection(self, user, collection, name, description, tags):
"""Update collection via config service"""
# Create metadata
metadata = CollectionMetadata(
user=user,
collection=collection,
name=name,
description=description,
tags=tags,
)
# Send put request to config service
request = ConfigRequest(
id=str(uuid.uuid4()),
operation='put',
type='collections',
key=f'{user}:{collection}',
value=json.dumps(metadata.to_dict()),
)
response = await self.send_config_request(request)
if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}")
# Config service will trigger config push automatically
# Storage services will receive update and create collections
```
**Modify `delete_collection` (Lines 314-398):**
```python
async def delete_collection(self, user, collection):
"""Delete collection via config service"""
# Send delete request to config service
request = ConfigRequest(
id=str(uuid.uuid4()),
operation='delete',
type='collections',
key=f'{user}:{collection}',
)
response = await self.send_config_request(request)
if response.error:
raise RuntimeError(f"Config delete failed: {response.error.message}")
# Config service will trigger config push automatically
# Storage services will receive update and delete collections
```
**Collection Metadata Format:**
- Stored in config table as: `class='collections', key='user:collection'`
- Value is JSON-serialized CollectionMetadata (without timestamp fields)
- Fields: `user`, `collection`, `name`, `description`, `tags`
- Example: `class='collections', key='alice:my-docs', value='{"user":"alice","collection":"my-docs","name":"My Documents","description":"...","tags":["work"]}'`
#### Change 8: Librarian Service - Remove Storage Management Infrastructure
**File:** `trustgraph-flow/trustgraph/librarian/service.py`
**Remove:**
- Storage management producers (Lines 173-190):
- `vector_storage_management_producer`
- `object_storage_management_producer`
- `triples_storage_management_producer`
- Storage response consumer (Lines 192-201)
- `on_storage_response` handler (Lines 467-473)
**Modify:**
- CollectionManager initialization (Lines 215-224) - remove storage producer parameters
**Note:** External collection API remains unchanged:
- `list-collections`
- `update-collection`
- `delete-collection`
#### Change 9: Remove Collections Table from LibraryTableStore
**File:** `trustgraph-flow/trustgraph/tables/library.py`
**Delete:**
- Collections table CREATE statement (Lines 114-127)
- Collections prepared statements (Lines 205-240)
- All collection methods (Lines 578-717):
- `ensure_collection_exists`
- `list_collections`
- `update_collection`
- `delete_collection`
- `get_collection`
- `create_collection`
**Rationale:**
- Collections now stored in config table
- Breaking change acceptable - no data migration needed
- Simplifies librarian service significantly
#### Change 10: Storage Services - Config-Based Collection Management ✅ COMPLETED
**Status:** All 11 storage backends have been migrated to use `CollectionConfigHandler`.
**Affected Services (11 total):**
- Document embeddings: milvus, pinecone, qdrant
- Graph embeddings: milvus, pinecone, qdrant
- Object storage: cassandra
- Triples storage: cassandra, falkordb, memgraph, neo4j
**Files:**
- `trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py`
- `trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py`
- `trustgraph-flow/trustgraph/storage/objects/cassandra/write.py`
- `trustgraph-flow/trustgraph/storage/triples/cassandra/write.py`
- `trustgraph-flow/trustgraph/storage/triples/falkordb/write.py`
- `trustgraph-flow/trustgraph/storage/triples/memgraph/write.py`
- `trustgraph-flow/trustgraph/storage/triples/neo4j/write.py`
**Implementation Pattern (all services):**
1. **Register config handler in `__init__`:**
```python
# Add after AsyncProcessor initialization
self.register_config_handler(self.on_collection_config)
self.known_collections = set() # Track (user, collection) tuples
```
2. **Implement config handler:**
```python
async def on_collection_config(self, config, version):
"""Handle collection configuration updates"""
logger.info(f"Collection config version: {version}")
if "collections" not in config:
return
# Parse collections from config
# Key format: "user:collection" in config["collections"]
config_collections = set()
for key in config["collections"].keys():
if ":" in key:
user, collection = key.split(":", 1)
config_collections.add((user, collection))
# Determine changes
to_create = config_collections - self.known_collections
to_delete = self.known_collections - config_collections
# Create new collections (idempotent)
for user, collection in to_create:
try:
await self.create_collection_internal(user, collection)
self.known_collections.add((user, collection))
logger.info(f"Created collection: {user}/{collection}")
except Exception as e:
logger.error(f"Failed to create {user}/{collection}: {e}")
# Delete removed collections (idempotent)
for user, collection in to_delete:
try:
await self.delete_collection_internal(user, collection)
self.known_collections.discard((user, collection))
logger.info(f"Deleted collection: {user}/{collection}")
except Exception as e:
logger.error(f"Failed to delete {user}/{collection}: {e}")
```
3. **Initialize known collections on startup:**
```python
async def start(self):
"""Start the processor"""
await super().start()
await self.sync_known_collections()
async def sync_known_collections(self):
"""Query backend to populate known_collections set"""
# Backend-specific implementation:
# - Milvus/Pinecone/Qdrant: List collections/indexes matching naming pattern
# - Cassandra: Query keyspaces or collection metadata
# - Neo4j/Memgraph/FalkorDB: Query CollectionMetadata nodes
pass
```
4. **Refactor existing handler methods:**
```python
# Rename and remove response sending:
# handle_create_collection → create_collection_internal
# handle_delete_collection → delete_collection_internal
async def create_collection_internal(self, user, collection):
"""Create collection (idempotent)"""
# Same logic as current handle_create_collection
# But remove response producer calls
# Handle "already exists" gracefully
pass
async def delete_collection_internal(self, user, collection):
"""Delete collection (idempotent)"""
# Same logic as current handle_delete_collection
# But remove response producer calls
# Handle "not found" gracefully
pass
```
5. **Remove storage management infrastructure:**
- Remove `self.storage_request_consumer` setup and start
- Remove `self.storage_response_producer` setup
- Remove `on_storage_management` dispatcher method
- Remove metrics for storage management
- Remove imports: `StorageManagementRequest`, `StorageManagementResponse`
**Backend-Specific Considerations:**
- **Vector stores (Milvus, Pinecone, Qdrant):** Track logical `(user, collection)` in `known_collections`, but may create multiple backend collections per dimension. Continue lazy creation pattern. Delete operations must remove all dimension variants.
- **Cassandra Objects:** Collections are row properties, not structures. Track keyspace-level information.
- **Graph stores (Neo4j, Memgraph, FalkorDB):** Query `CollectionMetadata` nodes on startup. Create/delete metadata nodes on sync.
- **Cassandra Triples:** Use `KnowledgeGraph` API for collection operations.
**Key Design Points:**
- **Eventual consistency:** No request/response mechanism, config push is broadcast
- **Idempotency:** All create/delete operations must be safe to retry
- **Error handling:** Log errors but don't block config updates
- **Self-healing:** Failed operations will retry on next config push
- **Collection key format:** `"user:collection"` in `config["collections"]`
#### Change 11: Update Collection Schema - Remove Timestamps
**File:** `trustgraph-base/trustgraph/schema/services/collection.py`
**Modify CollectionMetadata (Lines 13-21):**
Remove `created_at` and `updated_at` fields:
```python
class CollectionMetadata(Record):
user = String()
collection = String()
name = String()
description = String()
tags = Array(String())
# Remove: created_at = String()
# Remove: updated_at = String()
```
**Modify CollectionManagementRequest (Lines 25-47):**
Remove timestamp fields:
```python
class CollectionManagementRequest(Record):
operation = String()
user = String()
collection = String()
timestamp = String()
name = String()
description = String()
tags = Array(String())
# Remove: created_at = String()
# Remove: updated_at = String()
tag_filter = Array(String())
limit = Integer()
```
**Rationale:**
- Timestamps don't add value for collections
- Config service maintains its own version tracking
- Simplifies schema and reduces storage
#### Benefits of Config Service Migration
1. ✅ **Eliminates hardcoded storage management topics** - Solves multi-tenant blocker
2. ✅ **Simpler coordination** - No complex async waiting for 4+ storage responses
3. ✅ **Eventual consistency** - Storage services update independently via config push
4. ✅ **Better reliability** - Persistent config push vs non-persistent request/response
5. ✅ **Unified configuration model** - Collections treated as configuration
6. ✅ **Reduces complexity** - Removes ~300 lines of coordination code
7. ✅ **Multi-tenant ready** - Config already supports tenant isolation via keyspace
8. ✅ **Version tracking** - Config service version mechanism provides audit trail
## Implementation Notes
### Backward Compatibility
**Parameter Changes:**
- CLI parameter renames are breaking changes but acceptable (feature currently non-functional)
- Services work without parameters (use defaults)
- Default keyspaces preserved: "config", "knowledge", "librarian"
- Default queue: `persistent://tg/config/config`
**Collection Management:**
- **Breaking change:** Collections table removed from librarian keyspace
- **No data migration provided** - acceptable for this phase
- External collection API unchanged (list/update/delete operations)
- Collection metadata format simplified (timestamps removed)
### Testing Requirements
**Parameter Testing:**
1. Verify `--config-push-queue` parameter works on graph-embeddings service
2. Verify `--config-push-queue` parameter works on text-completion service
3. Verify `--config-push-queue` parameter works on config service
4. Verify `--cassandra-keyspace` parameter works for config service
5. Verify `--cassandra-keyspace` parameter works for cores service
6. Verify `--cassandra-keyspace` parameter works for librarian service
7. Verify services work without parameters (uses defaults)
8. Verify multi-tenant deployment with custom queue names and keyspace
**Collection Management Testing:**
9. Verify `list-collections` operation via config service
10. Verify `update-collection` creates/updates in config table
11. Verify `delete-collection` removes from config table
12. Verify config push is triggered on collection updates
13. Verify tag filtering works with config-based storage
14. Verify collection operations work without timestamp fields
### Multi-Tenant Deployment Example
```bash
# Tenant: tg-dev
graph-embeddings \
-p pulsar+ssl://broker:6651 \
--pulsar-api-key <KEY> \
--config-push-queue persistent://tg-dev/config/config
config-service \
-p pulsar+ssl://broker:6651 \
--pulsar-api-key <KEY> \
--config-push-queue persistent://tg-dev/config/config \
--cassandra-keyspace tg_dev_config
```
## Impact Analysis
### Services Affected by Change 1-2 (CLI Parameter Rename)
All services inheriting from AsyncProcessor or FlowProcessor:
- config-service
- cores-service
- librarian-service
- graph-embeddings
- document-embeddings
- text-completion-* (all providers)
- extract-* (all extractors)
- query-* (all query services)
- retrieval-* (all RAG services)
- storage-* (all storage services)
- And 20+ more services
### Services Affected by Changes 3-6 (Cassandra Keyspace)
- config-service
- cores-service
- librarian-service
### Services Affected by Changes 7-11 (Collection Management)
**Immediate Changes:**
- librarian-service (collection_manager.py, service.py)
- tables/library.py (collections table removal)
- schema/services/collection.py (timestamp removal)
**Completed Changes (Change 10):** ✅
- All storage services (11 total) - migrated to config push for collection updates via `CollectionConfigHandler`
- Storage management schema removed from `storage.py`
## Future Considerations
### Per-User Keyspace Model
Some services use **per-user keyspaces** dynamically, where each user gets their own Cassandra keyspace:
**Services with per-user keyspaces:**
1. **Triples Query Service** (`trustgraph-flow/trustgraph/query/triples/cassandra/service.py:65`)
- Uses `keyspace=query.user`
2. **Objects Query Service** (`trustgraph-flow/trustgraph/query/objects/cassandra/service.py:479`)
- Uses `keyspace=self.sanitize_name(user)`
3. **KnowledgeGraph Direct Access** (`trustgraph-flow/trustgraph/direct/cassandra_kg.py:18`)
- Default parameter `keyspace="trustgraph"`
**Status:** These are **not modified** in this specification.
**Future Review Required:**
- Evaluate whether per-user keyspace model creates tenant isolation issues
- Consider if multi-tenant deployments need keyspace prefix patterns (e.g., `tenant_a_user1`)
- Review for potential user ID collision across tenants
- Assess if single shared keyspace per tenant with user-based row isolation is preferable
**Note:** This does not block the current multi-tenant implementation but should be reviewed before production multi-tenant deployments.
## Implementation Phases
### Phase 1: Parameter Fixes (Changes 1-6)
- Fix `--config-push-queue` parameter naming
- Add `--cassandra-keyspace` parameter support
- **Outcome:** Multi-tenant queue and keyspace configuration enabled
### Phase 2: Collection Management Migration (Changes 7-9, 11)
- Migrate collection storage to config service
- Remove collections table from librarian
- Update collection schema (remove timestamps)
- **Outcome:** Eliminates hardcoded storage management topics, simplifies librarian
### Phase 3: Storage Service Updates (Change 10) ✅ COMPLETED
- Updated all storage services to use config push for collections via `CollectionConfigHandler`
- Removed storage management request/response infrastructure
- Removed legacy schema definitions
- **Outcome:** Complete config-based collection management achieved
## References
- GitHub Issue: https://github.com/trustgraph-ai/trustgraph/issues/582
- Related Files:
- `trustgraph-base/trustgraph/base/async_processor.py`
- `trustgraph-base/trustgraph/base/cassandra_config.py`
- `trustgraph-base/trustgraph/schema/core/topic.py`
- `trustgraph-base/trustgraph/schema/services/collection.py`
- `trustgraph-flow/trustgraph/config/service/service.py`
- `trustgraph-flow/trustgraph/cores/service.py`
- `trustgraph-flow/trustgraph/librarian/service.py`
- `trustgraph-flow/trustgraph/librarian/collection_manager.py`
- `trustgraph-flow/trustgraph/tables/library.py`

View file

@ -0,0 +1,761 @@
# Ontology Knowledge Extraction - Phase 2 Refactor
**Status**: Draft
**Author**: Analysis Session 2025-12-03
**Related**: `ontology.md`, `ontorag.md`
## Overview
This document identifies inconsistencies in the current ontology-based knowledge extraction system and proposes a refactor to improve LLM performance and reduce information loss.
## Current Implementation
### How It Works Now
1. **Ontology Loading** (`ontology_loader.py`)
- Loads ontology JSON with keys like `"fo/Recipe"`, `"fo/Food"`, `"fo/produces"`
- Class IDs include namespace prefix in the key itself
- Example from `food.ontology`:
```json
"classes": {
"fo/Recipe": {
"uri": "http://purl.org/ontology/fo/Recipe",
"rdfs:comment": "A Recipe is a combination..."
}
}
```
2. **Prompt Construction** (`extract.py:299-307`, `ontology-prompt.md`)
- Template receives `classes`, `object_properties`, `datatype_properties` dicts
- Template iterates: `{% for class_id, class_def in classes.items() %}`
- LLM sees: `**fo/Recipe**: A Recipe is a combination...`
- Example output format shows:
```json
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"}
```
3. **Response Parsing** (`extract.py:382-428`)
- Expects JSON array: `[{"subject": "...", "predicate": "...", "object": "..."}]`
- Validates against ontology subset
- Expands URIs via `expand_uri()` (extract.py:473-521)
4. **URI Expansion** (`extract.py:473-521`)
- Checks if value is in `ontology_subset.classes` dict
- If found, extracts URI from class definition
- If not found, constructs URI: `f"https://trustgraph.ai/ontology/{ontology_id}#{value}"`
### Data Flow Example
**Ontology JSON → Loader → Prompt:**
```
"fo/Recipe" → classes["fo/Recipe"] → LLM sees "**fo/Recipe**"
```
**LLM → Parser → Output:**
```
"Recipe" → not in classes["fo/Recipe"] → constructs URI → LOSES original URI
"fo/Recipe" → found in classes → uses original URI → PRESERVES URI
```
## Problems Identified
### 1. **Inconsistent Examples in Prompt**
**Issue**: The prompt template shows class IDs with prefixes (`fo/Recipe`) but the example output uses unprefixed class names (`Recipe`).
**Location**: `ontology-prompt.md:5-52`
```markdown
## Ontology Classes:
- **fo/Recipe**: A Recipe is...
## Example Output:
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"}
```
**Impact**: LLM receives conflicting signals about what format to use.
### 2. **Information Loss in URI Expansion**
**Issue**: When LLM returns unprefixed class names following the example, `expand_uri()` can't find them in the ontology dict and constructs fallback URIs, losing the original proper URIs.
**Location**: `extract.py:494-500`
```python
if value in ontology_subset.classes: # Looks for "Recipe"
class_def = ontology_subset.classes[value] # But key is "fo/Recipe"
if isinstance(class_def, dict) and 'uri' in class_def:
return class_def['uri'] # Never reached!
return f"https://trustgraph.ai/ontology/{ontology_id}#{value}" # Fallback
```
**Impact**:
- Original URI: `http://purl.org/ontology/fo/Recipe`
- Constructed URI: `https://trustgraph.ai/ontology/food#Recipe`
- Semantic meaning lost, breaks interoperability
### 3. **Ambiguous Entity Instance Format**
**Issue**: No clear guidance on entity instance URI format.
**Examples in prompt**:
- `"recipe:cornish-pasty"` (namespace-like prefix)
- `"ingredient:flour"` (different prefix)
**Actual behavior** (extract.py:517-520):
```python
# Treat as entity instance - construct unique URI
normalized = value.replace(" ", "-").lower()
return f"https://trustgraph.ai/{ontology_id}/{normalized}"
```
**Impact**: LLM must guess prefixing convention with no ontology context.
### 4. **No Namespace Prefix Guidance**
**Issue**: The ontology JSON contains namespace definitions (line 10-25 in food.ontology):
```json
"namespaces": {
"fo": "http://purl.org/ontology/fo/",
"rdf": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
...
}
```
But these are never surfaced to the LLM. The LLM doesn't know:
- What "fo" means
- What prefix to use for entities
- Which namespace applies to which elements
### 5. **Labels Not Used in Prompt**
**Issue**: Every class has `rdfs:label` fields (e.g., `{"value": "Recipe", "lang": "en-gb"}`), but the prompt template doesn't use them.
**Current**: Shows only `class_id` and `comment`
```jinja
- **{{class_id}}**{% if class_def.comment %}: {{class_def.comment}}{% endif %}
```
**Available but unused**:
```python
"rdfs:label": [{"value": "Recipe", "lang": "en-gb"}]
```
**Impact**: Could provide human-readable names alongside technical IDs.
## Proposed Solutions
### Option A: Normalize to Unprefixed IDs
**Approach**: Strip prefixes from class IDs before showing to LLM.
**Changes**:
1. Modify `build_extraction_variables()` to transform keys:
```python
classes_for_prompt = {
k.split('/')[-1]: v # "fo/Recipe" → "Recipe"
for k, v in ontology_subset.classes.items()
}
```
2. Update prompt example to match (already uses unprefixed names)
3. Modify `expand_uri()` to handle both formats:
```python
# Try exact match first
if value in ontology_subset.classes:
return ontology_subset.classes[value]['uri']
# Try with prefix
for prefix in ['fo/', 'rdf:', 'rdfs:']:
prefixed = f"{prefix}{value}"
if prefixed in ontology_subset.classes:
return ontology_subset.classes[prefixed]['uri']
```
**Pros**:
- Cleaner, more human-readable
- Matches existing prompt examples
- LLMs work better with simpler tokens
**Cons**:
- Class name collisions if multiple ontologies have same class name
- Loses namespace information
- Requires fallback logic for lookups
### Option B: Use Full Prefixed IDs Consistently
**Approach**: Update examples to use prefixed IDs matching what's shown in the class list.
**Changes**:
1. Update prompt example (ontology-prompt.md:46-52):
```json
[
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "fo/Recipe"},
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
{"subject": "recipe:cornish-pasty", "predicate": "fo/produces", "object": "food:cornish-pasty"},
{"subject": "food:cornish-pasty", "predicate": "rdf:type", "object": "fo/Food"}
]
```
2. Add namespace explanation to prompt:
```markdown
## Namespace Prefixes:
- **fo/**: Food Ontology (http://purl.org/ontology/fo/)
- **rdf:**: RDF Schema
- **rdfs:**: RDF Schema
Use these prefixes exactly as shown when referencing classes and properties.
```
3. Keep `expand_uri()` as-is (works correctly when matches found)
**Pros**:
- Input = Output consistency
- No information loss
- Preserves namespace semantics
- Works with multiple ontologies
**Cons**:
- More verbose tokens for LLM
- Requires LLM to track prefixes
### Option C: Hybrid - Show Both Label and ID
**Approach**: Enhance prompt to show both human-readable labels and technical IDs.
**Changes**:
1. Update prompt template:
```jinja
{% for class_id, class_def in classes.items() %}
- **{{class_id}}** (label: "{{class_def.labels[0].value if class_def.labels else class_id}}"){% if class_def.comment %}: {{class_def.comment}}{% endif %}
{% endfor %}
```
Example output:
```markdown
- **fo/Recipe** (label: "Recipe"): A Recipe is a combination...
```
2. Update instructions:
```markdown
When referencing classes:
- Use the full prefixed ID (e.g., "fo/Recipe") in JSON output
- The label (e.g., "Recipe") is for human understanding only
```
**Pros**:
- Clearest for LLM
- Preserves all information
- Explicit about what to use
**Cons**:
- Longer prompt
- More complex template
## Implemented Approach
**Simplified Entity-Relationship-Attribute Format** - completely replaces the old triple-based format.
The new approach was chosen because:
1. **No Information Loss**: Original URIs preserved correctly
2. **Simpler Logic**: No transformation needed, direct dict lookups work
3. **Namespace Safety**: Handles multiple ontologies without collisions
4. **Semantic Correctness**: Maintains RDF/OWL semantics
## Implementation Complete
### What Was Built:
1. **New Prompt Template** (`prompts/ontology-extract-v2.txt`)
- ✅ Clear sections: Entity Types, Relationships, Attributes
- ✅ Example using full type identifiers (`fo/Recipe`, `fo/has_ingredient`)
- ✅ Instructions to use exact identifiers from schema
- ✅ New JSON format with entities/relationships/attributes arrays
2. **Entity Normalization** (`entity_normalizer.py`)
- ✅ `normalize_entity_name()` - Converts names to URI-safe format
- ✅ `normalize_type_identifier()` - Handles slashes in types (`fo/Recipe``fo-recipe`)
- ✅ `build_entity_uri()` - Creates unique URIs using (name, type) tuple
- ✅ `EntityRegistry` - Tracks entities for deduplication
3. **JSON Parser** (`simplified_parser.py`)
- ✅ Parses new format: `{entities: [...], relationships: [...], attributes: [...]}`
- ✅ Supports kebab-case and snake_case field names
- ✅ Returns structured dataclasses
- ✅ Graceful error handling with logging
4. **Triple Converter** (`triple_converter.py`)
- ✅ `convert_entity()` - Generates type + label triples automatically
- ✅ `convert_relationship()` - Connects entity URIs via properties
- ✅ `convert_attribute()` - Adds literal values
- ✅ Looks up full URIs from ontology definitions
5. **Updated Main Processor** (`extract.py`)
- ✅ Removed old triple-based extraction code
- ✅ Added `extract_with_simplified_format()` method
- ✅ Now exclusively uses new simplified format
- ✅ Calls prompt with `extract-with-ontologies-v2` ID
## Test Cases
### Test 1: URI Preservation
```python
# Given ontology class
classes = {"fo/Recipe": {"uri": "http://purl.org/ontology/fo/Recipe", ...}}
# When LLM returns
llm_output = {"subject": "x", "predicate": "rdf:type", "object": "fo/Recipe"}
# Then expanded URI should be
assert expanded == "http://purl.org/ontology/fo/Recipe"
# Not: "https://trustgraph.ai/ontology/food#Recipe"
```
### Test 2: Multi-Ontology Collision
```python
# Given two ontologies
ont1 = {"fo/Recipe": {...}}
ont2 = {"cooking/Recipe": {...}}
# LLM should use full prefix to disambiguate
llm_output = {"object": "fo/Recipe"} # Not just "Recipe"
```
### Test 3: Entity Instance Format
```python
# Given prompt with food ontology
# LLM should create instances like
{"subject": "recipe:cornish-pasty"} # Namespace-style
{"subject": "food:beef"} # Consistent prefix
```
## Open Questions
1. **Should entity instances use namespace prefixes?**
- Current: `"recipe:cornish-pasty"` (arbitrary)
- Alternative: Use ontology prefix `"fo:cornish-pasty"`?
- Alternative: No prefix, expand in URI `"cornish-pasty"` → full URI?
2. **How to handle domain/range in prompt?**
- Currently shows: `(Recipe → Food)`
- Should it be: `(fo/Recipe → fo/Food)`?
3. **Should we validate domain/range constraints?**
- TODO comment at extract.py:470
- Would catch more errors but more complex
4. **What about inverse properties and equivalences?**
- Ontology has `owl:inverseOf`, `owl:equivalentClass`
- Not currently used in extraction
- Should they be?
## Success Metrics
- ✅ Zero URI information loss (100% preservation of original URIs)
- ✅ LLM output format matches input format
- ✅ No ambiguous examples in prompt
- ✅ Tests pass with multiple ontologies
- ✅ Improved extraction quality (measured by valid triple %)
## Alternative Approach: Simplified Extraction Format
### Philosophy
Instead of asking the LLM to understand RDF/OWL semantics, ask it to do what it's good at: **find entities and relationships in text**.
Let the code handle URI construction, RDF conversion, and semantic web formalities.
### Example: Entity Classification
**Input Text:**
```
Cornish pasty is a traditional British pastry filled with meat and vegetables.
```
**Ontology Schema (shown to LLM):**
```markdown
## Entity Types:
- Recipe: A recipe is a combination of ingredients and a method
- Food: A food is something that can be eaten
- Ingredient: An ingredient combines a quantity and a food
```
**What LLM Returns (Simple JSON):**
```json
{
"entities": [
{
"entity": "Cornish pasty",
"type": "Recipe"
}
]
}
```
**What Code Produces (RDF Triples):**
```python
# 1. Normalize entity name + type to ID (type prevents collisions)
entity_id = "recipe-cornish-pasty" # normalize("Cornish pasty", "Recipe")
entity_uri = "https://trustgraph.ai/food/recipe-cornish-pasty"
# Note: Same name, different type = different URI
# "Cornish pasty" (Recipe) → recipe-cornish-pasty
# "Cornish pasty" (Food) → food-cornish-pasty
# 2. Generate triples
triples = [
# Type triple
Triple(
s=Value(value=entity_uri, is_uri=True),
p=Value(value="http://www.w3.org/1999/02/22-rdf-syntax-ns#type", is_uri=True),
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
),
# Label triple (automatic)
Triple(
s=Value(value=entity_uri, is_uri=True),
p=Value(value="http://www.w3.org/2000/01/rdf-schema#label", is_uri=True),
o=Value(value="Cornish pasty", is_uri=False)
)
]
```
### Benefits
1. **LLM doesn't need to:**
- Understand URI syntax
- Invent identifier prefixes (`recipe:`, `ingredient:`)
- Know about `rdf:type` or `rdfs:label`
- Construct semantic web identifiers
2. **LLM just needs to:**
- Find entities in text
- Map them to ontology classes
- Extract relationships and attributes
3. **Code handles:**
- URI normalization and construction
- RDF triple generation
- Automatic label assignment
- Namespace management
### Why This Works Better
- **Simpler prompt** = less confusion = fewer errors
- **Consistent IDs** = code controls normalization rules
- **Auto-generated labels** = no missing rdfs:label triples
- **LLM focuses on extraction** = what it's actually good at
### Example: Entity Relationships
**Input Text:**
```
Cornish pasty is a traditional British pastry filled with beef and potatoes.
```
**Ontology Schema (shown to LLM):**
```markdown
## Entity Types:
- Recipe: A recipe is a combination of ingredients and a method
- Food: A food is something that can be eaten
- Ingredient: An ingredient combines a quantity and a food
## Relationships:
- has_ingredient: Relates a recipe to an ingredient it uses (Recipe → Ingredient)
- food: Relates an ingredient to the food that is required (Ingredient → Food)
```
**What LLM Returns (Simple JSON):**
```json
{
"entities": [
{
"entity": "Cornish pasty",
"type": "Recipe"
},
{
"entity": "beef",
"type": "Food"
},
{
"entity": "potatoes",
"type": "Food"
}
],
"relationships": [
{
"subject": "Cornish pasty",
"subject-type": "Recipe",
"relation": "has_ingredient",
"object": "beef",
"object-type": "Food"
},
{
"subject": "Cornish pasty",
"subject-type": "Recipe",
"relation": "has_ingredient",
"object": "potatoes",
"object-type": "Food"
}
]
}
```
**What Code Produces (RDF Triples):**
```python
# Normalize entity names to URIs
cornish_pasty_uri = "https://trustgraph.ai/food/cornish-pasty"
beef_uri = "https://trustgraph.ai/food/beef"
potatoes_uri = "https://trustgraph.ai/food/potatoes"
# Look up relation URI from ontology
has_ingredient_uri = "http://purl.org/ontology/fo/ingredients" # from fo/has_ingredient
triples = [
# Entity type triples (as before)
Triple(s=cornish_pasty_uri, p=rdf_type, o="http://purl.org/ontology/fo/Recipe"),
Triple(s=cornish_pasty_uri, p=rdfs_label, o="Cornish pasty"),
Triple(s=beef_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
Triple(s=beef_uri, p=rdfs_label, o="beef"),
Triple(s=potatoes_uri, p=rdf_type, o="http://purl.org/ontology/fo/Food"),
Triple(s=potatoes_uri, p=rdfs_label, o="potatoes"),
# Relationship triples
Triple(
s=Value(value=cornish_pasty_uri, is_uri=True),
p=Value(value=has_ingredient_uri, is_uri=True),
o=Value(value=beef_uri, is_uri=True)
),
Triple(
s=Value(value=cornish_pasty_uri, is_uri=True),
p=Value(value=has_ingredient_uri, is_uri=True),
o=Value(value=potatoes_uri, is_uri=True)
)
]
```
**Key Points:**
- LLM returns natural language entity names: `"Cornish pasty"`, `"beef"`, `"potatoes"`
- LLM includes types to disambiguate: `subject-type`, `object-type`
- LLM uses relation name from schema: `"has_ingredient"`
- Code derives consistent IDs using (name, type): `("Cornish pasty", "Recipe")``recipe-cornish-pasty`
- Code looks up relation URI from ontology: `fo/has_ingredient` → full URI
- Same (name, type) tuple always gets same URI (deduplication)
### Example: Entity Name Disambiguation
**Problem:** Same name can refer to different entity types.
**Real-world case:**
```
"Cornish pasty" can be:
- A Recipe (instructions for making it)
- A Food (the dish itself)
```
**How It's Handled:**
LLM returns both as separate entities:
```json
{
"entities": [
{"entity": "Cornish pasty", "type": "Recipe"},
{"entity": "Cornish pasty", "type": "Food"}
],
"relationships": [
{
"subject": "Cornish pasty",
"subject-type": "Recipe",
"relation": "produces",
"object": "Cornish pasty",
"object-type": "Food"
}
]
}
```
**Code Resolution:**
```python
# Different types → different URIs
recipe_uri = normalize("Cornish pasty", "Recipe")
# → "https://trustgraph.ai/food/recipe-cornish-pasty"
food_uri = normalize("Cornish pasty", "Food")
# → "https://trustgraph.ai/food/food-cornish-pasty"
# Relationship connects them correctly
triple = Triple(
s=recipe_uri, # The Recipe
p="http://purl.org/ontology/fo/produces",
o=food_uri # The Food
)
```
**Why This Works:**
- Type is included in ALL references (entities, relationships, attributes)
- Code uses `(name, type)` tuple as lookup key
- No ambiguity, no collisions
### Example: Entity Attributes
**Input Text:**
```
This Cornish pasty recipe serves 4-6 people and takes 45 minutes to prepare.
```
**Ontology Schema (shown to LLM):**
```markdown
## Entity Types:
- Recipe: A recipe is a combination of ingredients and a method
## Attributes:
- serves: Indicates what the recipe is intended to serve (Recipe → text)
- preparation_time: Time needed to prepare the recipe (Recipe → text)
```
**What LLM Returns (Simple JSON):**
```json
{
"entities": [
{
"entity": "Cornish pasty recipe",
"type": "Recipe"
}
],
"attributes": [
{
"entity": "Cornish pasty recipe",
"entity-type": "Recipe",
"attribute": "serves",
"value": "4-6 people"
},
{
"entity": "Cornish pasty recipe",
"entity-type": "Recipe",
"attribute": "preparation_time",
"value": "45 minutes"
}
]
}
```
**What Code Produces (RDF Triples):**
```python
# Normalize entity name to URI
recipe_uri = "https://trustgraph.ai/food/cornish-pasty-recipe"
# Look up attribute URIs from ontology
serves_uri = "http://purl.org/ontology/fo/serves" # from fo/serves
prep_time_uri = "http://purl.org/ontology/fo/preparation_time" # from fo/preparation_time
triples = [
# Entity type triple
Triple(
s=Value(value=recipe_uri, is_uri=True),
p=Value(value=rdf_type, is_uri=True),
o=Value(value="http://purl.org/ontology/fo/Recipe", is_uri=True)
),
# Label triple (automatic)
Triple(
s=Value(value=recipe_uri, is_uri=True),
p=Value(value=rdfs_label, is_uri=True),
o=Value(value="Cornish pasty recipe", is_uri=False)
),
# Attribute triples (objects are literals, not URIs)
Triple(
s=Value(value=recipe_uri, is_uri=True),
p=Value(value=serves_uri, is_uri=True),
o=Value(value="4-6 people", is_uri=False) # Literal value!
),
Triple(
s=Value(value=recipe_uri, is_uri=True),
p=Value(value=prep_time_uri, is_uri=True),
o=Value(value="45 minutes", is_uri=False) # Literal value!
)
]
```
**Key Points:**
- LLM extracts literal values: `"4-6 people"`, `"45 minutes"`
- LLM includes entity type for disambiguation: `entity-type`
- LLM uses attribute name from schema: `"serves"`, `"preparation_time"`
- Code looks up attribute URI from ontology datatype properties
- **Object is literal** (`is_uri=False`), not a URI reference
- Values stay as natural text, no normalization needed
**Difference from Relationships:**
- Relationships: both subject and object are entities (URIs)
- Attributes: subject is entity (URI), object is literal value (string/number)
### Complete Example: Entities + Relationships + Attributes
**Input Text:**
```
Cornish pasty is a savory pastry filled with beef and potatoes.
This recipe serves 4 people.
```
**What LLM Returns:**
```json
{
"entities": [
{
"entity": "Cornish pasty",
"type": "Recipe"
},
{
"entity": "beef",
"type": "Food"
},
{
"entity": "potatoes",
"type": "Food"
}
],
"relationships": [
{
"subject": "Cornish pasty",
"subject-type": "Recipe",
"relation": "has_ingredient",
"object": "beef",
"object-type": "Food"
},
{
"subject": "Cornish pasty",
"subject-type": "Recipe",
"relation": "has_ingredient",
"object": "potatoes",
"object-type": "Food"
}
],
"attributes": [
{
"entity": "Cornish pasty",
"entity-type": "Recipe",
"attribute": "serves",
"value": "4 people"
}
]
}
```
**Result:** 11 RDF triples generated:
- 3 entity type triples (rdf:type)
- 3 entity label triples (rdfs:label) - automatic
- 2 relationship triples (has_ingredient)
- 1 attribute triple (serves)
All from simple, natural language extractions by the LLM!
## References
- Current implementation: `trustgraph-flow/trustgraph/extract/kg/ontology/extract.py`
- Prompt template: `ontology-prompt.md`
- Test cases: `tests/unit/test_extract/test_ontology/`
- Example ontology: `e2e/test-data/food.ontology`

958
docs/tech-specs/pubsub.md Normal file
View file

@ -0,0 +1,958 @@
# Pub/Sub Infrastructure
## Overview
This document catalogs all connections between the TrustGraph codebase and the pub/sub infrastructure. Currently, the system is hardcoded to use Apache Pulsar. This analysis identifies all integration points to inform future refactoring toward a configurable pub/sub abstraction.
## Current State: Pulsar Integration Points
### 1. Direct Pulsar Client Usage
**Location:** `trustgraph-flow/trustgraph/gateway/service.py`
The API gateway directly imports and instantiates the Pulsar client:
- **Line 20:** `import pulsar`
- **Lines 54-61:** Direct instantiation of `pulsar.Client()` with optional `pulsar.AuthenticationToken()`
- **Lines 33-35:** Default Pulsar host configuration from environment variables
- **Lines 178-192:** CLI arguments for `--pulsar-host`, `--pulsar-api-key`, and `--pulsar-listener`
- **Lines 78, 124:** Passes `pulsar_client` to `ConfigReceiver` and `DispatcherManager`
This is the only location that directly instantiates a Pulsar client outside of the abstraction layer.
### 2. Base Processor Framework
**Location:** `trustgraph-base/trustgraph/base/async_processor.py`
The base class for all processors provides Pulsar connectivity:
- **Line 9:** `import _pulsar` (for exception handling)
- **Line 18:** `from . pubsub import PulsarClient`
- **Line 38:** Creates `pulsar_client_object = PulsarClient(**params)`
- **Lines 104-108:** Properties exposing `pulsar_host` and `pulsar_client`
- **Line 250:** Static method `add_args()` calls `PulsarClient.add_args(parser)` for CLI arguments
- **Lines 223-225:** Exception handling for `_pulsar.Interrupted`
All processors inherit from `AsyncProcessor`, making this the central integration point.
### 3. Consumer Abstraction
**Location:** `trustgraph-base/trustgraph/base/consumer.py`
Consumes messages from queues and invokes handler functions:
**Pulsar imports:**
- **Line 12:** `from pulsar.schema import JsonSchema`
- **Line 13:** `import pulsar`
- **Line 14:** `import _pulsar`
**Pulsar-specific usage:**
- **Lines 100, 102:** `pulsar.InitialPosition.Earliest` / `pulsar.InitialPosition.Latest`
- **Line 108:** `JsonSchema(self.schema)` wrapper
- **Line 110:** `pulsar.ConsumerType.Shared`
- **Lines 104-111:** `self.client.subscribe()` with Pulsar-specific parameters
- **Lines 143, 150, 65:** `consumer.unsubscribe()` and `consumer.close()` methods
- **Line 162:** `_pulsar.Timeout` exception
- **Lines 182, 205, 232:** `consumer.acknowledge()` / `consumer.negative_acknowledge()`
**Spec file:** `trustgraph-base/trustgraph/base/consumer_spec.py`
- **Line 22:** References `processor.pulsar_client`
### 4. Producer Abstraction
**Location:** `trustgraph-base/trustgraph/base/producer.py`
Sends messages to queues:
**Pulsar imports:**
- **Line 2:** `from pulsar.schema import JsonSchema`
**Pulsar-specific usage:**
- **Line 49:** `JsonSchema(self.schema)` wrapper
- **Lines 47-51:** `self.client.create_producer()` with Pulsar-specific parameters (topic, schema, chunking_enabled)
- **Lines 31, 76:** `producer.close()` method
- **Lines 64-65:** `producer.send()` with message and properties
**Spec file:** `trustgraph-base/trustgraph/base/producer_spec.py`
- **Line 18:** References `processor.pulsar_client`
### 5. Publisher Abstraction
**Location:** `trustgraph-base/trustgraph/base/publisher.py`
Asynchronous message publishing with queue buffering:
**Pulsar imports:**
- **Line 2:** `from pulsar.schema import JsonSchema`
- **Line 6:** `import pulsar`
**Pulsar-specific usage:**
- **Line 52:** `JsonSchema(self.schema)` wrapper
- **Lines 50-54:** `self.client.create_producer()` with Pulsar-specific parameters
- **Lines 101, 103:** `producer.send()` with message and optional properties
- **Lines 106-107:** `producer.flush()` and `producer.close()` methods
### 6. Subscriber Abstraction
**Location:** `trustgraph-base/trustgraph/base/subscriber.py`
Provides multi-recipient message distribution from queues:
**Pulsar imports:**
- **Line 6:** `from pulsar.schema import JsonSchema`
- **Line 8:** `import _pulsar`
**Pulsar-specific usage:**
- **Line 55:** `JsonSchema(self.schema)` wrapper
- **Line 57:** `self.client.subscribe(**subscribe_args)`
- **Lines 101, 136, 160, 167-172:** Pulsar exceptions: `_pulsar.Timeout`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
- **Lines 159, 166, 170:** Consumer methods: `negative_acknowledge()`, `unsubscribe()`, `close()`
- **Lines 247, 251:** Message acknowledgment: `acknowledge()`, `negative_acknowledge()`
**Spec file:** `trustgraph-base/trustgraph/base/subscriber_spec.py`
- **Line 19:** References `processor.pulsar_client`
### 7. Schema System (Heart of Darkness)
**Location:** `trustgraph-base/trustgraph/schema/`
Every message schema in the system is defined using Pulsar's schema framework.
**Core primitives:** `schema/core/primitives.py`
- **Line 2:** `from pulsar.schema import Record, String, Boolean, Array, Integer`
- All schemas inherit from Pulsar's `Record` base class
- All field types are Pulsar types: `String()`, `Integer()`, `Boolean()`, `Array()`, `Map()`, `Double()`
**Example schemas:**
- `schema/services/llm.py` (Line 2): `from pulsar.schema import Record, String, Array, Double, Integer, Boolean`
- `schema/services/config.py` (Line 2): `from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer`
**Topic naming:** `schema/core/topic.py`
- **Lines 2-3:** Topic format: `{kind}://{tenant}/{namespace}/{topic}`
- This URI structure is Pulsar-specific (e.g., `persistent://tg/flow/config`)
**Impact:**
- All request/response message definitions throughout the codebase use Pulsar schemas
- This includes services for: config, flow, llm, prompt, query, storage, agent, collection, diagnosis, library, lookup, nlp_query, objects_query, retrieval, structured_query
- Schema definitions are imported and used extensively across all processors and services
## Summary
### Pulsar Dependencies by Category
1. **Client instantiation:**
- Direct: `gateway/service.py`
- Abstracted: `async_processor.py``pubsub.py` (PulsarClient)
2. **Message transport:**
- Consumer: `consumer.py`, `consumer_spec.py`
- Producer: `producer.py`, `producer_spec.py`
- Publisher: `publisher.py`
- Subscriber: `subscriber.py`, `subscriber_spec.py`
3. **Schema system:**
- Base types: `schema/core/primitives.py`
- All service schemas: `schema/services/*.py`
- Topic naming: `schema/core/topic.py`
4. **Pulsar-specific concepts required:**
- Topic-based messaging
- Schema system (Record, field types)
- Shared subscriptions
- Message acknowledgment (positive/negative)
- Consumer positioning (earliest/latest)
- Message properties
- Initial positions and consumer types
- Chunking support
- Persistent vs non-persistent topics
### Refactoring Challenges
The good news: The abstraction layer (Consumer, Producer, Publisher, Subscriber) provides a clean encapsulation of most Pulsar interactions.
The challenges:
1. **Schema system pervasiveness:** Every message definition uses `pulsar.schema.Record` and Pulsar field types
2. **Pulsar-specific enums:** `InitialPosition`, `ConsumerType`
3. **Pulsar exceptions:** `_pulsar.Timeout`, `_pulsar.Interrupted`, `_pulsar.InvalidConfiguration`, `_pulsar.AlreadyClosed`
4. **Method signatures:** `acknowledge()`, `negative_acknowledge()`, `subscribe()`, `create_producer()`, etc.
5. **Topic URI format:** Pulsar's `kind://tenant/namespace/topic` structure
### Next Steps
To make the pub/sub infrastructure configurable, we need to:
1. Create an abstraction interface for the client/schema system
2. Abstract Pulsar-specific enums and exceptions
3. Create schema wrappers or alternative schema definitions
4. Implement the interface for both Pulsar and alternative systems (Kafka, RabbitMQ, Redis Streams, etc.)
5. Update `pubsub.py` to be configurable and support multiple backends
6. Provide migration path for existing deployments
## Approach Draft 1: Adapter Pattern with Schema Translation Layer
### Key Insight
The **schema system** is the deepest integration point - everything else flows from it. We need to solve this first, or we'll be rewriting the entire codebase.
### Strategy: Minimal Disruption with Adapters
**1. Keep Pulsar schemas as the internal representation**
- Don't rewrite all the schema definitions
- Schemas remain `pulsar.schema.Record` internally
- Use adapters to translate at the boundary between our code and the pub/sub backend
**2. Create a pub/sub abstraction layer:**
```
┌─────────────────────────────────────┐
│ Existing Code (unchanged) │
│ - Uses Pulsar schemas internally │
│ - Consumer/Producer/Publisher │
└──────────────┬──────────────────────┘
┌──────────────┴──────────────────────┐
│ PubSubFactory (configurable) │
│ - Creates backend-specific client │
└──────────────┬──────────────────────┘
┌──────┴──────┐
│ │
┌───────▼─────┐ ┌────▼─────────┐
│ PulsarAdapter│ │ KafkaAdapter │ etc...
│ (passthrough)│ │ (translates) │
└──────────────┘ └──────────────┘
```
**3. Define abstract interfaces:**
- `PubSubClient` - client connection
- `PubSubProducer` - sending messages
- `PubSubConsumer` - receiving messages
- `SchemaAdapter` - translating Pulsar schemas to/from JSON or backend-specific formats
**4. Implementation details:**
For **Pulsar adapter**: Nearly passthrough, minimal translation
For **other backends** (Kafka, RabbitMQ, etc.):
- Serialize Pulsar Record objects to JSON/bytes
- Map concepts like:
- `InitialPosition.Earliest/Latest` → Kafka's auto.offset.reset
- `acknowledge()` → Kafka's commit
- `negative_acknowledge()` → Re-queue or DLQ pattern
- Topic URIs → Backend-specific topic names
### Analysis
**Pros:**
- ✅ Minimal code changes to existing services
- ✅ Schemas stay as-is (no massive rewrite)
- ✅ Gradual migration path
- ✅ Pulsar users see no difference
- ✅ New backends added via adapters
**Cons:**
- ⚠️ Still carries Pulsar dependency (for schema definitions)
- ⚠️ Some impedance mismatch translating concepts
### Alternative Consideration
Create a **TrustGraph schema system** that's pub/sub agnostic (using dataclasses or Pydantic), then generate Pulsar/Kafka/etc schemas from it. This requires rewriting every schema file and potentially breaking changes.
### Recommendation for Draft 1
Start with the **adapter approach** because:
1. It's pragmatic - works with existing code
2. Proves the concept with minimal risk
3. Can evolve to a native schema system later if needed
4. Configuration-driven: one env var switches backends
## Approach Draft 2: Backend-Agnostic Schema System with Dataclasses
### Core Concept
Use Python **dataclasses** as the neutral schema definition format. Each pub/sub backend provides its own serialization/deserialization for dataclasses, eliminating the need for Pulsar schemas to remain in the codebase.
### Schema Polymorphism at the Factory Level
Instead of translating Pulsar schemas, **each backend provides its own schema handling** that works with standard Python dataclasses.
### Publisher Flow
```python
# 1. Get the configured backend from factory
pubsub = get_pubsub() # Returns PulsarBackend, MQTTBackend, etc.
# 2. Get schema class from the backend
# (Can be imported directly - backend-agnostic)
from trustgraph.schema.services.llm import TextCompletionRequest
# 3. Create a producer/publisher for a specific topic
producer = pubsub.create_producer(
topic="text-completion-requests",
schema=TextCompletionRequest # Tells backend what schema to use
)
# 4. Create message instances (same API regardless of backend)
request = TextCompletionRequest(
system="You are helpful",
prompt="Hello world",
streaming=False
)
# 5. Send the message
producer.send(request) # Backend serializes appropriately
```
### Consumer Flow
```python
# 1. Get the configured backend
pubsub = get_pubsub()
# 2. Create a consumer
consumer = pubsub.subscribe(
topic="text-completion-requests",
schema=TextCompletionRequest # Tells backend how to deserialize
)
# 3. Receive and deserialize
msg = consumer.receive()
request = msg.value() # Returns TextCompletionRequest dataclass instance
# 4. Use the data (type-safe access)
print(request.system) # "You are helpful"
print(request.prompt) # "Hello world"
print(request.streaming) # False
```
### What Happens Behind the Scenes
**For Pulsar backend:**
- `create_producer()` → creates Pulsar producer with JSON schema or dynamically generated Record
- `send(request)` → serializes dataclass to JSON/Pulsar format, sends to Pulsar
- `receive()` → gets Pulsar message, deserializes back to dataclass
**For MQTT backend:**
- `create_producer()` → connects to MQTT broker, no schema registration needed
- `send(request)` → converts dataclass to JSON, publishes to MQTT topic
- `receive()` → subscribes to MQTT topic, deserializes JSON to dataclass
**For Kafka backend:**
- `create_producer()` → creates Kafka producer, registers Avro schema if needed
- `send(request)` → serializes dataclass to Avro format, sends to Kafka
- `receive()` → gets Kafka message, deserializes Avro back to dataclass
### Key Design Points
1. **Schema object creation**: The dataclass instance (`TextCompletionRequest(...)`) is identical regardless of backend
2. **Backend handles encoding**: Each backend knows how to serialize its dataclass to the wire format
3. **Schema definition at creation**: When creating producer/consumer, you specify the schema type
4. **Type safety preserved**: You get back a proper `TextCompletionRequest` object, not a dict
5. **No backend leakage**: Application code never imports backend-specific libraries
### Example Transformation
**Current (Pulsar-specific):**
```python
# schema/services/llm.py
from pulsar.schema import Record, String, Boolean, Integer
class TextCompletionRequest(Record):
system = String()
prompt = String()
streaming = Boolean()
```
**New (Backend-agnostic):**
```python
# schema/services/llm.py
from dataclasses import dataclass
@dataclass
class TextCompletionRequest:
system: str
prompt: str
streaming: bool = False
```
### Backend Integration
Each backend handles serialization/deserialization of dataclasses:
**Pulsar backend:**
- Dynamically generate `pulsar.schema.Record` classes from dataclasses
- Or serialize dataclasses to JSON and use Pulsar's JSON schema
- Maintains compatibility with existing Pulsar deployments
**MQTT/Redis backend:**
- Direct JSON serialization of dataclass instances
- Use `dataclasses.asdict()` / `from_dict()`
- Lightweight, no schema registry needed
**Kafka backend:**
- Generate Avro schemas from dataclass definitions
- Use Confluent's schema registry
- Type-safe serialization with schema evolution support
### Architecture
```
┌─────────────────────────────────────┐
│ Application Code │
│ - Uses dataclass schemas │
│ - Backend-agnostic │
└──────────────┬──────────────────────┘
┌──────────────┴──────────────────────┐
│ PubSubFactory (configurable) │
│ - get_pubsub() returns backend │
└──────────────┬──────────────────────┘
┌──────┴──────┐
│ │
┌───────▼─────────┐ ┌────▼──────────────┐
│ PulsarBackend │ │ MQTTBackend │
│ - JSON schema │ │ - JSON serialize │
│ - or dynamic │ │ - Simple queues │
│ Record gen │ │ │
└─────────────────┘ └───────────────────┘
```
### Implementation Details
**1. Schema definitions:** Plain dataclasses with type hints
- `str`, `int`, `bool`, `float` for primitives
- `list[T]` for arrays
- `dict[str, T]` for maps
- Nested dataclasses for complex types
**2. Each backend provides:**
- Serializer: `dataclass → bytes/wire format`
- Deserializer: `bytes/wire format → dataclass`
- Schema registration (if needed, like Pulsar/Kafka)
**3. Consumer/Producer abstraction:**
- Already exists (consumer.py, producer.py)
- Update to use backend's serialization
- Remove direct Pulsar imports
**4. Type mappings:**
- Pulsar `String()` → Python `str`
- Pulsar `Integer()` → Python `int`
- Pulsar `Boolean()` → Python `bool`
- Pulsar `Array(T)` → Python `list[T]`
- Pulsar `Map(K, V)` → Python `dict[K, V]`
- Pulsar `Double()` → Python `float`
- Pulsar `Bytes()` → Python `bytes`
### Migration Path
1. **Create dataclass versions** of all schemas in `trustgraph/schema/`
2. **Update backend classes** (Consumer, Producer, Publisher, Subscriber) to use backend-provided serialization
3. **Implement PulsarBackend** with JSON schema or dynamic Record generation
4. **Test with Pulsar** to ensure backward compatibility with existing deployments
5. **Add new backends** (MQTT, Kafka, Redis, etc.) as needed
6. **Remove Pulsar imports** from schema files
### Benefits
**No pub/sub dependency** in schema definitions
**Standard Python** - easy to understand, type-check, document
**Modern tooling** - works with mypy, IDE autocomplete, linters
**Backend-optimized** - each backend uses native serialization
**No translation overhead** - direct serialization, no adapters
**Type safety** - real objects with proper types
**Easy validation** - can use Pydantic if needed
### Challenges & Solutions
**Challenge:** Pulsar's `Record` has runtime field validation
**Solution:** Use Pydantic dataclasses for validation if needed, or Python 3.10+ dataclass features with `__post_init__`
**Challenge:** Some Pulsar-specific features (like `Bytes` type)
**Solution:** Map to `bytes` type in dataclass, backend handles encoding appropriately
**Challenge:** Topic naming (`persistent://tenant/namespace/topic`)
**Solution:** Abstract topic names in schema definitions, backend converts to proper format
**Challenge:** Schema evolution and versioning
**Solution:** Each backend handles this according to its capabilities (Pulsar schema versions, Kafka schema registry, etc.)
**Challenge:** Nested complex types
**Solution:** Use nested dataclasses, backends recursively serialize/deserialize
### Design Decisions
1. **Plain dataclasses or Pydantic?**
- ✅ **Decision: Use plain Python dataclasses**
- Simpler, no additional dependencies
- Validation not required in practice
- Easier to understand and maintain
2. **Schema evolution:**
- ✅ **Decision: No versioning mechanism needed**
- Schemas are stable and long-lasting
- Updates typically add new fields (backward compatible)
- Backends handle schema evolution according to their capabilities
3. **Backward compatibility:**
- ✅ **Decision: Major version change, no backward compatibility required**
- Will be a breaking change with migration instructions
- Clean break allows for better design
- Migration guide will be provided for existing deployments
4. **Nested types and complex structures:**
- ✅ **Decision: Use nested dataclasses naturally**
- Python dataclasses handle nesting perfectly
- `list[T]` for arrays, `dict[K, V]` for maps
- Backends recursively serialize/deserialize
- Example:
```python
@dataclass
class Value:
value: str
is_uri: bool
@dataclass
class Triple:
s: Value # Nested dataclass
p: Value
o: Value
@dataclass
class GraphQuery:
triples: list[Triple] # Array of nested dataclasses
metadata: dict[str, str]
```
5. **Default values and optional fields:**
- ✅ **Decision: Mix of required, defaults, and optional fields**
- Required fields: No default value
- Fields with defaults: Always present, have sensible default
- Truly optional fields: `T | None = None`, omitted from serialization when `None`
- Example:
```python
@dataclass
class TextCompletionRequest:
system: str # Required, no default
prompt: str # Required, no default
streaming: bool = False # Optional with default value
metadata: dict | None = None # Truly optional, can be absent
```
**Important serialization semantics:**
When `metadata = None`:
```json
{
"system": "...",
"prompt": "...",
"streaming": false
// metadata field NOT PRESENT
}
```
When `metadata = {}` (explicitly empty):
```json
{
"system": "...",
"prompt": "...",
"streaming": false,
"metadata": {} // Field PRESENT but empty
}
```
**Key distinction:**
- `None` → field absent from JSON (not serialized)
- Empty value (`{}`, `[]`, `""`) → field present with empty value
- This matters semantically: "not provided" vs "explicitly empty"
- Serialization backends must skip `None` fields, not encode as `null`
## Approach Draft 3: Implementation Details
### Generic Queue Naming Format
Replace backend-specific queue names with a generic format that backends can map appropriately.
**Format:** `{qos}/{tenant}/{namespace}/{queue-name}`
Where:
- `qos`: Quality of Service level
- `q0` = best-effort (fire and forget, no acknowledgment)
- `q1` = at-least-once (requires acknowledgment)
- `q2` = exactly-once (two-phase acknowledgment)
- `tenant`: Logical grouping for multi-tenancy
- `namespace`: Sub-grouping within tenant
- `queue-name`: Actual queue/topic name
**Examples:**
```
q1/tg/flow/text-completion-requests
q2/tg/config/config-push
q0/tg/metrics/stats
```
### Backend Topic Mapping
Each backend maps the generic format to its native format:
**Pulsar Backend:**
```python
def map_topic(self, generic_topic: str) -> str:
# Parse: q1/tg/flow/text-completion-requests
qos, tenant, namespace, queue = generic_topic.split('/', 3)
# Map QoS to persistence
persistence = 'persistent' if qos in ['q1', 'q2'] else 'non-persistent'
# Return Pulsar URI: persistent://tg/flow/text-completion-requests
return f"{persistence}://{tenant}/{namespace}/{queue}"
```
**MQTT Backend:**
```python
def map_topic(self, generic_topic: str) -> tuple[str, int]:
# Parse: q1/tg/flow/text-completion-requests
qos, tenant, namespace, queue = generic_topic.split('/', 3)
# Map QoS level
qos_level = {'q0': 0, 'q1': 1, 'q2': 2}[qos]
# Build MQTT topic including tenant/namespace for proper namespacing
mqtt_topic = f"{tenant}/{namespace}/{queue}"
return mqtt_topic, qos_level
```
### Updated Topic Helper Function
```python
# schema/core/topic.py
def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
"""
Create a generic topic identifier that can be mapped by backends.
Args:
queue_name: The queue/topic name
qos: Quality of service
- 'q0' = best-effort (no ack)
- 'q1' = at-least-once (ack required)
- 'q2' = exactly-once (two-phase ack)
tenant: Tenant identifier for multi-tenancy
namespace: Namespace within tenant
Returns:
Generic topic string: qos/tenant/namespace/queue_name
Examples:
topic('my-queue') # q1/tg/flow/my-queue
topic('config', qos='q2', namespace='config') # q2/tg/config/config
"""
return f"{qos}/{tenant}/{namespace}/{queue_name}"
```
### Configuration and Initialization
**Command-Line Arguments + Environment Variables:**
```python
# In base/async_processor.py - add_args() method
@staticmethod
def add_args(parser):
# Pub/sub backend selection
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
choices=['pulsar', 'mqtt'],
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)'
)
# Pulsar-specific configuration
parser.add_argument(
'--pulsar-host',
default=os.getenv('PULSAR_HOST', 'pulsar://localhost:6650'),
help='Pulsar host (default: pulsar://localhost:6650, env: PULSAR_HOST)'
)
parser.add_argument(
'--pulsar-api-key',
default=os.getenv('PULSAR_API_KEY', None),
help='Pulsar API key (env: PULSAR_API_KEY)'
)
parser.add_argument(
'--pulsar-listener',
default=os.getenv('PULSAR_LISTENER', None),
help='Pulsar listener name (env: PULSAR_LISTENER)'
)
# MQTT-specific configuration
parser.add_argument(
'--mqtt-host',
default=os.getenv('MQTT_HOST', 'localhost'),
help='MQTT broker host (default: localhost, env: MQTT_HOST)'
)
parser.add_argument(
'--mqtt-port',
type=int,
default=int(os.getenv('MQTT_PORT', '1883')),
help='MQTT broker port (default: 1883, env: MQTT_PORT)'
)
parser.add_argument(
'--mqtt-username',
default=os.getenv('MQTT_USERNAME', None),
help='MQTT username (env: MQTT_USERNAME)'
)
parser.add_argument(
'--mqtt-password',
default=os.getenv('MQTT_PASSWORD', None),
help='MQTT password (env: MQTT_PASSWORD)'
)
```
**Factory Function:**
```python
# In base/pubsub.py or base/pubsub_factory.py
def get_pubsub(**config) -> PubSubBackend:
"""
Create and return a pub/sub backend based on configuration.
Args:
config: Configuration dict from command-line args
Must include 'pubsub_backend' key
Returns:
Backend instance (PulsarBackend, MQTTBackend, etc.)
"""
backend_type = config.get('pubsub_backend', 'pulsar')
if backend_type == 'pulsar':
return PulsarBackend(
host=config.get('pulsar_host'),
api_key=config.get('pulsar_api_key'),
listener=config.get('pulsar_listener'),
)
elif backend_type == 'mqtt':
return MQTTBackend(
host=config.get('mqtt_host'),
port=config.get('mqtt_port'),
username=config.get('mqtt_username'),
password=config.get('mqtt_password'),
)
else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
```
**Usage in AsyncProcessor:**
```python
# In async_processor.py
class AsyncProcessor:
def __init__(self, **params):
self.id = params.get("id")
# Create backend from config (replaces PulsarClient)
self.pubsub = get_pubsub(**params)
# Rest of initialization...
```
### Backend Interface
```python
class PubSubBackend(Protocol):
"""Protocol defining the interface all pub/sub backends must implement."""
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""
Create a producer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
schema: Dataclass type for messages
options: Backend-specific options (e.g., chunking_enabled)
Returns:
Backend-specific producer instance
"""
...
def create_consumer(
self,
topic: str,
subscription: str,
schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options
) -> BackendConsumer:
"""
Create a consumer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
subscription: Subscription/consumer group name
schema: Dataclass type for messages
initial_position: 'earliest' or 'latest' (MQTT may ignore)
consumer_type: 'shared', 'exclusive', 'failover' (MQTT may ignore)
options: Backend-specific options
Returns:
Backend-specific consumer instance
"""
...
def close(self) -> None:
"""Close the backend connection."""
...
```
```python
class BackendProducer(Protocol):
"""Protocol for backend-specific producer."""
def send(self, message: Any, properties: dict = {}) -> None:
"""Send a message (dataclass instance) with optional properties."""
...
def flush(self) -> None:
"""Flush any buffered messages."""
...
def close(self) -> None:
"""Close the producer."""
...
```
```python
class BackendConsumer(Protocol):
"""Protocol for backend-specific consumer."""
def receive(self, timeout_millis: int = 2000) -> Message:
"""
Receive a message from the topic.
Raises:
TimeoutError: If no message received within timeout
"""
...
def acknowledge(self, message: Message) -> None:
"""Acknowledge successful processing of a message."""
...
def negative_acknowledge(self, message: Message) -> None:
"""Negative acknowledge - triggers redelivery."""
...
def unsubscribe(self) -> None:
"""Unsubscribe from the topic."""
...
def close(self) -> None:
"""Close the consumer."""
...
```
```python
class Message(Protocol):
"""Protocol for a received message."""
def value(self) -> Any:
"""Get the deserialized message (dataclass instance)."""
...
def properties(self) -> dict:
"""Get message properties/metadata."""
...
```
### Existing Classes Refactoring
The existing `Consumer`, `Producer`, `Publisher`, `Subscriber` classes remain largely intact:
**Current responsibilities (keep):**
- Async threading model and taskgroups
- Reconnection logic and retry handling
- Metrics collection
- Rate limiting
- Concurrency management
**Changes needed:**
- Remove direct Pulsar imports (`pulsar.schema`, `pulsar.InitialPosition`, etc.)
- Accept `BackendProducer`/`BackendConsumer` instead of Pulsar client
- Delegate actual pub/sub operations to backend instances
- Map generic concepts to backend calls
**Example refactoring:**
```python
# OLD - consumer.py
class Consumer:
def __init__(self, client, topic, subscriber, schema, ...):
self.client = client # Direct Pulsar client
# ...
async def consumer_run(self):
# Uses pulsar.InitialPosition, pulsar.ConsumerType
self.consumer = self.client.subscribe(
topic=self.topic,
schema=JsonSchema(self.schema),
initial_position=pulsar.InitialPosition.Earliest,
consumer_type=pulsar.ConsumerType.Shared,
)
# NEW - consumer.py
class Consumer:
def __init__(self, backend_consumer, schema, ...):
self.backend_consumer = backend_consumer # Backend-specific consumer
self.schema = schema
# ...
async def consumer_run(self):
# Backend consumer already created with right settings
# Just use it directly
while self.running:
msg = await asyncio.to_thread(
self.backend_consumer.receive,
timeout_millis=2000
)
await self.handle_message(msg)
```
### Backend-Specific Behaviors
**Pulsar Backend:**
- Maps `q0``non-persistent://`, `q1`/`q2``persistent://`
- Supports all consumer types (shared, exclusive, failover)
- Supports initial position (earliest/latest)
- Native message acknowledgment
- Schema registry support
**MQTT Backend:**
- Maps `q0`/`q1`/`q2` → MQTT QoS levels 0/1/2
- Includes tenant/namespace in topic path for namespacing
- Auto-generates client IDs from subscription names
- Ignores initial position (no message history in basic MQTT)
- Ignores consumer type (MQTT uses client IDs, not consumer groups)
- Simple publish/subscribe model
### Design Decisions Summary
1. ✅ **Generic queue naming**: `qos/tenant/namespace/queue-name` format
2. ✅ **QoS in queue ID**: Determined by queue definition, not configuration
3. ✅ **Reconnection**: Handled by Consumer/Producer classes, not backends
4. ✅ **MQTT topics**: Include tenant/namespace for proper namespacing
5. ✅ **Message history**: MQTT ignores `initial_position` parameter (future enhancement)
6. ✅ **Client IDs**: MQTT backend auto-generates from subscription name
### Future Enhancements
**MQTT message history:**
- Could add optional persistence layer (e.g., retained messages, external store)
- Would allow supporting `initial_position='earliest'`
- Not required for initial implementation

File diff suppressed because it is too large Load diff

54
ontology-prompt.md Normal file
View file

@ -0,0 +1,54 @@
You are a knowledge extraction expert. Extract structured triples from text using ONLY the provided ontology elements.
## Ontology Classes:
{% for class_id, class_def in classes.items() %}
- **{{class_id}}**{% if class_def.subclass_of %} (subclass of {{class_def.subclass_of}}){% endif %}{% if class_def.comment %}: {{class_def.comment}}{% endif %}
{% endfor %}
## Object Properties (connect entities):
{% for prop_id, prop_def in object_properties.items() %}
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
{% endfor %}
## Datatype Properties (entity attributes):
{% for prop_id, prop_def in datatype_properties.items() %}
- **{{prop_id}}**{% if prop_def.domain and prop_def.range %} ({{prop_def.domain}} → {{prop_def.range}}){% endif %}{% if prop_def.comment %}: {{prop_def.comment}}{% endif %}
{% endfor %}
## Text to Analyze:
{{text}}
## Extraction Rules:
1. Only use classes defined above for entity types
2. Only use properties defined above for relationships and attributes
3. Respect domain and range constraints where specified
4. For class instances, use `rdf:type` as the predicate
5. Include `rdfs:label` for new entities to provide human-readable names
6. Extract all relevant triples that can be inferred from the text
7. Use entity URIs or meaningful identifiers as subjects/objects
## Output Format:
Return ONLY a valid JSON array (no markdown, no code blocks) containing objects with these fields:
- "subject": the subject entity (URI or identifier)
- "predicate": the property (from ontology or rdf:type/rdfs:label)
- "object": the object entity or literal value
Important: Return raw JSON only, with no markdown formatting, no code blocks, and no backticks.
## Example Output:
[
{"subject": "recipe:cornish-pasty", "predicate": "rdf:type", "object": "Recipe"},
{"subject": "recipe:cornish-pasty", "predicate": "rdfs:label", "object": "Cornish Pasty"},
{"subject": "recipe:cornish-pasty", "predicate": "has_ingredient", "object": "ingredient:flour"},
{"subject": "ingredient:flour", "predicate": "rdf:type", "object": "Ingredient"},
{"subject": "ingredient:flour", "predicate": "rdfs:label", "object": "plain flour"}
]
Now extract triples from the text above.

View file

@ -21,3 +21,4 @@ prometheus-client
pyarrow pyarrow
boto3 boto3
ollama ollama
python-logging-loki

60
tests/conftest.py Normal file
View file

@ -0,0 +1,60 @@
"""
Global pytest configuration for all tests.
This conftest.py applies to all test directories.
"""
import pytest
# import asyncio
# import tracemalloc
# import warnings
from unittest.mock import MagicMock
# Uncomment the lines below to enable asyncio debug mode and tracemalloc
# for tracing unawaited coroutines and their creation points
# tracemalloc.start()
# asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
# warnings.simplefilter("always", ResourceWarning)
# warnings.simplefilter("always", RuntimeWarning)
@pytest.fixture(scope="session", autouse=True)
def mock_loki_handler(session_mocker=None):
"""
Mock LokiHandler to prevent connection attempts during tests.
This fixture runs once per test session and prevents the logging
module from trying to connect to a Loki server that doesn't exist
in the test environment.
"""
# Try to import logging_loki and mock it if available
try:
import logging_loki
# Create a mock LokiHandler that does nothing
original_loki_handler = logging_loki.LokiHandler
class MockLokiHandler:
"""Mock LokiHandler that doesn't make network calls."""
def __init__(self, *args, **kwargs):
pass
def emit(self, record):
pass
def flush(self):
pass
def close(self):
pass
# Replace the real LokiHandler with our mock
logging_loki.LokiHandler = MockLokiHandler
yield
# Restore original after tests
logging_loki.LokiHandler = original_loki_handler
except ImportError:
# If logging_loki isn't installed, no need to mock
yield

View file

@ -257,7 +257,6 @@ class TestAgentMessageContracts:
# Act # Act
request = AgentRequest( request = AgentRequest(
question="What comes next?", question="What comes next?",
plan="Multi-step plan",
state="processing", state="processing",
history=history_steps history=history_steps
) )
@ -588,7 +587,6 @@ class TestSerializationContracts:
request = AgentRequest( request = AgentRequest(
question="Test with array", question="Test with array",
plan="Test plan",
state="Test state", state="Test state",
history=steps history=steps
) )

View file

@ -189,6 +189,7 @@ class TestObjectsCassandraContracts:
assert result == expected_val assert result == expected_val
assert isinstance(result, expected_type) or result is None assert isinstance(result, expected_type) or result is None
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
def test_extracted_object_serialization_contract(self): def test_extracted_object_serialization_contract(self):
"""Test that ExtractedObject can be serialized/deserialized correctly""" """Test that ExtractedObject can be serialized/deserialized correctly"""
# Create test object # Create test object
@ -408,6 +409,7 @@ class TestObjectsCassandraContractsBatch:
assert isinstance(single_batch_object.values[0], dict) assert isinstance(single_batch_object.values[0], dict)
assert single_batch_object.values[0]["customer_id"] == "CUST999" assert single_batch_object.values[0]["customer_id"] == "CUST999"
@pytest.mark.skip(reason="ExtractedObject is a dataclass, not a Pulsar Record type")
def test_extracted_object_batch_serialization_contract(self): def test_extracted_object_batch_serialization_contract(self):
"""Test that batched ExtractedObject can be serialized/deserialized correctly""" """Test that batched ExtractedObject can be serialized/deserialized correctly"""
# Create batch object # Create batch object

View file

@ -480,11 +480,15 @@ def streaming_chunk_collector():
class ChunkCollector: class ChunkCollector:
def __init__(self): def __init__(self):
self.chunks = [] self.chunks = []
self.end_of_stream_flags = []
self.complete = False self.complete = False
async def collect(self, chunk): async def collect(self, chunk, end_of_stream=False):
"""Async callback to collect chunks""" """Async callback to collect chunks with end_of_stream flag"""
self.chunks.append(chunk) self.chunks.append(chunk)
self.end_of_stream_flags.append(end_of_stream)
if end_of_stream:
self.complete = True
def get_full_text(self): def get_full_text(self):
"""Concatenate all chunk content""" """Concatenate all chunk content"""
@ -496,6 +500,14 @@ def streaming_chunk_collector():
return [c.get("chunk_type") for c in self.chunks] return [c.get("chunk_type") for c in self.chunks]
return [] return []
def verify_streaming_protocol(self):
"""Verify that streaming protocol is correct"""
assert len(self.chunks) > 0, "Should have received at least one chunk"
assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag"
assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True"
assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True"
assert self.complete is True, "Should be marked complete after final chunk"
return ChunkCollector return ChunkCollector

View file

@ -47,8 +47,9 @@ Args: {
"}" "}"
] ]
for chunk in chunks: for i, chunk in enumerate(chunks):
await chunk_callback(chunk) is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text return full_text
else: else:
@ -312,8 +313,10 @@ Final Answer: AI is the simulation of human intelligence in machines."""
call_count += 1 call_count += 1
if streaming and chunk_callback: if streaming and chunk_callback:
for chunk in response.split(): chunks = response.split()
await chunk_callback(chunk + " ") for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
return response return response
return response return response

View file

@ -373,13 +373,13 @@ class TestMultipleHostsHandling:
from trustgraph.base.cassandra_config import resolve_cassandra_config from trustgraph.base.cassandra_config import resolve_cassandra_config
# Test various whitespace scenarios # Test various whitespace scenarios
hosts1, _, _ = resolve_cassandra_config(host='host1, host2 , host3') hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
assert hosts1 == ['host1', 'host2', 'host3'] assert hosts1 == ['host1', 'host2', 'host3']
hosts2, _, _ = resolve_cassandra_config(host='host1,host2,host3,') hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
assert hosts2 == ['host1', 'host2', 'host3'] assert hosts2 == ['host1', 'host2', 'host3']
hosts3, _, _ = resolve_cassandra_config(host=' host1 , host2 ') hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
assert hosts3 == ['host1', 'host2'] assert hosts3 == ['host1', 'host2']

View file

@ -46,9 +46,16 @@ class TestDocumentRagStreaming:
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback: if streaming and chunk_callback:
# Simulate streaming chunks # Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response(): async for chunk in mock_streaming_llm_response():
await chunk_callback(chunk) chunks.append(chunk)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text return full_text
else: else:
# Non-streaming response - same text # Non-streaming response - same text
@ -89,6 +96,9 @@ class TestDocumentRagStreaming:
assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
# Verify streaming protocol compliance
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks # Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text() full_from_chunks = collector.get_full_text()
assert result == full_from_chunks assert result == full_from_chunks
@ -117,7 +127,7 @@ class TestDocumentRagStreaming:
# Act - Streaming # Act - Streaming
streaming_chunks = [] streaming_chunks = []
async def collect(chunk): async def collect(chunk, end_of_stream):
streaming_chunks.append(chunk) streaming_chunks.append(chunk)
streaming_result = await document_rag_streaming.query( streaming_result = await document_rag_streaming.query(

View file

@ -59,9 +59,16 @@ class TestGraphRagStreaming:
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data." full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback: if streaming and chunk_callback:
# Simulate streaming chunks # Simulate streaming chunks with end_of_stream flags
chunks = []
async for chunk in mock_streaming_llm_response(): async for chunk in mock_streaming_llm_response():
await chunk_callback(chunk) chunks.append(chunk)
# Send all chunks with end_of_stream=False except the last
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text return full_text
else: else:
# Non-streaming response - same text # Non-streaming response - same text
@ -102,6 +109,9 @@ class TestGraphRagStreaming:
assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
# Verify streaming protocol compliance
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks # Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text() full_from_chunks = collector.get_full_text()
assert result == full_from_chunks assert result == full_from_chunks
@ -128,7 +138,7 @@ class TestGraphRagStreaming:
# Act - Streaming # Act - Streaming
streaming_chunks = [] streaming_chunks = []
async def collect(chunk): async def collect(chunk, end_of_stream):
streaming_chunks.append(chunk) streaming_chunks.append(chunk)
streaming_result = await graph_rag_streaming.query( streaming_result = await graph_rag_streaming.query(

View file

@ -59,16 +59,16 @@ class MockWebSocket:
@pytest.fixture @pytest.fixture
def mock_pulsar_client(): def mock_backend():
"""Mock Pulsar client for integration testing.""" """Mock backend for integration testing."""
client = MagicMock() backend = MagicMock()
# Mock producer # Mock producer
producer = MagicMock() producer = MagicMock()
producer.send = MagicMock() producer.send = MagicMock()
producer.flush = MagicMock() producer.flush = MagicMock()
producer.close = MagicMock() producer.close = MagicMock()
client.create_producer.return_value = producer backend.create_producer.return_value = producer
# Mock consumer # Mock consumer
consumer = MagicMock() consumer = MagicMock()
@ -78,17 +78,15 @@ def mock_pulsar_client():
consumer.pause_message_listener = MagicMock() consumer.pause_message_listener = MagicMock()
consumer.unsubscribe = MagicMock() consumer.unsubscribe = MagicMock()
consumer.close = MagicMock() consumer.close = MagicMock()
client.subscribe.return_value = consumer backend.create_consumer.return_value = consumer
return client return backend
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_import_graceful_shutdown_integration(): async def test_import_graceful_shutdown_integration(mock_backend):
"""Test import path handles shutdown gracefully with real message flow.""" """Test import path handles shutdown gracefully with real message flow."""
mock_client = MagicMock() mock_producer = mock_backend.create_producer.return_value
mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer
# Track sent messages # Track sent messages
sent_messages = [] sent_messages = []
@ -104,7 +102,7 @@ async def test_import_graceful_shutdown_integration():
import_handler = TriplesImport( import_handler = TriplesImport(
ws=ws, ws=ws,
running=running, running=running,
pulsar_client=mock_client, backend=mock_backend,
queue="test-triples-import" queue="test-triples-import"
) )
@ -151,11 +149,9 @@ async def test_import_graceful_shutdown_integration():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_export_no_message_loss_integration(): async def test_export_no_message_loss_integration(mock_backend):
"""Test export path doesn't lose acknowledged messages.""" """Test export path doesn't lose acknowledged messages."""
mock_client = MagicMock() mock_consumer = mock_backend.create_consumer.return_value
mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer
# Create test messages # Create test messages
test_messages = [] test_messages = []
@ -202,7 +198,7 @@ async def test_export_no_message_loss_integration():
export_handler = TriplesExport( export_handler = TriplesExport(
ws=ws, ws=ws,
running=running, running=running,
pulsar_client=mock_client, backend=mock_backend,
queue="test-triples-export", queue="test-triples-export",
consumer="test-consumer", consumer="test-consumer",
subscriber="test-subscriber" subscriber="test-subscriber"
@ -245,14 +241,14 @@ async def test_export_no_message_loss_integration():
async def test_concurrent_import_export_shutdown(): async def test_concurrent_import_export_shutdown():
"""Test concurrent import and export shutdown scenarios.""" """Test concurrent import and export shutdown scenarios."""
# Setup mock clients # Setup mock clients
import_client = MagicMock() import_backend = MagicMock()
export_client = MagicMock() export_backend = MagicMock()
import_producer = MagicMock() import_producer = MagicMock()
export_consumer = MagicMock() export_consumer = MagicMock()
import_client.create_producer.return_value = import_producer import_backend.create_producer.return_value = import_producer
export_client.subscribe.return_value = export_consumer export_backend.subscribe.return_value = export_consumer
# Track operations # Track operations
import_operations = [] import_operations = []
@ -280,14 +276,14 @@ async def test_concurrent_import_export_shutdown():
import_handler = TriplesImport( import_handler = TriplesImport(
ws=import_ws, ws=import_ws,
running=import_running, running=import_running,
pulsar_client=import_client, backend=import_backend,
queue="concurrent-import" queue="concurrent-import"
) )
export_handler = TriplesExport( export_handler = TriplesExport(
ws=export_ws, ws=export_ws,
running=export_running, running=export_running,
pulsar_client=export_client, backend=export_backend,
queue="concurrent-export", queue="concurrent-export",
consumer="concurrent-consumer", consumer="concurrent-consumer",
subscriber="concurrent-subscriber" subscriber="concurrent-subscriber"
@ -328,9 +324,9 @@ async def test_concurrent_import_export_shutdown():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_close_during_message_processing(): async def test_websocket_close_during_message_processing():
"""Test graceful handling when websocket closes during active message processing.""" """Test graceful handling when websocket closes during active message processing."""
mock_client = MagicMock() mock_backend_local = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend_local.create_producer.return_value = mock_producer
# Simulate slow message processing # Simulate slow message processing
processed_messages = [] processed_messages = []
@ -346,7 +342,7 @@ async def test_websocket_close_during_message_processing():
import_handler = TriplesImport( import_handler = TriplesImport(
ws=ws, ws=ws,
running=running, running=running,
pulsar_client=mock_client, backend=mock_backend_local,
queue="slow-processing-import" queue="slow-processing-import"
) )
@ -395,9 +391,9 @@ async def test_websocket_close_during_message_processing():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_backpressure_during_shutdown(): async def test_backpressure_during_shutdown():
"""Test graceful shutdown under backpressure conditions.""" """Test graceful shutdown under backpressure conditions."""
mock_client = MagicMock() mock_backend_local = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend_local.subscribe.return_value = mock_consumer
# Mock slow websocket # Mock slow websocket
class SlowWebSocket(MockWebSocket): class SlowWebSocket(MockWebSocket):
@ -411,7 +407,7 @@ async def test_backpressure_during_shutdown():
export_handler = TriplesExport( export_handler = TriplesExport(
ws=ws, ws=ws,
running=running, running=running,
pulsar_client=mock_client, backend=mock_backend_local,
queue="backpressure-export", queue="backpressure-export",
consumer="backpressure-consumer", consumer="backpressure-consumer",
subscriber="backpressure-subscriber" subscriber="backpressure-subscriber"

View file

@ -117,7 +117,7 @@ class TestObjectsCassandraIntegration:
assert "customer_records" in processor.schemas assert "customer_records" in processor.schemas
# Step 1.5: Create the collection first (simulate tg-set-collection) # Step 1.5: Create the collection first (simulate tg-set-collection)
await processor.create_collection("test_user", "import_2024") await processor.create_collection("test_user", "import_2024", {})
# Step 2: Process an ExtractedObject # Step 2: Process an ExtractedObject
test_obj = ExtractedObject( test_obj = ExtractedObject(
@ -213,8 +213,8 @@ class TestObjectsCassandraIntegration:
assert len(processor.schemas) == 2 assert len(processor.schemas) == 2
# Create collections first # Create collections first
await processor.create_collection("shop", "catalog") await processor.create_collection("shop", "catalog", {})
await processor.create_collection("shop", "sales") await processor.create_collection("shop", "sales", {})
# Process objects for different schemas # Process objects for different schemas
product_obj = ExtractedObject( product_obj = ExtractedObject(
@ -263,7 +263,7 @@ class TestObjectsCassandraIntegration:
) )
# Create collection first # Create collection first
await processor.create_collection("test", "test") await processor.create_collection("test", "test", {})
# Create object missing required field # Create object missing required field
test_obj = ExtractedObject( test_obj = ExtractedObject(
@ -302,7 +302,7 @@ class TestObjectsCassandraIntegration:
) )
# Create collection first # Create collection first
await processor.create_collection("logger", "app_events") await processor.create_collection("logger", "app_events", {})
# Process object # Process object
test_obj = ExtractedObject( test_obj = ExtractedObject(
@ -407,7 +407,7 @@ class TestObjectsCassandraIntegration:
# Create all collections first # Create all collections first
for coll in collections: for coll in collections:
await processor.create_collection("analytics", coll) await processor.create_collection("analytics", coll, {})
for coll in collections: for coll in collections:
obj = ExtractedObject( obj = ExtractedObject(
@ -486,7 +486,7 @@ class TestObjectsCassandraIntegration:
) )
# Create collection first # Create collection first
await processor.create_collection("test_user", "batch_import") await processor.create_collection("test_user", "batch_import", {})
msg = MagicMock() msg = MagicMock()
msg.value.return_value = batch_obj msg.value.return_value = batch_obj
@ -532,7 +532,7 @@ class TestObjectsCassandraIntegration:
) )
# Create collection first # Create collection first
await processor.create_collection("test", "empty") await processor.create_collection("test", "empty", {})
# Process empty batch object # Process empty batch object
empty_obj = ExtractedObject( empty_obj = ExtractedObject(
@ -573,7 +573,7 @@ class TestObjectsCassandraIntegration:
) )
# Create collection first # Create collection first
await processor.create_collection("test", "mixed") await processor.create_collection("test", "mixed", {})
# Single object (backward compatibility) # Single object (backward compatibility)
single_obj = ExtractedObject( single_obj = ExtractedObject(

View file

@ -0,0 +1,351 @@
"""
Integration tests for RAG service streaming protocol compliance.
These tests verify that RAG services correctly forward end_of_stream flags
and don't duplicate final chunks, ensuring proper streaming semantics.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
class TestGraphRagStreamingProtocol:
"""Integration tests for GraphRAG streaming protocol"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3]]
return client
@pytest.fixture
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = ["entity1", "entity2"]
return client
@pytest.fixture
def mock_triples_client(self):
"""Mock triples client"""
client = AsyncMock()
client.query.return_value = []
return client
@pytest.fixture
def mock_streaming_prompt_client(self):
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
await chunk_callback("The", False)
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
else:
return "The answer is here."
client.kg_prompt.side_effect = kg_prompt_side_effect
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
prompt_client=mock_streaming_prompt_client,
verbose=False
)
@pytest.mark.asyncio
async def test_callback_receives_end_of_stream_parameter(self, graph_rag):
"""Test that callback receives end_of_stream parameter"""
# Arrange
callback = AsyncMock()
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
)
# Assert - callback should receive (chunk, end_of_stream) signature
assert callback.call_count == 4
# All calls should have 2 arguments
for call_args in callback.call_args_list:
assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)"
@pytest.mark.asyncio
async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag):
"""Test that end_of_stream flags are forwarded correctly"""
# Arrange
chunks_with_flags = []
async def collect(chunk, end_of_stream):
chunks_with_flags.append((chunk, end_of_stream))
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert
assert len(chunks_with_flags) == 4
# First three chunks should have end_of_stream=False
assert chunks_with_flags[0] == ("The", False)
assert chunks_with_flags[1] == (" answer", False)
assert chunks_with_flags[2] == (" is here.", False)
# Final chunk should have end_of_stream=True
assert chunks_with_flags[3] == ("", True)
@pytest.mark.asyncio
async def test_no_duplicate_final_chunk(self, graph_rag):
"""Test that final chunk is not duplicated"""
# Arrange
chunks = []
async def collect(chunk, end_of_stream):
chunks.append(chunk)
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert - should have exactly 4 chunks, no duplicates
assert len(chunks) == 4
assert chunks == ["The", " answer", " is here.", ""]
# The last chunk appears exactly once
assert chunks.count("") == 1
@pytest.mark.asyncio
async def test_exactly_one_end_of_stream_true(self, graph_rag):
"""Test that exactly one message has end_of_stream=True"""
# Arrange
end_of_stream_flags = []
async def collect(chunk, end_of_stream):
end_of_stream_flags.append(end_of_stream)
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert - exactly one True
assert end_of_stream_flags.count(True) == 1
assert end_of_stream_flags.count(False) == 3
@pytest.mark.asyncio
async def test_empty_final_chunk_preserved(self, graph_rag):
"""Test that empty final chunks are preserved and forwarded"""
# Arrange
final_chunk = None
final_flag = None
async def collect(chunk, end_of_stream):
nonlocal final_chunk, final_flag
if end_of_stream:
final_chunk = chunk
final_flag = end_of_stream
# Act
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert
assert final_flag is True
assert final_chunk == "", "Empty final chunk should be preserved"
class TestDocumentRagStreamingProtocol:
"""Integration tests for DocumentRAG streaming protocol"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3]]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
client = AsyncMock()
client.query.return_value = ["doc1", "doc2"]
return client
@pytest.fixture
def mock_streaming_prompt_client(self):
"""Mock prompt client with streaming support"""
client = AsyncMock()
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
# Simulate streaming with non-empty final chunk (some LLMs do this)
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
return ""
else:
return "Document summary."
client.document_prompt.side_effect = document_prompt_side_effect
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
verbose=False
)
@pytest.mark.asyncio
async def test_callback_receives_end_of_stream_parameter(self, document_rag):
"""Test that callback receives end_of_stream parameter"""
# Arrange
callback = AsyncMock()
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 3
for call_args in callback.call_args_list:
assert len(call_args.args) == 2
@pytest.mark.asyncio
async def test_non_empty_final_chunk_preserved(self, document_rag):
"""Test that non-empty final chunks are preserved with correct flag"""
# Arrange
chunks_with_flags = []
async def collect(chunk, end_of_stream):
chunks_with_flags.append((chunk, end_of_stream))
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert
assert len(chunks_with_flags) == 3
assert chunks_with_flags[0] == ("Document", False)
assert chunks_with_flags[1] == (" summary", False)
assert chunks_with_flags[2] == (".", True) # Non-empty final chunk
@pytest.mark.asyncio
async def test_no_duplicate_final_chunk(self, document_rag):
"""Test that final chunk is not duplicated"""
# Arrange
chunks = []
async def collect(chunk, end_of_stream):
chunks.append(chunk)
# Act
await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collect
)
# Assert - final "." appears exactly once
assert chunks.count(".") == 1
assert chunks == ["Document", " summary", "."]
class TestStreamingProtocolEdgeCases:
"""Test edge cases in streaming protocol"""
@pytest.mark.asyncio
async def test_multiple_empty_chunks_before_final(self):
"""Test handling of multiple empty chunks (edge case)"""
# Arrange
client = AsyncMock()
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
else:
return "textmore"
client.kg_prompt.side_effect = kg_prompt_with_empties
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
prompt_client=client,
verbose=False
)
chunks_with_flags = []
async def collect(chunk, end_of_stream):
chunks_with_flags.append((chunk, end_of_stream))
# Act
await rag.query(
query="test",
streaming=True,
chunk_callback=collect
)
# Assert
assert len(chunks_with_flags) == 4
assert chunks_with_flags[-1] == ("", True) # Final empty chunk
end_of_stream_flags = [f for c, f in chunks_with_flags]
assert end_of_stream_flags.count(True) == 1

View file

@ -14,15 +14,16 @@ from trustgraph.base.async_processor import AsyncProcessor
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase): class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
"""Test AsyncProcessor base class functionality""" """Test AsyncProcessor base class functionality"""
@patch('trustgraph.base.async_processor.PulsarClient') @patch('trustgraph.base.async_processor.get_pubsub')
@patch('trustgraph.base.async_processor.Consumer') @patch('trustgraph.base.async_processor.Consumer')
@patch('trustgraph.base.async_processor.ProcessorMetrics') @patch('trustgraph.base.async_processor.ProcessorMetrics')
@patch('trustgraph.base.async_processor.ConsumerMetrics') @patch('trustgraph.base.async_processor.ConsumerMetrics')
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics, async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
mock_consumer, mock_pulsar_client): mock_consumer, mock_get_pubsub):
"""Test basic AsyncProcessor initialization""" """Test basic AsyncProcessor initialization"""
# Arrange # Arrange
mock_pulsar_client.return_value = MagicMock() mock_backend = MagicMock()
mock_get_pubsub.return_value = mock_backend
mock_consumer.return_value = MagicMock() mock_consumer.return_value = MagicMock()
mock_processor_metrics.return_value = MagicMock() mock_processor_metrics.return_value = MagicMock()
mock_consumer_metrics.return_value = MagicMock() mock_consumer_metrics.return_value = MagicMock()
@ -43,8 +44,8 @@ class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
assert hasattr(processor, 'config_handlers') assert hasattr(processor, 'config_handlers')
assert processor.config_handlers == [] assert processor.config_handlers == []
# Verify PulsarClient was created # Verify get_pubsub was called to create backend
mock_pulsar_client.assert_called_once_with(**config) mock_get_pubsub.assert_called_once_with(**config)
# Verify metrics were initialized # Verify metrics were initialized
mock_processor_metrics.assert_called_once() mock_processor_metrics.assert_called_once()

View file

@ -145,7 +145,7 @@ class TestResolveCassandraConfig:
def test_default_configuration(self): def test_default_configuration(self):
"""Test resolution with no parameters or environment variables.""" """Test resolution with no parameters or environment variables."""
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
hosts, username, password = resolve_cassandra_config() hosts, username, password, keyspace = resolve_cassandra_config()
assert hosts == ['cassandra'] assert hosts == ['cassandra']
assert username is None assert username is None
@ -160,7 +160,7 @@ class TestResolveCassandraConfig:
} }
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
hosts, username, password = resolve_cassandra_config() hosts, username, password, keyspace = resolve_cassandra_config()
assert hosts == ['env1', 'env2', 'env3'] assert hosts == ['env1', 'env2', 'env3']
assert username == 'env-user' assert username == 'env-user'
@ -175,7 +175,7 @@ class TestResolveCassandraConfig:
} }
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
hosts, username, password = resolve_cassandra_config( hosts, username, password, keyspace = resolve_cassandra_config(
host='explicit-host', host='explicit-host',
username='explicit-user', username='explicit-user',
password='explicit-pass' password='explicit-pass'
@ -188,19 +188,19 @@ class TestResolveCassandraConfig:
def test_host_list_parsing(self): def test_host_list_parsing(self):
"""Test different host list formats.""" """Test different host list formats."""
# Single host # Single host
hosts, _, _ = resolve_cassandra_config(host='single-host') hosts, _, _, _ = resolve_cassandra_config(host='single-host')
assert hosts == ['single-host'] assert hosts == ['single-host']
# Multiple hosts with spaces # Multiple hosts with spaces
hosts, _, _ = resolve_cassandra_config(host='host1, host2 ,host3') hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
assert hosts == ['host1', 'host2', 'host3'] assert hosts == ['host1', 'host2', 'host3']
# Empty elements filtered out # Empty elements filtered out
hosts, _, _ = resolve_cassandra_config(host='host1,,host2,') hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
assert hosts == ['host1', 'host2'] assert hosts == ['host1', 'host2']
# Already a list # Already a list
hosts, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2']) hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
assert hosts == ['list-host1', 'list-host2'] assert hosts == ['list-host1', 'list-host2']
def test_args_object_resolution(self): def test_args_object_resolution(self):
@ -212,7 +212,7 @@ class TestResolveCassandraConfig:
cassandra_password = 'args-pass' cassandra_password = 'args-pass'
args = MockArgs() args = MockArgs()
hosts, username, password = resolve_cassandra_config(args) hosts, username, password, keyspace = resolve_cassandra_config(args)
assert hosts == ['args-host1', 'args-host2'] assert hosts == ['args-host1', 'args-host2']
assert username == 'args-user' assert username == 'args-user'
@ -233,7 +233,7 @@ class TestResolveCassandraConfig:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
args = PartialArgs() args = PartialArgs()
hosts, username, password = resolve_cassandra_config(args) hosts, username, password, keyspace = resolve_cassandra_config(args)
assert hosts == ['args-host'] # From args assert hosts == ['args-host'] # From args
assert username == 'env-user' # From env assert username == 'env-user' # From env
@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams:
'cassandra_password': 'new-pass' 'cassandra_password': 'new-pass'
} }
hosts, username, password = get_cassandra_config_from_params(params) hosts, username, password, keyspace = get_cassandra_config_from_params(params)
assert hosts == ['new-host1', 'new-host2'] assert hosts == ['new-host1', 'new-host2']
assert username == 'new-user' assert username == 'new-user'
@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams:
'graph_password': 'old-pass' 'graph_password': 'old-pass'
} }
hosts, username, password = get_cassandra_config_from_params(params) hosts, username, password, keyspace = get_cassandra_config_from_params(params)
# Should use defaults since graph_* params are not recognized # Should use defaults since graph_* params are not recognized
assert hosts == ['cassandra'] # Default assert hosts == ['cassandra'] # Default
@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams:
'cassandra_password': 'compat-pass' 'cassandra_password': 'compat-pass'
} }
hosts, username, password = get_cassandra_config_from_params(params) hosts, username, password, keyspace = get_cassandra_config_from_params(params)
assert hosts == ['compat-host'] assert hosts == ['compat-host']
assert username is None # cassandra_user is not recognized assert username is None # cassandra_user is not recognized
@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams:
'graph_password': 'old-pass' 'graph_password': 'old-pass'
} }
hosts, username, password = get_cassandra_config_from_params(params) hosts, username, password, keyspace = get_cassandra_config_from_params(params)
assert hosts == ['new-host'] # Only cassandra_* params work assert hosts == ['new-host'] # Only cassandra_* params work
assert username == 'new-user' # Only cassandra_* params work assert username == 'new-user' # Only cassandra_* params work
@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
params = {} params = {}
hosts, username, password = get_cassandra_config_from_params(params) hosts, username, password, keyspace = get_cassandra_config_from_params(params)
assert hosts == ['fallback-host1', 'fallback-host2'] assert hosts == ['fallback-host1', 'fallback-host2']
assert username == 'fallback-user' assert username == 'fallback-user'
@ -334,7 +334,7 @@ class TestConfigurationPriority:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
# CLI args should override everything # CLI args should override everything
hosts, username, password = resolve_cassandra_config( hosts, username, password, keyspace = resolve_cassandra_config(
host='cli-host', host='cli-host',
username='cli-user', username='cli-user',
password='cli-pass' password='cli-pass'
@ -354,7 +354,7 @@ class TestConfigurationPriority:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
# Only provide host via CLI # Only provide host via CLI
hosts, username, password = resolve_cassandra_config( hosts, username, password, keyspace = resolve_cassandra_config(
host='cli-host' host='cli-host'
# username and password not provided # username and password not provided
) )
@ -366,7 +366,7 @@ class TestConfigurationPriority:
def test_no_config_defaults(self): def test_no_config_defaults(self):
"""Test that defaults are used when no configuration is provided.""" """Test that defaults are used when no configuration is provided."""
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
hosts, username, password = resolve_cassandra_config() hosts, username, password, keyspace = resolve_cassandra_config()
assert hosts == ['cassandra'] # Default assert hosts == ['cassandra'] # Default
assert username is None # Default assert username is None # Default
@ -378,17 +378,17 @@ class TestEdgeCases:
def test_empty_host_string(self): def test_empty_host_string(self):
"""Test handling of empty host string falls back to default.""" """Test handling of empty host string falls back to default."""
hosts, _, _ = resolve_cassandra_config(host='') hosts, _, _, _ = resolve_cassandra_config(host='')
assert hosts == ['cassandra'] # Falls back to default assert hosts == ['cassandra'] # Falls back to default
def test_whitespace_only_host(self): def test_whitespace_only_host(self):
"""Test handling of whitespace-only host string.""" """Test handling of whitespace-only host string."""
hosts, _, _ = resolve_cassandra_config(host=' ') hosts, _, _, _ = resolve_cassandra_config(host=' ')
assert hosts == [] # Empty after stripping whitespace assert hosts == [] # Empty after stripping whitespace
def test_none_values_preserved(self): def test_none_values_preserved(self):
"""Test that None values are preserved correctly.""" """Test that None values are preserved correctly."""
hosts, username, password = resolve_cassandra_config( hosts, username, password, keyspace = resolve_cassandra_config(
host=None, host=None,
username=None, username=None,
password=None password=None
@ -401,7 +401,7 @@ class TestEdgeCases:
def test_mixed_none_and_values(self): def test_mixed_none_and_values(self):
"""Test mixing None and actual values.""" """Test mixing None and actual values."""
hosts, username, password = resolve_cassandra_config( hosts, username, password, keyspace = resolve_cassandra_config(
host='mixed-host', host='mixed-host',
username=None, username=None,
password='mixed-pass' password='mixed-pass'

View file

@ -0,0 +1,260 @@
"""
Unit tests for PromptClient streaming callback behavior.
These tests verify that the prompt client correctly passes the end_of_stream
flag to chunk callbacks, ensuring proper streaming protocol compliance.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call, patch
from trustgraph.base.prompt_client import PromptClient
from trustgraph.schema import PromptResponse
class TestPromptClientStreamingCallback:
"""Test PromptClient streaming callback behavior"""
@pytest.fixture
def prompt_client(self):
"""Create a PromptClient with mocked dependencies"""
# Mock all the required initialization parameters
with patch.object(PromptClient, '__init__', lambda self: None):
client = PromptClient()
return client
@pytest.fixture
def mock_request_response(self):
"""Create a mock request/response handler"""
async def mock_request(request, recipient=None, timeout=600):
if recipient:
# Simulate streaming responses
responses = [
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
else:
# Non-streaming response
return PromptResponse(text="Hello world!", object=None, error=None)
return mock_request
@pytest.mark.asyncio
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
"""Test that callback receives both chunk text and end_of_stream flag"""
# Arrange
prompt_client.request = mock_request_response
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - callback should be called with (chunk, end_of_stream) signature
assert callback.call_count == 4
# Verify first chunk: text + end_of_stream=False
assert callback.call_args_list[0] == call("Hello", False)
# Verify second chunk
assert callback.call_args_list[1] == call(" world", False)
# Verify third chunk
assert callback.call_args_list[2] == call("!", False)
# Verify final chunk: empty text + end_of_stream=True
assert callback.call_args_list[3] == call("", True)
@pytest.mark.asyncio
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
"""Test that empty final chunks are passed to callback"""
# Arrange
prompt_client.request = mock_request_response
chunks_received = []
async def collect_chunks(chunk, end_of_stream):
chunks_received.append((chunk, end_of_stream))
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=collect_chunks
)
# Assert - should receive the empty final chunk
final_chunk = chunks_received[-1]
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
@pytest.mark.asyncio
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
"""Test callback signature when LLM sends non-empty final chunk"""
# Arrange
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
if recipient:
# Some LLMs send content in the final chunk
responses = [
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request_non_empty_final
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Hello", False)
assert callback.call_args_list[1] == call(" world!", True)
@pytest.mark.asyncio
async def test_callback_not_called_without_text(self, prompt_client):
"""Test that callback is not called for responses without text"""
# Arrange
async def mock_request_no_text(request, recipient=None, timeout=600):
if recipient:
# Response with only end_of_stream, no text
responses = [
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request_no_text
callback = AsyncMock()
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - callback should only be called once (for "Content")
assert callback.call_count == 1
assert callback.call_args_list[0] == call("Content", False)
@pytest.mark.asyncio
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
"""Test that synchronous callbacks also receive end_of_stream parameter"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = MagicMock() # Synchronous mock
# Act
await prompt_client.prompt(
id="test-prompt",
variables={"query": "test"},
streaming=True,
chunk_callback=callback
)
# Assert - synchronous callback should also get both parameters
assert callback.call_count == 2
assert callback.call_args_list[0] == call("test", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that kg_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.kg_prompt(
query="What is machine learning?",
kg=[("subject", "predicate", "object")],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Answer", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that document_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.document_prompt(
query="Summarize this",
documents=["doc1", "doc2"],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Summary", False)
assert callback.call_args_list[1] == call("", True)

View file

@ -8,22 +8,22 @@ from trustgraph.base.publisher import Publisher
@pytest.fixture @pytest.fixture
def mock_pulsar_client(): def mock_pulsar_backend():
"""Mock Pulsar client for testing.""" """Mock Pulsar backend for testing."""
client = MagicMock() backend = MagicMock()
producer = AsyncMock() producer = AsyncMock()
producer.send = MagicMock() producer.send = MagicMock()
producer.flush = MagicMock() producer.flush = MagicMock()
producer.close = MagicMock() producer.close = MagicMock()
client.create_producer.return_value = producer backend.create_producer.return_value = producer
return client return backend
@pytest.fixture @pytest.fixture
def publisher(mock_pulsar_client): def publisher(mock_pulsar_backend):
"""Create Publisher instance for testing.""" """Create Publisher instance for testing."""
return Publisher( return Publisher(
client=mock_pulsar_client, backend=mock_pulsar_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -34,12 +34,12 @@ def publisher(mock_pulsar_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_queue_drain(): async def test_publisher_queue_drain():
"""Verify Publisher drains queue on shutdown.""" """Verify Publisher drains queue on shutdown."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -85,12 +85,12 @@ async def test_publisher_queue_drain():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_rejects_messages_during_drain(): async def test_publisher_rejects_messages_during_drain():
"""Verify Publisher rejects new messages during shutdown.""" """Verify Publisher rejects new messages during shutdown."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -113,12 +113,12 @@ async def test_publisher_rejects_messages_during_drain():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_drain_timeout(): async def test_publisher_drain_timeout():
"""Verify Publisher respects drain timeout.""" """Verify Publisher respects drain timeout."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -169,12 +169,12 @@ async def test_publisher_drain_timeout():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_successful_drain(): async def test_publisher_successful_drain():
"""Verify Publisher drains successfully under normal conditions.""" """Verify Publisher drains successfully under normal conditions."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -224,12 +224,12 @@ async def test_publisher_successful_drain():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_state_transitions(): async def test_publisher_state_transitions():
"""Test Publisher state transitions during graceful shutdown.""" """Test Publisher state transitions during graceful shutdown."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,
@ -276,9 +276,9 @@ async def test_publisher_state_transitions():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publisher_exception_handling(): async def test_publisher_exception_handling():
"""Test Publisher handles exceptions during drain gracefully.""" """Test Publisher handles exceptions during drain gracefully."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_producer = MagicMock() mock_producer = MagicMock()
mock_client.create_producer.return_value = mock_producer mock_backend.create_producer.return_value = mock_producer
# Mock producer.send to raise exception on second call # Mock producer.send to raise exception on second call
call_count = 0 call_count = 0
@ -291,7 +291,7 @@ async def test_publisher_exception_handling():
mock_producer.send.side_effect = failing_send mock_producer.send.side_effect = failing_send
publisher = Publisher( publisher = Publisher(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
schema=dict, schema=dict,
max_size=10, max_size=10,

View file

@ -6,23 +6,11 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.base.subscriber import Subscriber from trustgraph.base.subscriber import Subscriber
# Mock JsonSchema globally to avoid schema issues in tests
# Patch at the module level where it's imported in subscriber
@patch('trustgraph.base.subscriber.JsonSchema')
def mock_json_schema_global(mock_schema):
mock_schema.return_value = MagicMock()
return mock_schema
# Apply the global patch
_json_schema_patch = patch('trustgraph.base.subscriber.JsonSchema')
_mock_json_schema = _json_schema_patch.start()
_mock_json_schema.return_value = MagicMock()
@pytest.fixture @pytest.fixture
def mock_pulsar_client(): def mock_pulsar_backend():
"""Mock Pulsar client for testing.""" """Mock Pulsar backend for testing."""
client = MagicMock() backend = MagicMock()
consumer = MagicMock() consumer = MagicMock()
consumer.receive = MagicMock() consumer.receive = MagicMock()
consumer.acknowledge = MagicMock() consumer.acknowledge = MagicMock()
@ -30,15 +18,15 @@ def mock_pulsar_client():
consumer.pause_message_listener = MagicMock() consumer.pause_message_listener = MagicMock()
consumer.unsubscribe = MagicMock() consumer.unsubscribe = MagicMock()
consumer.close = MagicMock() consumer.close = MagicMock()
client.subscribe.return_value = consumer backend.create_consumer.return_value = consumer
return client return backend
@pytest.fixture @pytest.fixture
def subscriber(mock_pulsar_client): def subscriber(mock_pulsar_backend):
"""Create Subscriber instance for testing.""" """Create Subscriber instance for testing."""
return Subscriber( return Subscriber(
client=mock_pulsar_client, backend=mock_pulsar_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -60,12 +48,12 @@ def create_mock_message(message_id="test-id", data=None):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_deferred_acknowledgment_success(): async def test_subscriber_deferred_acknowledgment_success():
"""Verify Subscriber only acks on successful delivery.""" """Verify Subscriber only acks on successful delivery."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -102,12 +90,12 @@ async def test_subscriber_deferred_acknowledgment_success():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_deferred_acknowledgment_failure(): async def test_subscriber_deferred_acknowledgment_failure():
"""Verify Subscriber negative acks on delivery failure.""" """Verify Subscriber negative acks on delivery failure."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -140,13 +128,13 @@ async def test_subscriber_deferred_acknowledgment_failure():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_backpressure_strategies(): async def test_subscriber_backpressure_strategies():
"""Test different backpressure strategies.""" """Test different backpressure strategies."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
# Test drop_oldest strategy # Test drop_oldest strategy
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -187,12 +175,12 @@ async def test_subscriber_backpressure_strategies():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_graceful_shutdown(): async def test_subscriber_graceful_shutdown():
"""Test Subscriber graceful shutdown with queue draining.""" """Test Subscriber graceful shutdown with queue draining."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -253,12 +241,12 @@ async def test_subscriber_graceful_shutdown():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_drain_timeout(): async def test_subscriber_drain_timeout():
"""Test Subscriber respects drain timeout.""" """Test Subscriber respects drain timeout."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -288,12 +276,12 @@ async def test_subscriber_drain_timeout():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_pending_acks_cleanup(): async def test_subscriber_pending_acks_cleanup():
"""Test Subscriber cleans up pending acknowledgments on shutdown.""" """Test Subscriber cleans up pending acknowledgments on shutdown."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",
@ -342,12 +330,12 @@ async def test_subscriber_pending_acks_cleanup():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscriber_multiple_subscribers(): async def test_subscriber_multiple_subscribers():
"""Test Subscriber with multiple concurrent subscribers.""" """Test Subscriber with multiple concurrent subscribers."""
mock_client = MagicMock() mock_backend = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_client.subscribe.return_value = mock_consumer mock_backend.create_consumer.return_value = mock_consumer
subscriber = Subscriber( subscriber = Subscriber(
client=mock_client, backend=mock_backend,
topic="test-topic", topic="test-topic",
subscription="test-subscription", subscription="test-subscription",
consumer_name="test-consumer", consumer_name="test-consumer",

View file

@ -108,7 +108,8 @@ class TestListConfigItems:
mock_list.assert_called_once_with( mock_list.assert_called_once_with(
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
format_type='json' format_type='json',
token=None
) )
def test_list_main_uses_defaults(self): def test_list_main_uses_defaults(self):
@ -126,7 +127,8 @@ class TestListConfigItems:
mock_list.assert_called_once_with( mock_list.assert_called_once_with(
url='http://localhost:8088/', url='http://localhost:8088/',
config_type='prompt', config_type='prompt',
format_type='text' format_type='text',
token=None
) )
@ -193,7 +195,8 @@ class TestGetConfigItem:
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
key='template-1', key='template-1',
format_type='json' format_type='json',
token=None
) )
@ -249,7 +252,8 @@ class TestPutConfigItem:
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
key='new-template', key='new-template',
value='Custom prompt: {input}' value='Custom prompt: {input}',
token=None
) )
def test_put_main_with_stdin_arg(self): def test_put_main_with_stdin_arg(self):
@ -273,7 +277,8 @@ class TestPutConfigItem:
url='http://localhost:8088/', url='http://localhost:8088/',
config_type='prompt', config_type='prompt',
key='stdin-template', key='stdin-template',
value=stdin_content value=stdin_content,
token=None
) )
def test_put_main_mutually_exclusive_args(self): def test_put_main_mutually_exclusive_args(self):
@ -328,7 +333,8 @@ class TestDeleteConfigItem:
mock_delete.assert_called_once_with( mock_delete.assert_called_once_with(
url='http://custom.com', url='http://custom.com',
config_type='prompt', config_type='prompt',
key='old-template' key='old-template',
token=None
) )

View file

@ -2,17 +2,16 @@
Unit tests for the load_knowledge CLI module. Unit tests for the load_knowledge CLI module.
Tests the business logic of loading triples and entity contexts from Turtle files Tests the business logic of loading triples and entity contexts from Turtle files
while mocking WebSocket connections and external dependencies. using the BulkClient API.
""" """
import pytest import pytest
import json
import tempfile import tempfile
import asyncio from unittest.mock import Mock, patch, MagicMock, call
from unittest.mock import AsyncMock, Mock, patch, mock_open, MagicMock
from pathlib import Path from pathlib import Path
from trustgraph.cli.load_knowledge import KnowledgeLoader, main from trustgraph.cli.load_knowledge import KnowledgeLoader, main
from trustgraph.api import Triple
@pytest.fixture @pytest.fixture
@ -43,26 +42,6 @@ def temp_turtle_file(sample_turtle_content):
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)
@pytest.fixture
def mock_websocket():
"""Mock WebSocket connection."""
mock_ws = MagicMock()
async def async_send(data):
return None
async def async_recv():
return ""
async def async_close():
return None
mock_ws.send = Mock(side_effect=async_send)
mock_ws.recv = Mock(side_effect=async_recv)
mock_ws.close = Mock(side_effect=async_close)
return mock_ws
@pytest.fixture @pytest.fixture
def knowledge_loader(): def knowledge_loader():
"""Create a KnowledgeLoader instance with test parameters.""" """Create a KnowledgeLoader instance with test parameters."""
@ -72,125 +51,66 @@ def knowledge_loader():
user="test-user", user="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc-123", document_id="test-doc-123",
url="ws://test.example.com/" url="http://test.example.com/",
token=None
) )
class TestKnowledgeLoader: class TestKnowledgeLoader:
"""Test the KnowledgeLoader class business logic.""" """Test the KnowledgeLoader class business logic."""
def test_init_constructs_urls_correctly(self): def test_init_stores_parameters_correctly(self):
"""Test that URLs are constructed properly.""" """Test that initialization stores parameters correctly."""
loader = KnowledgeLoader( loader = KnowledgeLoader(
files=["test.ttl"], files=["file1.ttl", "file2.ttl"],
flow="my-flow", flow="my-flow",
user="user1", user="user1",
collection="col1", collection="col1",
document_id="doc1", document_id="doc1",
url="ws://example.com/" url="http://example.com/",
token="test-token"
) )
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples" assert loader.files == ["file1.ttl", "file2.ttl"]
assert loader.entity_contexts_url == "ws://example.com/api/v1/flow/my-flow/import/entity-contexts" assert loader.flow == "my-flow"
assert loader.user == "user1" assert loader.user == "user1"
assert loader.collection == "col1" assert loader.collection == "col1"
assert loader.document_id == "doc1" assert loader.document_id == "doc1"
assert loader.url == "http://example.com/"
assert loader.token == "test-token"
def test_init_adds_trailing_slash(self): def test_load_triples_from_file_yields_triples(self, temp_turtle_file, knowledge_loader):
"""Test that trailing slash is added to URL if missing.""" """Test that load_triples_from_file yields Triple objects."""
loader = KnowledgeLoader( triples = list(knowledge_loader.load_triples_from_file(temp_turtle_file))
files=["test.ttl"],
flow="my-flow",
user="user1",
collection="col1",
document_id="doc1",
url="ws://example.com" # No trailing slash
)
assert loader.triples_url == "ws://example.com/api/v1/flow/my-flow/import/triples" # Should have triples for all statements in the file
assert len(triples) > 0
@pytest.mark.asyncio # Verify they are Triple objects
async def test_load_triples_sends_correct_messages(self, temp_turtle_file, mock_websocket): for triple in triples:
"""Test that triple loading sends correctly formatted messages.""" assert isinstance(triple, Triple)
loader = KnowledgeLoader( assert hasattr(triple, 's')
files=[temp_turtle_file], assert hasattr(triple, 'p')
flow="test-flow", assert hasattr(triple, 'o')
user="test-user", assert isinstance(triple.s, str)
collection="test-collection", assert isinstance(triple.p, str)
document_id="test-doc" assert isinstance(triple.o, str)
)
await loader.load_triples(temp_turtle_file, mock_websocket) def test_load_entity_contexts_from_file_yields_literals_only(self, temp_turtle_file, knowledge_loader):
# Verify WebSocket send was called
assert mock_websocket.send.call_count > 0
# Check message format for one of the calls
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
# Verify message structure
sample_message = sent_messages[0]
assert "metadata" in sample_message
assert "triples" in sample_message
metadata = sample_message["metadata"]
assert metadata["id"] == "test-doc"
assert metadata["user"] == "test-user"
assert metadata["collection"] == "test-collection"
assert isinstance(metadata["metadata"], list)
triple = sample_message["triples"][0]
assert "s" in triple
assert "p" in triple
assert "o" in triple
# Check Value structure
assert "v" in triple["s"]
assert "e" in triple["s"]
assert triple["s"]["e"] is True # Subject should be URI
@pytest.mark.asyncio
async def test_load_entity_contexts_processes_literals_only(self, temp_turtle_file, mock_websocket):
"""Test that entity contexts are created only for literals.""" """Test that entity contexts are created only for literals."""
loader = KnowledgeLoader( contexts = list(knowledge_loader.load_entity_contexts_from_file(temp_turtle_file))
files=[temp_turtle_file],
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc"
)
await loader.load_entity_contexts(temp_turtle_file, mock_websocket) # Should have contexts for literal objects (foaf:name, foaf:age, foaf:email)
assert len(contexts) > 0
# Get all sent messages # Verify format: (entity, context) tuples
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list] for entity, context in contexts:
assert isinstance(entity, str)
assert isinstance(context, str)
# Entity should be a URI (subject)
assert entity.startswith("http://")
# Verify we got entity context messages def test_load_entity_contexts_skips_uri_objects(self):
assert len(sent_messages) > 0
for message in sent_messages:
assert "metadata" in message
assert "entities" in message
metadata = message["metadata"]
assert metadata["id"] == "test-doc"
assert metadata["user"] == "test-user"
assert metadata["collection"] == "test-collection"
entity_context = message["entities"][0]
assert "entity" in entity_context
assert "context" in entity_context
entity = entity_context["entity"]
assert "v" in entity
assert "e" in entity
assert entity["e"] is True # Entity should be URI (subject)
# Context should be a string (the literal value)
assert isinstance(entity_context["context"], str)
@pytest.mark.asyncio
async def test_load_entity_contexts_skips_uri_objects(self, mock_websocket):
"""Test that URI objects don't generate entity contexts.""" """Test that URI objects don't generate entity contexts."""
# Create turtle with only URI objects (no literals) # Create turtle with only URI objects (no literals)
turtle_content = """ turtle_content = """
@ -208,63 +128,68 @@ ex:mary ex:knows ex:bob .
flow="test-flow", flow="test-flow",
user="test-user", user="test-user",
collection="test-collection", collection="test-collection",
document_id="test-doc" document_id="test-doc",
url="http://test.example.com/"
) )
await loader.load_entity_contexts(f.name, mock_websocket) contexts = list(loader.load_entity_contexts_from_file(f.name))
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)
# Should not send any messages since there are no literals # Should have no contexts since there are no literals
mock_websocket.send.assert_not_called() assert len(contexts) == 0
@pytest.mark.asyncio @patch('trustgraph.cli.load_knowledge.Api')
@patch('trustgraph.cli.load_knowledge.connect') def test_run_calls_bulk_api(self, mock_api_class, temp_turtle_file):
async def test_run_calls_both_loaders(self, mock_connect, knowledge_loader, temp_turtle_file): """Test that run() uses BulkClient API."""
"""Test that run() calls both triple and entity context loaders.""" # Setup mocks
knowledge_loader.files = [temp_turtle_file] mock_api = MagicMock()
mock_bulk = MagicMock()
mock_api_class.return_value = mock_api
mock_api.bulk.return_value = mock_bulk
# Create a simple mock websocket loader = KnowledgeLoader(
mock_ws = MagicMock() files=[temp_turtle_file],
async def mock_send(data): flow="test-flow",
pass user="test-user",
mock_ws.send = mock_send collection="test-collection",
document_id="test-doc",
url="http://test.example.com/",
token="test-token"
)
# Create async context manager mock loader.run()
async def mock_aenter(self):
return mock_ws
async def mock_aexit(self, exc_type, exc_val, exc_tb): # Verify Api was created with correct parameters
return None mock_api_class.assert_called_once_with(
url="http://test.example.com/",
token="test-token"
)
mock_connection = MagicMock() # Verify bulk client was obtained
mock_connection.__aenter__ = mock_aenter mock_api.bulk.assert_called_once()
mock_connection.__aexit__ = mock_aexit
mock_connect.return_value = mock_connection
# Create AsyncMock objects that can track calls properly # Verify import_triples was called
mock_load_triples = AsyncMock(return_value=None) assert mock_bulk.import_triples.call_count == 1
mock_load_contexts = AsyncMock(return_value=None) call_args = mock_bulk.import_triples.call_args
assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc"
assert call_args[1]['metadata']['user'] == "test-user"
assert call_args[1]['metadata']['collection'] == "test-collection"
with patch.object(knowledge_loader, 'load_triples', mock_load_triples), \ # Verify import_entity_contexts was called
patch.object(knowledge_loader, 'load_entity_contexts', mock_load_contexts): assert mock_bulk.import_entity_contexts.call_count == 1
call_args = mock_bulk.import_entity_contexts.call_args
await knowledge_loader.run() assert call_args[1]['flow'] == "test-flow"
assert call_args[1]['metadata']['id'] == "test-doc"
# Verify both methods were called
mock_load_triples.assert_called_once_with(temp_turtle_file, mock_ws)
mock_load_contexts.assert_called_once_with(temp_turtle_file, mock_ws)
# Verify WebSocket connections were made to both URLs
assert mock_connect.call_count == 2
class TestCLIArgumentParsing: class TestCLIArgumentParsing:
"""Test CLI argument parsing and main function.""" """Test CLI argument parsing and main function."""
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader') @patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
@patch('trustgraph.cli.load_knowledge.asyncio.run') @patch('trustgraph.cli.load_knowledge.time.sleep')
def test_main_parses_args_correctly(self, mock_asyncio_run, mock_loader_class): def test_main_parses_args_correctly(self, mock_sleep, mock_loader_class):
"""Test that main() parses arguments correctly.""" """Test that main() parses arguments correctly."""
mock_loader_instance = MagicMock() mock_loader_instance = MagicMock()
mock_loader_class.return_value = mock_loader_instance mock_loader_class.return_value = mock_loader_instance
@ -275,7 +200,8 @@ class TestCLIArgumentParsing:
'-f', 'my-flow', '-f', 'my-flow',
'-U', 'my-user', '-U', 'my-user',
'-C', 'my-collection', '-C', 'my-collection',
'-u', 'ws://custom.example.com/', '-u', 'http://custom.example.com/',
'-t', 'my-token',
'file1.ttl', 'file1.ttl',
'file2.ttl' 'file2.ttl'
] ]
@ -286,19 +212,20 @@ class TestCLIArgumentParsing:
# Verify KnowledgeLoader was instantiated with correct args # Verify KnowledgeLoader was instantiated with correct args
mock_loader_class.assert_called_once_with( mock_loader_class.assert_called_once_with(
document_id='doc-123', document_id='doc-123',
url='ws://custom.example.com/', url='http://custom.example.com/',
token='my-token',
flow='my-flow', flow='my-flow',
files=['file1.ttl', 'file2.ttl'], files=['file1.ttl', 'file2.ttl'],
user='my-user', user='my-user',
collection='my-collection' collection='my-collection'
) )
# Verify asyncio.run was called once # Verify run was called
mock_asyncio_run.assert_called_once() mock_loader_instance.run.assert_called_once()
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader') @patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
@patch('trustgraph.cli.load_knowledge.asyncio.run') @patch('trustgraph.cli.load_knowledge.time.sleep')
def test_main_uses_defaults(self, mock_asyncio_run, mock_loader_class): def test_main_uses_defaults(self, mock_sleep, mock_loader_class):
"""Test that main() uses default values when not specified.""" """Test that main() uses default values when not specified."""
mock_loader_instance = MagicMock() mock_loader_instance = MagicMock()
mock_loader_class.return_value = mock_loader_instance mock_loader_class.return_value = mock_loader_instance
@ -317,80 +244,69 @@ class TestCLIArgumentParsing:
assert call_args['flow'] == 'default' assert call_args['flow'] == 'default'
assert call_args['user'] == 'trustgraph' assert call_args['user'] == 'trustgraph'
assert call_args['collection'] == 'default' assert call_args['collection'] == 'default'
assert call_args['url'] == 'ws://localhost:8088/' assert call_args['url'] == 'http://localhost:8088/'
assert call_args['token'] is None
class TestErrorHandling: class TestErrorHandling:
"""Test error handling scenarios.""" """Test error handling scenarios."""
@pytest.mark.asyncio def test_load_triples_handles_invalid_turtle(self, knowledge_loader):
async def test_load_triples_handles_invalid_turtle(self, mock_websocket):
"""Test handling of invalid Turtle content.""" """Test handling of invalid Turtle content."""
# Create file with invalid Turtle content # Create file with invalid Turtle content
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
f.write("Invalid Turtle Content {{{") f.write("Invalid Turtle Content {{{")
f.flush() f.flush()
loader = KnowledgeLoader(
files=[f.name],
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc"
)
# Should raise an exception for invalid Turtle # Should raise an exception for invalid Turtle
with pytest.raises(Exception): with pytest.raises(Exception):
await loader.load_triples(f.name, mock_websocket) list(knowledge_loader.load_triples_from_file(f.name))
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)
@pytest.mark.asyncio def test_load_entity_contexts_handles_invalid_turtle(self, knowledge_loader):
async def test_load_entity_contexts_handles_invalid_turtle(self, mock_websocket):
"""Test handling of invalid Turtle content in entity contexts.""" """Test handling of invalid Turtle content in entity contexts."""
# Create file with invalid Turtle content # Create file with invalid Turtle content
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
f.write("Invalid Turtle Content {{{") f.write("Invalid Turtle Content {{{")
f.flush() f.flush()
loader = KnowledgeLoader(
files=[f.name],
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc"
)
# Should raise an exception for invalid Turtle # Should raise an exception for invalid Turtle
with pytest.raises(Exception): with pytest.raises(Exception):
await loader.load_entity_contexts(f.name, mock_websocket) list(knowledge_loader.load_entity_contexts_from_file(f.name))
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)
@pytest.mark.asyncio @patch('trustgraph.cli.load_knowledge.Api')
@patch('trustgraph.cli.load_knowledge.connect')
@patch('builtins.print') # Mock print to avoid output during tests @patch('builtins.print') # Mock print to avoid output during tests
async def test_run_handles_connection_errors(self, mock_print, mock_connect, knowledge_loader, temp_turtle_file): def test_run_handles_api_errors(self, mock_print, mock_api_class, temp_turtle_file):
"""Test handling of WebSocket connection errors.""" """Test handling of API errors."""
knowledge_loader.files = [temp_turtle_file] # Mock API to raise an error
mock_api_class.side_effect = Exception("API connection failed")
# Mock connection failure loader = KnowledgeLoader(
mock_connect.side_effect = ConnectionError("Failed to connect") files=[temp_turtle_file],
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc",
url="http://test.example.com/"
)
# Should not raise exception, just print error # Should raise the exception
await knowledge_loader.run() with pytest.raises(Exception, match="API connection failed"):
loader.run()
@patch('trustgraph.cli.load_knowledge.KnowledgeLoader') @patch('trustgraph.cli.load_knowledge.KnowledgeLoader')
@patch('trustgraph.cli.load_knowledge.asyncio.run')
@patch('trustgraph.cli.load_knowledge.time.sleep') @patch('trustgraph.cli.load_knowledge.time.sleep')
@patch('builtins.print') # Mock print to avoid output during tests @patch('builtins.print') # Mock print to avoid output during tests
def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_asyncio_run, mock_loader_class): def test_main_retries_on_exception(self, mock_print, mock_sleep, mock_loader_class):
"""Test that main() retries on exceptions.""" """Test that main() retries on exceptions."""
mock_loader_instance = MagicMock() mock_loader_instance = MagicMock()
mock_loader_class.return_value = mock_loader_instance mock_loader_class.return_value = mock_loader_instance
# First call raises exception, second succeeds # First call raises exception, second succeeds
mock_asyncio_run.side_effect = [Exception("Test error"), None] mock_loader_instance.run.side_effect = [Exception("Test error"), None]
test_args = [ test_args = [
'tg-load-knowledge', 'tg-load-knowledge',
@ -402,38 +318,29 @@ class TestErrorHandling:
main() main()
# Should have been called twice (first failed, second succeeded) # Should have been called twice (first failed, second succeeded)
assert mock_asyncio_run.call_count == 2 assert mock_loader_instance.run.call_count == 2
mock_sleep.assert_called_once_with(10) mock_sleep.assert_called_once_with(10)
class TestDataValidation: class TestDataValidation:
"""Test data validation and edge cases.""" """Test data validation and edge cases."""
@pytest.mark.asyncio def test_empty_turtle_file(self, knowledge_loader):
async def test_empty_turtle_file(self, mock_websocket):
"""Test handling of empty Turtle files.""" """Test handling of empty Turtle files."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.ttl', delete=False) as f:
f.write("") # Empty file f.write("") # Empty file
f.flush() f.flush()
loader = KnowledgeLoader( triples = list(knowledge_loader.load_triples_from_file(f.name))
files=[f.name], contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc"
)
await loader.load_triples(f.name, mock_websocket) # Should return empty lists for empty file
await loader.load_entity_contexts(f.name, mock_websocket) assert len(triples) == 0
assert len(contexts) == 0
# Should not send any messages for empty file
mock_websocket.send.assert_not_called()
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)
@pytest.mark.asyncio def test_turtle_with_mixed_literals_and_uris(self, knowledge_loader):
async def test_turtle_with_mixed_literals_and_uris(self, mock_websocket):
"""Test handling of Turtle with mixed literal and URI objects.""" """Test handling of Turtle with mixed literal and URI objects."""
turtle_content = """ turtle_content = """
@prefix ex: <http://example.org/> . @prefix ex: <http://example.org/> .
@ -448,32 +355,18 @@ ex:mary ex:name "Mary Johnson" .
f.write(turtle_content) f.write(turtle_content)
f.flush() f.flush()
loader = KnowledgeLoader( contexts = list(knowledge_loader.load_entity_contexts_from_file(f.name))
files=[f.name],
flow="test-flow",
user="test-user",
collection="test-collection",
document_id="test-doc"
)
await loader.load_entity_contexts(f.name, mock_websocket)
sent_messages = [json.loads(call.args[0]) for call in mock_websocket.send.call_args_list]
# Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson") # Should have 4 entity contexts (for the 4 literals: "John Smith", "25", "New York", "Mary Johnson")
# URI ex:mary should be skipped # URI ex:mary should be skipped
assert len(sent_messages) == 4 assert len(contexts) == 4
# Verify all contexts are for literals (subjects should be URIs) # Verify all contexts are for literals (subjects should be URIs)
contexts = [] context_values = [context for entity, context in contexts]
for message in sent_messages:
entity_context = message["entities"][0]
assert entity_context["entity"]["e"] is True # Subject is URI
contexts.append(entity_context["context"])
assert "John Smith" in contexts assert "John Smith" in context_values
assert "25" in contexts assert "25" in context_values
assert "New York" in contexts assert "New York" in context_values
assert "Mary Johnson" in contexts assert "Mary Johnson" in context_values
Path(f.name).unlink(missing_ok=True) Path(f.name).unlink(missing_ok=True)

View file

@ -135,7 +135,8 @@ class TestSetToolStructuredQuery:
arguments=[], arguments=[],
group=None, group=None,
state=None, state=None,
applicable_states=None applicable_states=None,
token=None
) )
def test_set_main_structured_query_no_arguments_needed(self): def test_set_main_structured_query_no_arguments_needed(self):
@ -313,7 +314,7 @@ class TestShowToolsStructuredQuery:
show_main() show_main()
mock_show.assert_called_once_with(url='http://custom.com') mock_show.assert_called_once_with(url='http://custom.com', token=None)
class TestStructuredQueryToolValidation: class TestStructuredQueryToolValidation:

View file

@ -22,18 +22,18 @@ class TestConfigReceiver:
def test_config_receiver_initialization(self): def test_config_receiver_initialization(self):
"""Test ConfigReceiver initialization""" """Test ConfigReceiver initialization"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
assert config_receiver.pulsar_client == mock_pulsar_client assert config_receiver.backend == mock_backend
assert config_receiver.flow_handlers == [] assert config_receiver.flow_handlers == []
assert config_receiver.flows == {} assert config_receiver.flows == {}
def test_add_handler(self): def test_add_handler(self):
"""Test adding flow handlers""" """Test adding flow handlers"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
handler1 = Mock() handler1 = Mock()
handler2 = Mock() handler2 = Mock()
@ -48,8 +48,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_with_new_flows(self): async def test_on_config_with_new_flows(self):
"""Test on_config method with new flows""" """Test on_config method with new flows"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Track calls manually instead of using AsyncMock # Track calls manually instead of using AsyncMock
start_flow_calls = [] start_flow_calls = []
@ -87,8 +87,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_with_removed_flows(self): async def test_on_config_with_removed_flows(self):
"""Test on_config method with removed flows""" """Test on_config method with removed flows"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows # Pre-populate with existing flows
config_receiver.flows = { config_receiver.flows = {
@ -128,8 +128,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_with_no_flows(self): async def test_on_config_with_no_flows(self):
"""Test on_config method with no flows in config""" """Test on_config method with no flows in config"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow and stop_flow methods with async functions # Mock the start_flow and stop_flow methods with async functions
async def mock_start_flow(*args): async def mock_start_flow(*args):
@ -158,8 +158,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_exception_handling(self): async def test_on_config_exception_handling(self):
"""Test on_config method handles exceptions gracefully""" """Test on_config method handles exceptions gracefully"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Create mock message that will cause an exception # Create mock message that will cause an exception
mock_msg = Mock() mock_msg = Mock()
@ -174,8 +174,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_flow_with_handlers(self): async def test_start_flow_with_handlers(self):
"""Test start_flow method with multiple handlers""" """Test start_flow method with multiple handlers"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers # Add mock handlers
handler1 = Mock() handler1 = Mock()
@ -197,8 +197,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_flow_with_handler_exception(self): async def test_start_flow_with_handler_exception(self):
"""Test start_flow method handles handler exceptions""" """Test start_flow method handles handler exceptions"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception # Add mock handler that raises exception
handler = Mock() handler = Mock()
@ -217,8 +217,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow_with_handlers(self): async def test_stop_flow_with_handlers(self):
"""Test stop_flow method with multiple handlers""" """Test stop_flow method with multiple handlers"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Add mock handlers # Add mock handlers
handler1 = Mock() handler1 = Mock()
@ -240,8 +240,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow_with_handler_exception(self): async def test_stop_flow_with_handler_exception(self):
"""Test stop_flow method handles handler exceptions""" """Test stop_flow method handles handler exceptions"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Add mock handler that raises exception # Add mock handler that raises exception
handler = Mock() handler = Mock()
@ -260,9 +260,9 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_config_loader_creates_consumer(self): async def test_config_loader_creates_consumer(self):
"""Test config_loader method creates Pulsar consumer""" """Test config_loader method creates Pulsar consumer"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Temporarily restore the real config_loader for this test # Temporarily restore the real config_loader for this test
config_receiver.config_loader = _real_config_loader.__get__(config_receiver) config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
@ -292,7 +292,7 @@ class TestConfigReceiver:
mock_consumer_class.assert_called_once() mock_consumer_class.assert_called_once()
call_args = mock_consumer_class.call_args call_args = mock_consumer_class.call_args
assert call_args[1]['client'] == mock_pulsar_client assert call_args[1]['backend'] == mock_backend
assert call_args[1]['subscriber'] == "gateway-test-uuid" assert call_args[1]['subscriber'] == "gateway-test-uuid"
assert call_args[1]['handler'] == config_receiver.on_config assert call_args[1]['handler'] == config_receiver.on_config
assert call_args[1]['start_of_messages'] is True assert call_args[1]['start_of_messages'] is True
@ -301,8 +301,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_creates_config_loader_task(self, mock_create_task): async def test_start_creates_config_loader_task(self, mock_create_task):
"""Test start method creates config loader task""" """Test start method creates config loader task"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Mock create_task to avoid actually creating tasks with real coroutines # Mock create_task to avoid actually creating tasks with real coroutines
mock_task = Mock() mock_task = Mock()
@ -320,8 +320,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_mixed_flow_operations(self): async def test_on_config_mixed_flow_operations(self):
"""Test on_config with mixed add/remove operations""" """Test on_config with mixed add/remove operations"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Pre-populate with existing flows # Pre-populate with existing flows
config_receiver.flows = { config_receiver.flows = {
@ -380,8 +380,8 @@ class TestConfigReceiver:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_config_invalid_json_flow_data(self): async def test_on_config_invalid_json_flow_data(self):
"""Test on_config handles invalid JSON in flow data""" """Test on_config handles invalid JSON in flow data"""
mock_pulsar_client = Mock() mock_backend = Mock()
config_receiver = ConfigReceiver(mock_pulsar_client) config_receiver = ConfigReceiver(mock_backend)
# Mock the start_flow method with an async function # Mock the start_flow method with an async function
async def mock_start_flow(*args): async def mock_start_flow(*args):

View file

@ -24,10 +24,10 @@ class TestConfigRequestor:
mock_translator_registry.get_response_translator.return_value = mock_response_translator mock_translator_registry.get_response_translator.return_value = mock_response_translator
# Mock dependencies # Mock dependencies
mock_pulsar_client = Mock() mock_backend = Mock()
requestor = ConfigRequestor( requestor = ConfigRequestor(
pulsar_client=mock_pulsar_client, backend=mock_backend,
consumer="test-consumer", consumer="test-consumer",
subscriber="test-subscriber", subscriber="test-subscriber",
timeout=60 timeout=60
@ -55,7 +55,7 @@ class TestConfigRequestor:
with patch.object(ServiceRequestor, 'start', return_value=None), \ with patch.object(ServiceRequestor, 'start', return_value=None), \
patch.object(ServiceRequestor, 'process', return_value=None): patch.object(ServiceRequestor, 'process', return_value=None):
requestor = ConfigRequestor( requestor = ConfigRequestor(
pulsar_client=Mock(), backend=Mock(),
consumer="test-consumer", consumer="test-consumer",
subscriber="test-subscriber" subscriber="test-subscriber"
) )
@ -79,7 +79,7 @@ class TestConfigRequestor:
mock_response_translator.from_response_with_completion.return_value = "translated_response" mock_response_translator.from_response_with_completion.return_value = "translated_response"
requestor = ConfigRequestor( requestor = ConfigRequestor(
pulsar_client=Mock(), backend=Mock(),
consumer="test-consumer", consumer="test-consumer",
subscriber="test-subscriber" subscriber="test-subscriber"
) )

View file

@ -39,12 +39,12 @@ class TestDispatcherManager:
def test_dispatcher_manager_initialization(self): def test_dispatcher_manager_initialization(self):
"""Test DispatcherManager initialization""" """Test DispatcherManager initialization"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
assert manager.pulsar_client == mock_pulsar_client assert manager.backend == mock_backend
assert manager.config_receiver == mock_config_receiver assert manager.config_receiver == mock_config_receiver
assert manager.prefix == "api-gateway" # default prefix assert manager.prefix == "api-gateway" # default prefix
assert manager.flows == {} assert manager.flows == {}
@ -55,19 +55,19 @@ class TestDispatcherManager:
def test_dispatcher_manager_initialization_with_custom_prefix(self): def test_dispatcher_manager_initialization_with_custom_prefix(self):
"""Test DispatcherManager initialization with custom prefix""" """Test DispatcherManager initialization with custom prefix"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix") manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix")
assert manager.prefix == "custom-prefix" assert manager.prefix == "custom-prefix"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_flow(self): async def test_start_flow(self):
"""Test start_flow method""" """Test start_flow method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
@ -79,9 +79,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_flow(self): async def test_stop_flow(self):
"""Test stop_flow method""" """Test stop_flow method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Pre-populate with a flow # Pre-populate with a flow
flow_data = {"name": "test_flow", "steps": []} flow_data = {"name": "test_flow", "steps": []}
@ -93,9 +93,9 @@ class TestDispatcherManager:
def test_dispatch_global_service_returns_wrapper(self): def test_dispatch_global_service_returns_wrapper(self):
"""Test dispatch_global_service returns DispatcherWrapper""" """Test dispatch_global_service returns DispatcherWrapper"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
wrapper = manager.dispatch_global_service() wrapper = manager.dispatch_global_service()
@ -104,9 +104,9 @@ class TestDispatcherManager:
def test_dispatch_core_export_returns_wrapper(self): def test_dispatch_core_export_returns_wrapper(self):
"""Test dispatch_core_export returns DispatcherWrapper""" """Test dispatch_core_export returns DispatcherWrapper"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
wrapper = manager.dispatch_core_export() wrapper = manager.dispatch_core_export()
@ -115,9 +115,9 @@ class TestDispatcherManager:
def test_dispatch_core_import_returns_wrapper(self): def test_dispatch_core_import_returns_wrapper(self):
"""Test dispatch_core_import returns DispatcherWrapper""" """Test dispatch_core_import returns DispatcherWrapper"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
wrapper = manager.dispatch_core_import() wrapper = manager.dispatch_core_import()
@ -127,9 +127,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_core_import(self): async def test_process_core_import(self):
"""Test process_core_import method""" """Test process_core_import method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import: with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
mock_importer = Mock() mock_importer = Mock()
@ -138,16 +138,16 @@ class TestDispatcherManager:
result = await manager.process_core_import("data", "error", "ok", "request") result = await manager.process_core_import("data", "error", "ok", "request")
mock_core_import.assert_called_once_with(mock_pulsar_client) mock_core_import.assert_called_once_with(mock_backend)
mock_importer.process.assert_called_once_with("data", "error", "ok", "request") mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
assert result == "import_result" assert result == "import_result"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_core_export(self): async def test_process_core_export(self):
"""Test process_core_export method""" """Test process_core_export method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export: with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
mock_exporter = Mock() mock_exporter = Mock()
@ -156,16 +156,16 @@ class TestDispatcherManager:
result = await manager.process_core_export("data", "error", "ok", "request") result = await manager.process_core_export("data", "error", "ok", "request")
mock_core_export.assert_called_once_with(mock_pulsar_client) mock_core_export.assert_called_once_with(mock_backend)
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request") mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
assert result == "export_result" assert result == "export_result"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_global_service(self): async def test_process_global_service(self):
"""Test process_global_service method""" """Test process_global_service method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.invoke_global_service = AsyncMock(return_value="global_result") manager.invoke_global_service = AsyncMock(return_value="global_result")
@ -178,9 +178,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_global_service_with_existing_dispatcher(self): async def test_invoke_global_service_with_existing_dispatcher(self):
"""Test invoke_global_service with existing dispatcher""" """Test invoke_global_service with existing dispatcher"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Pre-populate with existing dispatcher # Pre-populate with existing dispatcher
mock_dispatcher = Mock() mock_dispatcher = Mock()
@ -195,9 +195,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_global_service_creates_new_dispatcher(self): async def test_invoke_global_service_creates_new_dispatcher(self):
"""Test invoke_global_service creates new dispatcher""" """Test invoke_global_service creates new dispatcher"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock() mock_dispatcher_class = Mock()
@ -211,10 +211,12 @@ class TestDispatcherManager:
# Verify dispatcher was created with correct parameters # Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
pulsar_client=mock_pulsar_client, backend=mock_backend,
timeout=120, timeout=120,
consumer="api-gateway-config-request", consumer="api-gateway-config-request",
subscriber="api-gateway-config-request" subscriber="api-gateway-config-request",
request_queue=None,
response_queue=None
) )
mock_dispatcher.start.assert_called_once() mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
@ -225,9 +227,9 @@ class TestDispatcherManager:
def test_dispatch_flow_import_returns_method(self): def test_dispatch_flow_import_returns_method(self):
"""Test dispatch_flow_import returns correct method""" """Test dispatch_flow_import returns correct method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
result = manager.dispatch_flow_import() result = manager.dispatch_flow_import()
@ -235,9 +237,9 @@ class TestDispatcherManager:
def test_dispatch_flow_export_returns_method(self): def test_dispatch_flow_export_returns_method(self):
"""Test dispatch_flow_export returns correct method""" """Test dispatch_flow_export returns correct method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
result = manager.dispatch_flow_export() result = manager.dispatch_flow_export()
@ -245,9 +247,9 @@ class TestDispatcherManager:
def test_dispatch_socket_returns_method(self): def test_dispatch_socket_returns_method(self):
"""Test dispatch_socket returns correct method""" """Test dispatch_socket returns correct method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
result = manager.dispatch_socket() result = manager.dispatch_socket()
@ -255,9 +257,9 @@ class TestDispatcherManager:
def test_dispatch_flow_service_returns_wrapper(self): def test_dispatch_flow_service_returns_wrapper(self):
"""Test dispatch_flow_service returns DispatcherWrapper""" """Test dispatch_flow_service returns DispatcherWrapper"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
wrapper = manager.dispatch_flow_service() wrapper = manager.dispatch_flow_service()
@ -267,9 +269,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_flow_import_with_valid_flow_and_kind(self): async def test_process_flow_import_with_valid_flow_and_kind(self):
"""Test process_flow_import with valid flow and kind""" """Test process_flow_import with valid flow and kind"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -292,7 +294,7 @@ class TestDispatcherManager:
result = await manager.process_flow_import("ws", "running", params) result = await manager.process_flow_import("ws", "running", params)
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
pulsar_client=mock_pulsar_client, backend=mock_backend,
ws="ws", ws="ws",
running="running", running="running",
queue={"queue": "test_queue"} queue={"queue": "test_queue"}
@ -303,9 +305,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_flow_import_with_invalid_flow(self): async def test_process_flow_import_with_invalid_flow(self):
"""Test process_flow_import with invalid flow""" """Test process_flow_import with invalid flow"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
params = {"flow": "invalid_flow", "kind": "triples"} params = {"flow": "invalid_flow", "kind": "triples"}
@ -318,9 +320,9 @@ class TestDispatcherManager:
import warnings import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning) warnings.simplefilter("ignore", RuntimeWarning)
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -340,9 +342,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_flow_export_with_valid_flow_and_kind(self): async def test_process_flow_export_with_valid_flow_and_kind(self):
"""Test process_flow_export with valid flow and kind""" """Test process_flow_export with valid flow and kind"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -364,7 +366,7 @@ class TestDispatcherManager:
result = await manager.process_flow_export("ws", "running", params) result = await manager.process_flow_export("ws", "running", params)
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
pulsar_client=mock_pulsar_client, backend=mock_backend,
ws="ws", ws="ws",
running="running", running="running",
queue={"queue": "test_queue"}, queue={"queue": "test_queue"},
@ -376,9 +378,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_socket(self): async def test_process_socket(self):
"""Test process_socket method""" """Test process_socket method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux: with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
mock_mux_instance = Mock() mock_mux_instance = Mock()
@ -392,9 +394,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_flow_service(self): async def test_process_flow_service(self):
"""Test process_flow_service method""" """Test process_flow_service method"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
manager.invoke_flow_service = AsyncMock(return_value="flow_result") manager.invoke_flow_service = AsyncMock(return_value="flow_result")
@ -407,9 +409,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_with_existing_dispatcher(self): async def test_invoke_flow_service_with_existing_dispatcher(self):
"""Test invoke_flow_service with existing dispatcher""" """Test invoke_flow_service with existing dispatcher"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Add flow to the flows dictionary # Add flow to the flows dictionary
manager.flows["test_flow"] = {"services": {"agent": {}}} manager.flows["test_flow"] = {"services": {"agent": {}}}
@ -427,9 +429,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_creates_request_response_dispatcher(self): async def test_invoke_flow_service_creates_request_response_dispatcher(self):
"""Test invoke_flow_service creates request-response dispatcher""" """Test invoke_flow_service creates request-response dispatcher"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -454,7 +456,7 @@ class TestDispatcherManager:
# Verify dispatcher was created with correct parameters # Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
pulsar_client=mock_pulsar_client, backend=mock_backend,
request_queue="agent_request_queue", request_queue="agent_request_queue",
response_queue="agent_response_queue", response_queue="agent_response_queue",
timeout=120, timeout=120,
@ -471,9 +473,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_creates_sender_dispatcher(self): async def test_invoke_flow_service_creates_sender_dispatcher(self):
"""Test invoke_flow_service creates sender dispatcher""" """Test invoke_flow_service creates sender dispatcher"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow # Setup test flow
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -498,7 +500,7 @@ class TestDispatcherManager:
# Verify dispatcher was created with correct parameters # Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue={"queue": "text_load_queue"} queue={"queue": "text_load_queue"}
) )
mock_dispatcher.start.assert_called_once() mock_dispatcher.start.assert_called_once()
@ -511,9 +513,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_invalid_flow(self): async def test_invoke_flow_service_invalid_flow(self):
"""Test invoke_flow_service with invalid flow""" """Test invoke_flow_service with invalid flow"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
with pytest.raises(RuntimeError, match="Invalid flow"): with pytest.raises(RuntimeError, match="Invalid flow"):
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent") await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
@ -521,9 +523,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_unsupported_kind_by_flow(self): async def test_invoke_flow_service_unsupported_kind_by_flow(self):
"""Test invoke_flow_service with kind not supported by flow""" """Test invoke_flow_service with kind not supported by flow"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow without agent interface # Setup test flow without agent interface
manager.flows["test_flow"] = { manager.flows["test_flow"] = {
@ -538,9 +540,9 @@ class TestDispatcherManager:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_flow_service_invalid_kind(self): async def test_invoke_flow_service_invalid_kind(self):
"""Test invoke_flow_service with invalid kind""" """Test invoke_flow_service with invalid kind"""
mock_pulsar_client = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver) manager = DispatcherManager(mock_backend, mock_config_receiver)
# Setup test flow with interface but unsupported kind # Setup test flow with interface but unsupported kind
manager.flows["test_flow"] = { manager.flows["test_flow"] = {

View file

@ -15,12 +15,12 @@ class TestServiceRequestor:
@patch('trustgraph.gateway.dispatch.requestor.Subscriber') @patch('trustgraph.gateway.dispatch.requestor.Subscriber')
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher): def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
"""Test ServiceRequestor initialization""" """Test ServiceRequestor initialization"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_request_schema = MagicMock() mock_request_schema = MagicMock()
mock_response_schema = MagicMock() mock_response_schema = MagicMock()
requestor = ServiceRequestor( requestor = ServiceRequestor(
pulsar_client=mock_pulsar_client, backend=mock_backend,
request_queue="test-request-queue", request_queue="test-request-queue",
request_schema=mock_request_schema, request_schema=mock_request_schema,
response_queue="test-response-queue", response_queue="test-response-queue",
@ -32,12 +32,12 @@ class TestServiceRequestor:
# Verify Publisher was created correctly # Verify Publisher was created correctly
mock_publisher.assert_called_once_with( mock_publisher.assert_called_once_with(
mock_pulsar_client, "test-request-queue", schema=mock_request_schema mock_backend, "test-request-queue", schema=mock_request_schema
) )
# Verify Subscriber was created correctly # Verify Subscriber was created correctly
mock_subscriber.assert_called_once_with( mock_subscriber.assert_called_once_with(
mock_pulsar_client, "test-response-queue", mock_backend, "test-response-queue",
"test-subscription", "test-consumer", mock_response_schema "test-subscription", "test-consumer", mock_response_schema
) )
@ -48,12 +48,12 @@ class TestServiceRequestor:
@patch('trustgraph.gateway.dispatch.requestor.Subscriber') @patch('trustgraph.gateway.dispatch.requestor.Subscriber')
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher): def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
"""Test ServiceRequestor initialization with default parameters""" """Test ServiceRequestor initialization with default parameters"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_request_schema = MagicMock() mock_request_schema = MagicMock()
mock_response_schema = MagicMock() mock_response_schema = MagicMock()
requestor = ServiceRequestor( requestor = ServiceRequestor(
pulsar_client=mock_pulsar_client, backend=mock_backend,
request_queue="test-queue", request_queue="test-queue",
request_schema=mock_request_schema, request_schema=mock_request_schema,
response_queue="response-queue", response_queue="response-queue",
@ -62,7 +62,7 @@ class TestServiceRequestor:
# Verify default values # Verify default values
mock_subscriber.assert_called_once_with( mock_subscriber.assert_called_once_with(
mock_pulsar_client, "response-queue", mock_backend, "response-queue",
"api-gateway", "api-gateway", mock_response_schema "api-gateway", "api-gateway", mock_response_schema
) )
assert requestor.timeout == 600 # Default timeout assert requestor.timeout == 600 # Default timeout
@ -72,14 +72,14 @@ class TestServiceRequestor:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_requestor_start(self, mock_subscriber, mock_publisher): async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
"""Test ServiceRequestor start method""" """Test ServiceRequestor start method"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_sub_instance = AsyncMock() mock_sub_instance = AsyncMock()
mock_pub_instance = AsyncMock() mock_pub_instance = AsyncMock()
mock_subscriber.return_value = mock_sub_instance mock_subscriber.return_value = mock_sub_instance
mock_publisher.return_value = mock_pub_instance mock_publisher.return_value = mock_pub_instance
requestor = ServiceRequestor( requestor = ServiceRequestor(
pulsar_client=mock_pulsar_client, backend=mock_backend,
request_queue="test-queue", request_queue="test-queue",
request_schema=MagicMock(), request_schema=MagicMock(),
response_queue="response-queue", response_queue="response-queue",
@ -98,14 +98,14 @@ class TestServiceRequestor:
@patch('trustgraph.gateway.dispatch.requestor.Subscriber') @patch('trustgraph.gateway.dispatch.requestor.Subscriber')
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher): def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
"""Test ServiceRequestor has correct attributes""" """Test ServiceRequestor has correct attributes"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_pub_instance = AsyncMock() mock_pub_instance = AsyncMock()
mock_sub_instance = AsyncMock() mock_sub_instance = AsyncMock()
mock_publisher.return_value = mock_pub_instance mock_publisher.return_value = mock_pub_instance
mock_subscriber.return_value = mock_sub_instance mock_subscriber.return_value = mock_sub_instance
requestor = ServiceRequestor( requestor = ServiceRequestor(
pulsar_client=mock_pulsar_client, backend=mock_backend,
request_queue="test-queue", request_queue="test-queue",
request_schema=MagicMock(), request_schema=MagicMock(),
response_queue="response-queue", response_queue="response-queue",

View file

@ -14,18 +14,18 @@ class TestServiceSender:
@patch('trustgraph.gateway.dispatch.sender.Publisher') @patch('trustgraph.gateway.dispatch.sender.Publisher')
def test_service_sender_initialization(self, mock_publisher): def test_service_sender_initialization(self, mock_publisher):
"""Test ServiceSender initialization""" """Test ServiceSender initialization"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_schema = MagicMock() mock_schema = MagicMock()
sender = ServiceSender( sender = ServiceSender(
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue", queue="test-queue",
schema=mock_schema schema=mock_schema
) )
# Verify Publisher was created correctly # Verify Publisher was created correctly
mock_publisher.assert_called_once_with( mock_publisher.assert_called_once_with(
mock_pulsar_client, "test-queue", schema=mock_schema mock_backend, "test-queue", schema=mock_schema
) )
@patch('trustgraph.gateway.dispatch.sender.Publisher') @patch('trustgraph.gateway.dispatch.sender.Publisher')
@ -36,7 +36,7 @@ class TestServiceSender:
mock_publisher.return_value = mock_pub_instance mock_publisher.return_value = mock_pub_instance
sender = ServiceSender( sender = ServiceSender(
pulsar_client=MagicMock(), backend=MagicMock(),
queue="test-queue", queue="test-queue",
schema=MagicMock() schema=MagicMock()
) )
@ -55,7 +55,7 @@ class TestServiceSender:
mock_publisher.return_value = mock_pub_instance mock_publisher.return_value = mock_pub_instance
sender = ServiceSender( sender = ServiceSender(
pulsar_client=MagicMock(), backend=MagicMock(),
queue="test-queue", queue="test-queue",
schema=MagicMock() schema=MagicMock()
) )
@ -70,7 +70,7 @@ class TestServiceSender:
def test_service_sender_to_request_not_implemented(self, mock_publisher): def test_service_sender_to_request_not_implemented(self, mock_publisher):
"""Test ServiceSender to_request method raises RuntimeError""" """Test ServiceSender to_request method raises RuntimeError"""
sender = ServiceSender( sender = ServiceSender(
pulsar_client=MagicMock(), backend=MagicMock(),
queue="test-queue", queue="test-queue",
schema=MagicMock() schema=MagicMock()
) )
@ -91,7 +91,7 @@ class TestServiceSender:
return {"processed": request} return {"processed": request}
sender = ConcreteSender( sender = ConcreteSender(
pulsar_client=MagicMock(), backend=MagicMock(),
queue="test-queue", queue="test-queue",
schema=MagicMock() schema=MagicMock()
) )
@ -111,7 +111,7 @@ class TestServiceSender:
mock_publisher.return_value = mock_pub_instance mock_publisher.return_value = mock_pub_instance
sender = ServiceSender( sender = ServiceSender(
pulsar_client=MagicMock(), backend=MagicMock(),
queue="test-queue", queue="test-queue",
schema=MagicMock() schema=MagicMock()
) )

View file

@ -16,7 +16,7 @@ from trustgraph.schema import Metadata, ExtractedObject
@pytest.fixture @pytest.fixture
def mock_pulsar_client(): def mock_backend():
"""Mock Pulsar client.""" """Mock Pulsar client."""
client = Mock() client = Mock()
return client return client
@ -96,7 +96,7 @@ class TestObjectsImportInitialization:
"""Test ObjectsImport initialization.""" """Test ObjectsImport initialization."""
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): def test_init_creates_publisher_with_correct_params(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport creates Publisher with correct parameters.""" """Test that ObjectsImport creates Publisher with correct parameters."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_class.return_value = mock_publisher_instance mock_publisher_class.return_value = mock_publisher_instance
@ -104,13 +104,13 @@ class TestObjectsImportInitialization:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-objects-queue" queue="test-objects-queue"
) )
# Verify Publisher was created with correct parameters # Verify Publisher was created with correct parameters
mock_publisher_class.assert_called_once_with( mock_publisher_class.assert_called_once_with(
mock_pulsar_client, mock_backend,
topic="test-objects-queue", topic="test-objects-queue",
schema=ExtractedObject schema=ExtractedObject
) )
@ -121,12 +121,12 @@ class TestObjectsImportInitialization:
assert objects_import.publisher == mock_publisher_instance assert objects_import.publisher == mock_publisher_instance
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
def test_init_stores_references_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): def test_init_stores_references_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that ObjectsImport stores all required references.""" """Test that ObjectsImport stores all required references."""
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="objects-queue" queue="objects-queue"
) )
@ -139,7 +139,7 @@ class TestObjectsImportLifecycle:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_calls_publisher_start(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_start_calls_publisher_start(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that start() calls publisher.start().""" """Test that start() calls publisher.start()."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.start = AsyncMock() mock_publisher_instance.start = AsyncMock()
@ -148,7 +148,7 @@ class TestObjectsImportLifecycle:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -158,7 +158,7 @@ class TestObjectsImportLifecycle:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_destroy_stops_and_closes_properly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that destroy() properly stops publisher and closes websocket.""" """Test that destroy() properly stops publisher and closes websocket."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.stop = AsyncMock() mock_publisher_instance.stop = AsyncMock()
@ -167,7 +167,7 @@ class TestObjectsImportLifecycle:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -180,7 +180,7 @@ class TestObjectsImportLifecycle:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_pulsar_client, mock_running): async def test_destroy_handles_none_websocket(self, mock_publisher_class, mock_backend, mock_running):
"""Test that destroy() handles None websocket gracefully.""" """Test that destroy() handles None websocket gracefully."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.stop = AsyncMock() mock_publisher_instance.stop = AsyncMock()
@ -189,7 +189,7 @@ class TestObjectsImportLifecycle:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=None, # None websocket ws=None, # None websocket
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -205,7 +205,7 @@ class TestObjectsImportMessageProcessing:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message): async def test_receive_processes_full_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() processes complete message correctly.""" """Test that receive() processes complete message correctly."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock() mock_publisher_instance.send = AsyncMock()
@ -214,7 +214,7 @@ class TestObjectsImportMessageProcessing:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -248,7 +248,7 @@ class TestObjectsImportMessageProcessing:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, minimal_objects_message): async def test_receive_handles_minimal_message(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, minimal_objects_message):
"""Test that receive() handles message with minimal required fields.""" """Test that receive() handles message with minimal required fields."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock() mock_publisher_instance.send = AsyncMock()
@ -257,7 +257,7 @@ class TestObjectsImportMessageProcessing:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -281,7 +281,7 @@ class TestObjectsImportMessageProcessing:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_uses_default_values(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_receive_uses_default_values(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() uses appropriate default values for optional fields.""" """Test that receive() uses appropriate default values for optional fields."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock() mock_publisher_instance.send = AsyncMock()
@ -290,7 +290,7 @@ class TestObjectsImportMessageProcessing:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -323,7 +323,7 @@ class TestObjectsImportRunMethod:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_run_loops_while_running(self, mock_sleep, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that run() loops while running.get() returns True.""" """Test that run() loops while running.get() returns True."""
mock_sleep.return_value = None mock_sleep.return_value = None
mock_publisher_class.return_value = Mock() mock_publisher_class.return_value = Mock()
@ -334,7 +334,7 @@ class TestObjectsImportRunMethod:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -353,7 +353,7 @@ class TestObjectsImportRunMethod:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep') @patch('trustgraph.gateway.dispatch.objects_import.asyncio.sleep')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_pulsar_client, mock_running): async def test_run_handles_none_websocket_gracefully(self, mock_sleep, mock_publisher_class, mock_backend, mock_running):
"""Test that run() handles None websocket gracefully.""" """Test that run() handles None websocket gracefully."""
mock_sleep.return_value = None mock_sleep.return_value = None
mock_publisher_class.return_value = Mock() mock_publisher_class.return_value = Mock()
@ -363,7 +363,7 @@ class TestObjectsImportRunMethod:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=None, # None websocket ws=None, # None websocket
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -417,7 +417,7 @@ class TestObjectsImportBatchProcessing:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, batch_objects_message): async def test_receive_processes_batch_message_correctly(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, batch_objects_message):
"""Test that receive() processes batch message correctly.""" """Test that receive() processes batch message correctly."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock() mock_publisher_instance.send = AsyncMock()
@ -426,7 +426,7 @@ class TestObjectsImportBatchProcessing:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -467,7 +467,7 @@ class TestObjectsImportBatchProcessing:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_receive_handles_empty_batch(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles empty batch correctly.""" """Test that receive() handles empty batch correctly."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock() mock_publisher_instance.send = AsyncMock()
@ -476,7 +476,7 @@ class TestObjectsImportBatchProcessing:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -507,7 +507,7 @@ class TestObjectsImportErrorHandling:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running, sample_objects_message): async def test_receive_propagates_publisher_errors(self, mock_publisher_class, mock_backend, mock_websocket, mock_running, sample_objects_message):
"""Test that receive() propagates publisher send errors.""" """Test that receive() propagates publisher send errors."""
mock_publisher_instance = Mock() mock_publisher_instance = Mock()
mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error")) mock_publisher_instance.send = AsyncMock(side_effect=Exception("Publisher error"))
@ -516,7 +516,7 @@ class TestObjectsImportErrorHandling:
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )
@ -528,14 +528,14 @@ class TestObjectsImportErrorHandling:
@patch('trustgraph.gateway.dispatch.objects_import.Publisher') @patch('trustgraph.gateway.dispatch.objects_import.Publisher')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_pulsar_client, mock_websocket, mock_running): async def test_receive_handles_malformed_json(self, mock_publisher_class, mock_backend, mock_websocket, mock_running):
"""Test that receive() handles malformed JSON appropriately.""" """Test that receive() handles malformed JSON appropriately."""
mock_publisher_class.return_value = Mock() mock_publisher_class.return_value = Mock()
objects_import = ObjectsImport( objects_import = ObjectsImport(
ws=mock_websocket, ws=mock_websocket,
running=mock_running, running=mock_running,
pulsar_client=mock_pulsar_client, backend=mock_backend,
queue="test-queue" queue="test-queue"
) )

View file

@ -19,8 +19,9 @@ class TestApi:
def test_api_initialization_with_defaults(self): def test_api_initialization_with_defaults(self):
"""Test Api initialization with default values""" """Test Api initialization with default values"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_backend = Mock()
mock_get_pubsub.return_value = mock_backend
api = Api() api = Api()
@ -31,11 +32,8 @@ class TestApi:
assert api.prometheus_url == default_prometheus_url + "/" assert api.prometheus_url == default_prometheus_url + "/"
assert api.auth.allow_all is True assert api.auth.allow_all is True
# Verify Pulsar client was created without API key # Verify get_pubsub was called
mock_client.assert_called_once_with( mock_get_pubsub.assert_called_once()
default_pulsar_host,
listener_name=None
)
def test_api_initialization_with_custom_config(self): def test_api_initialization_with_custom_config(self):
"""Test Api initialization with custom configuration""" """Test Api initialization with custom configuration"""
@ -49,10 +47,9 @@ class TestApi:
"api_token": "secret-token" "api_token": "secret-token"
} }
with patch('pulsar.Client') as mock_client, \ with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
patch('pulsar.AuthenticationToken') as mock_auth: mock_backend = Mock()
mock_client.return_value = Mock() mock_get_pubsub.return_value = mock_backend
mock_auth.return_value = Mock()
api = Api(**config) api = Api(**config)
@ -64,34 +61,24 @@ class TestApi:
assert api.auth.token == "secret-token" assert api.auth.token == "secret-token"
assert api.auth.allow_all is False assert api.auth.allow_all is False
# Verify Pulsar client was created with API key # Verify get_pubsub was called with config
mock_auth.assert_called_once_with("test-api-key") mock_get_pubsub.assert_called_once_with(**config)
mock_client.assert_called_once_with(
"pulsar://custom-host:6650",
listener_name="custom-listener",
authentication=mock_auth.return_value
)
def test_api_initialization_with_pulsar_api_key(self): def test_api_initialization_with_pulsar_api_key(self):
"""Test Api initialization with Pulsar API key authentication""" """Test Api initialization with Pulsar API key authentication"""
with patch('pulsar.Client') as mock_client, \ with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
patch('pulsar.AuthenticationToken') as mock_auth: mock_get_pubsub.return_value = Mock()
mock_client.return_value = Mock()
mock_auth.return_value = Mock()
api = Api(pulsar_api_key="test-key") api = Api(pulsar_api_key="test-key")
mock_auth.assert_called_once_with("test-key") # Verify api key was stored
mock_client.assert_called_once_with( assert api.pulsar_api_key == "test-key"
default_pulsar_host, mock_get_pubsub.assert_called_once()
listener_name=None,
authentication=mock_auth.return_value
)
def test_api_initialization_prometheus_url_normalization(self): def test_api_initialization_prometheus_url_normalization(self):
"""Test that prometheus_url gets normalized with trailing slash""" """Test that prometheus_url gets normalized with trailing slash"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
# Test URL without trailing slash # Test URL without trailing slash
api = Api(prometheus_url="http://prometheus:9090") api = Api(prometheus_url="http://prometheus:9090")
@ -103,16 +90,16 @@ class TestApi:
def test_api_initialization_empty_api_token_means_no_auth(self): def test_api_initialization_empty_api_token_means_no_auth(self):
"""Test that empty API token results in allow_all authentication""" """Test that empty API token results in allow_all authentication"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api(api_token="") api = Api(api_token="")
assert api.auth.allow_all is True assert api.auth.allow_all is True
def test_api_initialization_none_api_token_means_no_auth(self): def test_api_initialization_none_api_token_means_no_auth(self):
"""Test that None API token results in allow_all authentication""" """Test that None API token results in allow_all authentication"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api(api_token=None) api = Api(api_token=None)
assert api.auth.allow_all is True assert api.auth.allow_all is True
@ -120,8 +107,8 @@ class TestApi:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_app_factory_creates_application(self): async def test_app_factory_creates_application(self):
"""Test that app_factory creates aiohttp application""" """Test that app_factory creates aiohttp application"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api() api = Api()
@ -147,8 +134,8 @@ class TestApi:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_app_factory_with_custom_endpoints(self): async def test_app_factory_with_custom_endpoints(self):
"""Test app_factory with custom endpoints""" """Test app_factory with custom endpoints"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api() api = Api()
@ -180,9 +167,9 @@ class TestApi:
def test_run_method_calls_web_run_app(self): def test_run_method_calls_web_run_app(self):
"""Test that run method calls web.run_app""" """Test that run method calls web.run_app"""
with patch('pulsar.Client') as mock_client, \ with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \
patch('aiohttp.web.run_app') as mock_run_app: patch('aiohttp.web.run_app') as mock_run_app:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api(port=8080) api = Api(port=8080)
api.run() api.run()
@ -195,8 +182,8 @@ class TestApi:
def test_api_components_initialization(self): def test_api_components_initialization(self):
"""Test that all API components are properly initialized""" """Test that all API components are properly initialized"""
with patch('pulsar.Client') as mock_client: with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub:
mock_client.return_value = Mock() mock_get_pubsub.return_value = Mock()
api = Api() api = Api()
@ -207,7 +194,7 @@ class TestApi:
assert api.endpoints == [] assert api.endpoints == []
# Verify component relationships # Verify component relationships
assert api.dispatcher_manager.pulsar_client == api.pulsar_client assert api.dispatcher_manager.backend == api.pubsub_backend
assert api.dispatcher_manager.config_receiver == api.config_receiver assert api.dispatcher_manager.config_receiver == api.config_receiver
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
# EndpointManager doesn't store auth directly, it passes it to individual endpoints # EndpointManager doesn't store auth directly, it passes it to individual endpoints

View file

@ -129,7 +129,17 @@ async def test_handle_normal_flow():
mock_tg = AsyncMock() mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None) mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request) result = await socket_endpoint.handle(request)
@ -176,11 +186,25 @@ async def test_handle_exception_group_cleanup():
mock_tg = AsyncMock() mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg mock_task_group.return_value = mock_tg
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
mock_wait_for.return_value = None # Make wait_for consume the coroutine passed to it
async def wait_for_side_effect(coro, timeout=None):
coro.close() # Consume the coroutine
return None
mock_wait_for.side_effect = wait_for_side_effect
result = await socket_endpoint.handle(request) result = await socket_endpoint.handle(request)
@ -227,12 +251,26 @@ async def test_handle_dispatcher_cleanup_timeout():
mock_tg = AsyncMock() mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group) mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg mock_task_group.return_value = mock_tg
# Mock asyncio.wait_for to raise TimeoutError # Mock asyncio.wait_for to raise TimeoutError
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for: with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for', new_callable=AsyncMock) as mock_wait_for:
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout") # Make wait_for consume the coroutine before raising
async def wait_for_timeout(coro, timeout=None):
coro.close() # Consume the coroutine
raise asyncio.TimeoutError("Cleanup timeout")
mock_wait_for.side_effect = wait_for_timeout
result = await socket_endpoint.handle(request) result = await socket_endpoint.handle(request)
@ -314,7 +352,17 @@ async def test_handle_websocket_already_closed():
mock_tg = AsyncMock() mock_tg = AsyncMock()
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg) mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
mock_tg.__aexit__ = AsyncMock(return_value=None) mock_tg.__aexit__ = AsyncMock(return_value=None)
mock_tg.create_task = MagicMock(return_value=AsyncMock())
# Create proper mock tasks that look like asyncio.Task objects
def create_task_mock(coro):
# Consume the coroutine to avoid "was never awaited" warning
coro.close()
task = AsyncMock()
task.done = MagicMock(return_value=True)
task.cancelled = MagicMock(return_value=False)
return task
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
mock_task_group.return_value = mock_tg mock_task_group.return_value = mock_tg
result = await socket_endpoint.handle(request) result = await socket_endpoint.handle(request)

View file

@ -0,0 +1,326 @@
"""
Unit tests for streaming behavior in message translators.
These tests verify that translators correctly handle empty strings and
end_of_stream flags in streaming responses, preventing bugs where empty
final chunks could be dropped due to falsy value checks.
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.messaging.translators.retrieval import (
GraphRagResponseTranslator,
DocumentRagResponseTranslator,
)
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
from trustgraph.schema import (
GraphRagResponse,
DocumentRagResponse,
PromptResponse,
TextCompletionResponse,
)
class TestGraphRagResponseTranslator:
"""Test GraphRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert - Empty string should be included in result
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response="Some text",
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["response"] == "Some text"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_response(self):
"""Test that None response is handled correctly"""
# Arrange
translator = GraphRagResponseTranslator()
response = GraphRagResponse(
response=None,
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert - None should not be included
assert "response" not in result
assert result["end_of_stream"] is True
def test_from_response_with_completion_returns_correct_flag(self):
"""Test that from_response_with_completion returns correct is_final flag"""
# Arrange
translator = GraphRagResponseTranslator()
# Test non-final chunk
response_chunk = GraphRagResponse(
response="chunk",
end_of_stream=False,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(response_chunk)
# Assert
assert is_final is False
assert result["end_of_stream"] is False
# Test final chunk with empty content
final_response = GraphRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result, is_final = translator.from_response_with_completion(final_response)
# Assert
assert is_final is True
assert result["response"] == ""
assert result["end_of_stream"] is True
class TestDocumentRagResponseTranslator:
"""Test DocumentRagResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_response(self):
"""Test that empty response strings are preserved"""
# Arrange
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="",
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "response" in result
assert result["response"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_response(self):
"""Test that non-empty responses work correctly"""
# Arrange
translator = DocumentRagResponseTranslator()
response = DocumentRagResponse(
response="Document content",
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["response"] == "Document content"
assert result["end_of_stream"] is False
class TestPromptResponseTranslator:
"""Test PromptResponseTranslator streaming behavior"""
def test_from_pulsar_with_empty_text(self):
"""Test that empty text strings are preserved"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text="",
object=None,
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "text" in result
assert result["text"] == ""
assert result["end_of_stream"] is True
def test_from_pulsar_with_non_empty_text(self):
"""Test that non-empty text works correctly"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text="Some prompt response",
object=None,
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert result["text"] == "Some prompt response"
assert result["end_of_stream"] is False
def test_from_pulsar_with_none_text(self):
"""Test that None text is handled correctly"""
# Arrange
translator = PromptResponseTranslator()
response = PromptResponse(
text=None,
object='{"result": "data"}',
end_of_stream=True,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "text" not in result
assert "object" in result
assert result["end_of_stream"] is True
def test_from_pulsar_includes_end_of_stream(self):
"""Test that end_of_stream flag is always included"""
# Arrange
translator = PromptResponseTranslator()
# Test with end_of_stream=False
response = PromptResponse(
text="chunk",
object=None,
end_of_stream=False,
error=None
)
# Act
result = translator.from_pulsar(response)
# Assert
assert "end_of_stream" in result
assert result["end_of_stream"] is False
class TestTextCompletionResponseTranslator:
"""Test TextCompletionResponseTranslator streaming behavior"""
def test_from_pulsar_always_includes_response(self):
"""Test that response field is always included, even if empty"""
# Arrange
translator = TextCompletionResponseTranslator()
response = TextCompletionResponse(
response="",
end_of_stream=True,
error=None,
in_token=100,
out_token=5,
model="test-model"
)
# Act
result = translator.from_pulsar(response)
# Assert - Response should always be present
assert "response" in result
assert result["response"] == ""
def test_from_response_with_completion_with_empty_final(self):
"""Test that empty final response is handled correctly"""
# Arrange
translator = TextCompletionResponseTranslator()
response = TextCompletionResponse(
response="",
end_of_stream=True,
error=None,
in_token=100,
out_token=5,
model="test-model"
)
# Act
result, is_final = translator.from_response_with_completion(response)
# Assert
assert is_final is True
assert result["response"] == ""
class TestStreamingProtocolCompliance:
"""Test that all translators follow streaming protocol conventions"""
@pytest.mark.parametrize("translator_class,response_class,field_name", [
(GraphRagResponseTranslator, GraphRagResponse, "response"),
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
(PromptResponseTranslator, PromptResponse, "text"),
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
])
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
"""Test that all translators preserve empty final chunks"""
# Arrange
translator = translator_class()
kwargs = {
field_name: "",
"end_of_stream": True,
"error": None,
}
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
# Assert
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
@pytest.mark.parametrize("translator_class,response_class,field_name", [
(GraphRagResponseTranslator, GraphRagResponse, "response"),
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
])
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
"""Test that end_of_stream flag is included in all response translators"""
# Arrange
translator = translator_class()
kwargs = {
field_name: "test content",
"end_of_stream": True,
"error": None,
}
response = response_class(**kwargs)
# Act
result = translator.from_pulsar(response)
# Assert
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
assert result["end_of_stream"] is True

View file

@ -0,0 +1,446 @@
"""
Unit tests for TrustGraph Python API client library
These tests use mocks and do not require a running server.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock, call
import json
from trustgraph.api import (
Api,
Triple,
AgentThought,
AgentObservation,
AgentAnswer,
RAGChunk,
)
class TestApiInstantiation:
"""Test Api class instantiation and configuration"""
def test_api_instantiation_defaults(self):
"""Test Api with default parameters"""
api = Api()
assert api.url == "http://localhost:8088/api/v1/"
assert api.timeout == 60
assert api.token is None
def test_api_instantiation_with_url(self):
"""Test Api with custom URL"""
api = Api(url="http://test-server:9000/")
assert api.url == "http://test-server:9000/api/v1/"
def test_api_instantiation_with_url_trailing_slash(self):
"""Test Api adds trailing slash if missing"""
api = Api(url="http://test-server:9000")
assert api.url == "http://test-server:9000/api/v1/"
def test_api_instantiation_with_token(self):
"""Test Api with authentication token"""
api = Api(token="test-token-123")
assert api.token == "test-token-123"
def test_api_instantiation_with_timeout(self):
"""Test Api with custom timeout"""
api = Api(timeout=120)
assert api.timeout == 120
class TestApiLazyInitialization:
"""Test lazy initialization of client components"""
def test_socket_client_lazy_init(self):
"""Test socket client is created on first access"""
api = Api(url="http://test/", token="token")
assert api._socket_client is None
socket = api.socket()
assert api._socket_client is not None
assert socket is api._socket_client
# Second access returns same instance
socket2 = api.socket()
assert socket2 is socket
def test_bulk_client_lazy_init(self):
"""Test bulk client is created on first access"""
api = Api(url="http://test/")
assert api._bulk_client is None
bulk = api.bulk()
assert api._bulk_client is not None
def test_async_flow_lazy_init(self):
"""Test async flow is created on first access"""
api = Api(url="http://test/")
assert api._async_flow is None
async_flow = api.async_flow()
assert api._async_flow is not None
def test_metrics_lazy_init(self):
"""Test metrics client is created on first access"""
api = Api(url="http://test/")
assert api._metrics is None
metrics = api.metrics()
assert api._metrics is not None
class TestApiContextManager:
"""Test context manager functionality"""
def test_sync_context_manager(self):
"""Test synchronous context manager"""
with Api(url="http://test/") as api:
assert api is not None
assert isinstance(api, Api)
# Should exit cleanly
@pytest.mark.asyncio
async def test_async_context_manager(self):
"""Test asynchronous context manager"""
async with Api(url="http://test/") as api:
assert api is not None
assert isinstance(api, Api)
# Should exit cleanly
class TestFlowClient:
"""Test Flow client functionality"""
@patch('requests.post')
def test_flow_list(self, mock_post):
"""Test listing flows"""
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"flow-ids": ["flow1", "flow2"]}
api = Api(url="http://test/")
flows = api.flow().list()
assert flows == ["flow1", "flow2"]
assert mock_post.called
@patch('requests.post')
def test_flow_list_with_token(self, mock_post):
"""Test flow listing includes auth token"""
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"flow-ids": []}
api = Api(url="http://test/", token="my-token")
api.flow().list()
# Verify Authorization header was set
call_args = mock_post.call_args
headers = call_args[1]['headers'] if 'headers' in call_args[1] else {}
assert 'Authorization' in headers
assert headers['Authorization'] == 'Bearer my-token'
@patch('requests.post')
def test_flow_get(self, mock_post):
"""Test getting flow definition"""
flow_def = {"name": "test-flow", "description": "Test"}
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"flow": json.dumps(flow_def)}
api = Api(url="http://test/")
result = api.flow().get("test-flow")
assert result == flow_def
def test_flow_instance_creation(self):
"""Test creating flow instance"""
api = Api(url="http://test/")
flow_instance = api.flow().id("my-flow")
assert flow_instance is not None
assert flow_instance.id == "my-flow"
def test_flow_instance_has_methods(self):
"""Test flow instance has expected methods"""
api = Api(url="http://test/")
flow_instance = api.flow().id("my-flow")
expected_methods = [
'text_completion', 'agent', 'graph_rag', 'document_rag',
'graph_embeddings_query', 'embeddings', 'prompt',
'triples_query', 'objects_query'
]
for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}"
class TestSocketClient:
"""Test WebSocket client functionality"""
def test_socket_client_url_conversion_http(self):
"""Test HTTP URL converted to WebSocket"""
api = Api(url="http://test-server:8088/")
socket = api.socket()
assert socket.url.startswith("ws://")
assert "test-server" in socket.url
def test_socket_client_url_conversion_https(self):
"""Test HTTPS URL converted to secure WebSocket"""
api = Api(url="https://test-server:8088/")
socket = api.socket()
assert socket.url.startswith("wss://")
def test_socket_client_token_passed(self):
"""Test token is passed to socket client"""
api = Api(url="http://test/", token="socket-token")
socket = api.socket()
assert socket.token == "socket-token"
def test_socket_flow_instance(self):
"""Test creating socket flow instance"""
api = Api(url="http://test/")
socket = api.socket()
flow_instance = socket.flow("test-flow")
assert flow_instance is not None
assert flow_instance.flow_id == "test-flow"
def test_socket_flow_has_methods(self):
"""Test socket flow instance has expected methods"""
api = Api(url="http://test/")
flow_instance = api.socket().flow("test-flow")
expected_methods = [
'agent', 'text_completion', 'graph_rag', 'document_rag',
'prompt', 'graph_embeddings_query', 'embeddings',
'triples_query', 'objects_query', 'mcp_tool'
]
for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}"
class TestBulkClient:
"""Test bulk operations client"""
def test_bulk_client_url_conversion(self):
"""Test bulk client uses WebSocket URL"""
api = Api(url="http://test/")
bulk = api.bulk()
assert bulk.url.startswith("ws://")
def test_bulk_client_has_import_methods(self):
"""Test bulk client has import methods"""
api = Api(url="http://test/")
bulk = api.bulk()
import_methods = [
'import_triples',
'import_graph_embeddings',
'import_document_embeddings',
'import_entity_contexts',
'import_objects'
]
for method in import_methods:
assert hasattr(bulk, method), f"Missing method: {method}"
def test_bulk_client_has_export_methods(self):
"""Test bulk client has export methods"""
api = Api(url="http://test/")
bulk = api.bulk()
export_methods = [
'export_triples',
'export_graph_embeddings',
'export_document_embeddings',
'export_entity_contexts'
]
for method in export_methods:
assert hasattr(bulk, method), f"Missing method: {method}"
class TestMetricsClient:
"""Test metrics client"""
@patch('requests.get')
def test_metrics_get(self, mock_get):
"""Test getting metrics"""
mock_get.return_value.status_code = 200
mock_get.return_value.text = "# HELP metric_name\nmetric_name 42"
api = Api(url="http://test/")
metrics_text = api.metrics().get()
assert "metric_name" in metrics_text
assert mock_get.called
@patch('requests.get')
def test_metrics_with_token(self, mock_get):
"""Test metrics request includes token"""
mock_get.return_value.status_code = 200
mock_get.return_value.text = "metrics"
api = Api(url="http://test/", token="metrics-token")
api.metrics().get()
# Verify token in headers
call_args = mock_get.call_args
headers = call_args[1].get('headers', {})
assert 'Authorization' in headers
class TestStreamingTypes:
"""Test streaming chunk types"""
def test_agent_thought_creation(self):
"""Test creating AgentThought chunk"""
chunk = AgentThought(content="thinking...", end_of_message=False)
assert chunk.content == "thinking..."
assert chunk.end_of_message is False
assert chunk.chunk_type == "thought"
def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..."
assert chunk.chunk_type == "observation"
def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk"""
chunk = AgentAnswer(
content="answer",
end_of_message=True,
end_of_dialog=True
)
assert chunk.content == "answer"
assert chunk.end_of_message is True
assert chunk.end_of_dialog is True
assert chunk.chunk_type == "final-answer"
def test_rag_chunk_creation(self):
"""Test creating RAGChunk"""
chunk = RAGChunk(
content="response chunk",
end_of_stream=False,
error=None
)
assert chunk.content == "response chunk"
assert chunk.end_of_stream is False
assert chunk.error is None
def test_rag_chunk_with_error(self):
"""Test RAGChunk with error"""
error_dict = {"type": "error", "message": "failed"}
chunk = RAGChunk(
content="",
end_of_stream=True,
error=error_dict
)
assert chunk.error == error_dict
class TestTripleType:
"""Test Triple data structure"""
def test_triple_creation(self):
"""Test creating Triple"""
triple = Triple(s="subject", p="predicate", o="object")
assert triple.s == "subject"
assert triple.p == "predicate"
assert triple.o == "object"
def test_triple_with_uris(self):
"""Test Triple with URI values"""
triple = Triple(
s="http://example.org/entity1",
p="http://example.org/relation",
o="http://example.org/entity2"
)
assert triple.s.startswith("http://")
assert triple.p.startswith("http://")
assert triple.o.startswith("http://")
class TestAsyncClients:
"""Test async client availability"""
def test_async_flow_creation(self):
"""Test creating async flow client"""
api = Api(url="http://test/")
async_flow = api.async_flow()
assert async_flow is not None
def test_async_socket_creation(self):
"""Test creating async socket client"""
api = Api(url="http://test/")
async_socket = api.async_socket()
assert async_socket is not None
assert async_socket.url.startswith("ws://")
def test_async_bulk_creation(self):
"""Test creating async bulk client"""
api = Api(url="http://test/")
async_bulk = api.async_bulk()
assert async_bulk is not None
def test_async_metrics_creation(self):
"""Test creating async metrics client"""
api = Api(url="http://test/")
async_metrics = api.async_metrics()
assert async_metrics is not None
class TestErrorHandling:
"""Test error handling"""
@patch('requests.post')
def test_protocol_exception_on_non_200(self, mock_post):
"""Test ProtocolException raised on non-200 status"""
from trustgraph.api.exceptions import ProtocolException
mock_post.return_value.status_code = 500
api = Api(url="http://test/")
with pytest.raises(ProtocolException):
api.flow().list()
@patch('requests.post')
def test_application_exception_on_error_response(self, mock_post):
"""Test ApplicationException on error in response"""
from trustgraph.api.exceptions import ApplicationException
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"error": {
"type": "ValidationError",
"message": "Invalid input"
}
}
api = Api(url="http://test/")
with pytest.raises(ApplicationException):
api.flow().list()
# Run tests with: pytest tests/unit/test_python_api_client.py -v
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -23,9 +23,9 @@ class TestStructuredDiagnosisSchemaContract:
assert request.operation == "detect-type" assert request.operation == "detect-type"
assert request.sample == "test data" assert request.sample == "test data"
assert request.type is None # Optional, defaults to None assert request.type == "" # Optional, defaults to empty string
assert request.schema_name is None # Optional, defaults to None assert request.schema_name == "" # Optional, defaults to empty string
assert request.options is None # Optional, defaults to None assert request.options == {} # Optional, defaults to empty dict
def test_request_schema_all_operations(self): def test_request_schema_all_operations(self):
"""Test request schema supports all operations""" """Test request schema supports all operations"""
@ -66,9 +66,9 @@ class TestStructuredDiagnosisSchemaContract:
assert response.detected_type == "xml" assert response.detected_type == "xml"
assert response.confidence == 0.9 assert response.confidence == 0.9
assert response.error is None assert response.error is None
assert response.descriptor is None assert response.descriptor == "" # Defaults to empty string
assert response.metadata is None assert response.metadata == {} # Defaults to empty dict
assert response.schema_matches is None # New field, defaults to None assert response.schema_matches == [] # Defaults to empty list
def test_response_schema_with_error(self): def test_response_schema_with_error(self):
"""Test response schema with error""" """Test response schema with error"""
@ -140,6 +140,7 @@ class TestStructuredDiagnosisSchemaContract:
assert response.metadata == metadata assert response.metadata == metadata
assert response.metadata["field_count"] == "5" assert response.metadata["field_count"] == "5"
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
def test_schema_serialization(self): def test_schema_serialization(self):
"""Test that schemas can be serialized and deserialized correctly""" """Test that schemas can be serialized and deserialized correctly"""
# Test request serialization # Test request serialization
@ -158,6 +159,7 @@ class TestStructuredDiagnosisSchemaContract:
assert deserialized.sample == request.sample assert deserialized.sample == request.sample
assert deserialized.options == request.options assert deserialized.options == request.options
@pytest.mark.skip(reason="JsonSchema requires Pulsar Record types, not dataclasses")
def test_response_serialization_with_schema_matches(self): def test_response_serialization_with_schema_matches(self):
"""Test response serialization with schema_matches array""" """Test response serialization with schema_matches array"""
response = StructuredDataDiagnosisResponse( response = StructuredDataDiagnosisResponse(
@ -185,7 +187,7 @@ class TestStructuredDiagnosisSchemaContract:
) )
# Verify default value for new field # Verify default value for new field
assert response.schema_matches is None # Defaults to None when not set assert response.schema_matches == [] # Defaults to empty list when not set
# Verify old fields still work # Verify old fields still work
assert response.detected_type == "json" assert response.detected_type == "json"
@ -221,7 +223,7 @@ class TestStructuredDiagnosisSchemaContract:
) )
assert error_response.error is not None assert error_response.error is not None
assert error_response.schema_matches is None # Default None when not set assert error_response.schema_matches == [] # Default empty list when not set
def test_all_operations_supported(self): def test_all_operations_supported(self):
"""Verify all operations are properly supported in the contract""" """Verify all operations are properly supported in the contract"""

View file

@ -72,7 +72,7 @@ class TestMessageDispatcher:
assert dispatcher.max_workers == 10 assert dispatcher.max_workers == 10
assert dispatcher.semaphore._value == 10 assert dispatcher.semaphore._value == 10
assert dispatcher.active_tasks == set() assert dispatcher.active_tasks == set()
assert dispatcher.pulsar_client is None assert dispatcher.backend is None
assert dispatcher.dispatcher_manager is None assert dispatcher.dispatcher_manager is None
assert len(dispatcher.service_mapping) > 0 assert len(dispatcher.service_mapping) > 0
@ -86,7 +86,7 @@ class TestMessageDispatcher:
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
"""Test MessageDispatcher initialization with pulsar_client and config_receiver""" """Test MessageDispatcher initialization with pulsar_client and config_receiver"""
mock_pulsar_client = MagicMock() mock_backend = MagicMock()
mock_config_receiver = MagicMock() mock_config_receiver = MagicMock()
mock_dispatcher_instance = MagicMock() mock_dispatcher_instance = MagicMock()
mock_dispatcher_manager.return_value = mock_dispatcher_instance mock_dispatcher_manager.return_value = mock_dispatcher_instance
@ -94,14 +94,14 @@ class TestMessageDispatcher:
dispatcher = MessageDispatcher( dispatcher = MessageDispatcher(
max_workers=8, max_workers=8,
config_receiver=mock_config_receiver, config_receiver=mock_config_receiver,
pulsar_client=mock_pulsar_client backend=mock_backend
) )
assert dispatcher.max_workers == 8 assert dispatcher.max_workers == 8
assert dispatcher.pulsar_client == mock_pulsar_client assert dispatcher.backend == mock_backend
assert dispatcher.dispatcher_manager == mock_dispatcher_instance assert dispatcher.dispatcher_manager == mock_dispatcher_instance
mock_dispatcher_manager.assert_called_once_with( mock_dispatcher_manager.assert_called_once_with(
mock_pulsar_client, mock_config_receiver, prefix="rev-gateway" mock_backend, mock_config_receiver, prefix="rev-gateway"
) )
def test_message_dispatcher_service_mapping(self): def test_message_dispatcher_service_mapping(self):

View file

@ -16,11 +16,11 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with default parameters""" """Test ReverseGateway initialization with default parameters"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
@ -38,11 +38,11 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with custom parameters""" """Test ReverseGateway initialization with custom parameters"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway( gateway = ReverseGateway(
websocket_uri="wss://example.com:8080/websocket", websocket_uri="wss://example.com:8080/websocket",
@ -65,11 +65,11 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with WebSocket URI missing path""" """Test ReverseGateway initialization with WebSocket URI missing path"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway(websocket_uri="ws://example.com") gateway = ReverseGateway(websocket_uri="ws://example.com")
@ -78,53 +78,49 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with invalid WebSocket scheme""" """Test ReverseGateway initialization with invalid WebSocket scheme"""
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
ReverseGateway(websocket_uri="http://example.com") ReverseGateway(websocket_uri="http://example.com")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway initialization with missing hostname""" """Test ReverseGateway initialization with missing hostname"""
with pytest.raises(ValueError, match="WebSocket URI must include hostname"): with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
ReverseGateway(websocket_uri="ws://") ReverseGateway(websocket_uri="ws://")
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway creates Pulsar client with authentication""" """Test ReverseGateway creates backend with authentication"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
with patch('pulsar.AuthenticationToken') as mock_auth:
mock_auth_instance = MagicMock()
mock_auth.return_value = mock_auth_instance
gateway = ReverseGateway( gateway = ReverseGateway(
pulsar_api_key="test-key", pulsar_api_key="test-key",
pulsar_listener="test-listener" pulsar_listener="test-listener"
) )
mock_auth.assert_called_once_with("test-key") # Verify get_pubsub was called with the correct parameters
mock_pulsar_client.assert_called_once_with( mock_get_pubsub.assert_called_once_with(
"pulsar://pulsar:6650", pulsar_host="pulsar://pulsar:6650",
listener_name="test-listener", pulsar_api_key="test-key",
authentication=mock_auth_instance pulsar_listener="test-listener"
) )
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway successful connection""" """Test ReverseGateway successful connection"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
mock_session = AsyncMock() mock_session = AsyncMock()
mock_ws = AsyncMock() mock_ws = AsyncMock()
@ -142,13 +138,13 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@patch('trustgraph.rev_gateway.service.ClientSession') @patch('trustgraph.rev_gateway.service.ClientSession')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway connection failure""" """Test ReverseGateway connection failure"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session.ws_connect.side_effect = Exception("Connection failed")
@ -162,12 +158,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway disconnect""" """Test ReverseGateway disconnect"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
@ -189,12 +185,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message""" """Test ReverseGateway send message"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
@ -211,12 +207,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway send message with closed connection""" """Test ReverseGateway send message with closed connection"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
@ -234,12 +230,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message""" """Test ReverseGateway handle message"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock() mock_dispatcher_instance = AsyncMock()
mock_dispatcher_instance.handle_message.return_value = {"response": "success"} mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
@ -263,12 +259,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway handle message with invalid JSON""" """Test ReverseGateway handle message with invalid JSON"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
@ -285,12 +281,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with text message""" """Test ReverseGateway listen with text message"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
gateway.running = True gateway.running = True
@ -318,12 +314,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with binary message""" """Test ReverseGateway listen with binary message"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
gateway.running = True gateway.running = True
@ -351,12 +347,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway listen with close message""" """Test ReverseGateway listen with close message"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
gateway.running = True gateway.running = True
@ -383,12 +379,12 @@ class TestReverseGateway:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway shutdown""" """Test ReverseGateway shutdown"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
mock_dispatcher_instance = AsyncMock() mock_dispatcher_instance = AsyncMock()
mock_dispatcher.return_value = mock_dispatcher_instance mock_dispatcher.return_value = mock_dispatcher_instance
@ -404,15 +400,15 @@ class TestReverseGateway:
assert gateway.running is False assert gateway.running is False
mock_dispatcher_instance.shutdown.assert_called_once() mock_dispatcher_instance.shutdown.assert_called_once()
gateway.disconnect.assert_called_once() gateway.disconnect.assert_called_once()
mock_client_instance.close.assert_called_once() mock_backend.close.assert_called_once()
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway stop""" """Test ReverseGateway stop"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
gateway = ReverseGateway() gateway = ReverseGateway()
gateway.running = True gateway.running = True
@ -427,12 +423,12 @@ class TestReverseGatewayRun:
@patch('trustgraph.rev_gateway.service.ConfigReceiver') @patch('trustgraph.rev_gateway.service.ConfigReceiver')
@patch('trustgraph.rev_gateway.service.MessageDispatcher') @patch('trustgraph.rev_gateway.service.MessageDispatcher')
@patch('pulsar.Client') @patch('trustgraph.rev_gateway.service.get_pubsub')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver): async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver):
"""Test ReverseGateway run method with successful connect/listen cycle""" """Test ReverseGateway run method with successful connect/listen cycle"""
mock_client_instance = MagicMock() mock_backend = MagicMock()
mock_pulsar_client.return_value = mock_client_instance mock_get_pubsub.return_value = mock_backend
mock_config_receiver_instance = AsyncMock() mock_config_receiver_instance = AsyncMock()
mock_config_receiver.return_value = mock_config_receiver_instance mock_config_receiver.return_value = mock_config_receiver_instance

View file

@ -15,11 +15,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant document embeddings storage functionality""" """Test Qdrant document embeddings storage functionality"""
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_processor_initialization_basic(self, mock_qdrant_client):
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization""" """Test basic Qdrant processor initialization"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -34,9 +32,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Assert # Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters # Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
@ -45,11 +40,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
assert processor.qdrant == mock_qdrant_instance assert processor.qdrant == mock_qdrant_instance
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values""" """Test processor initialization with default values"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -68,11 +61,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_store_document_embeddings_basic(self, mock_uuid, mock_qdrant_client):
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with basic message""" """Test storing document embeddings with basic message"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -88,6 +79,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create mock message with chunks and vectors # Create mock message with chunks and vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user' mock_message.metadata.user = 'test_user'
@ -121,11 +115,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_store_document_embeddings_multiple_chunks(self, mock_uuid, mock_qdrant_client):
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple chunks""" """Test storing document embeddings with multiple chunks"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -141,6 +133,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('multi_user', 'multi_collection')] = {}
# Create mock message with multiple chunks # Create mock message with multiple chunks
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'multi_user' mock_message.metadata.user = 'multi_user'
@ -180,11 +175,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_uuid, mock_qdrant_client):
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing document embeddings with multiple vectors per chunk""" """Test storing document embeddings with multiple vectors per chunk"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -200,6 +193,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with chunk having multiple vectors # Create mock message with chunk having multiple vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user' mock_message.metadata.user = 'vector_user'
@ -237,11 +233,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
assert point.payload['doc'] == 'multi-vector document chunk' assert point.payload['doc'] == 'multi-vector document chunk'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_store_document_embeddings_empty_chunk(self, mock_qdrant_client):
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
"""Test storing document embeddings skips empty chunks""" """Test storing document embeddings skips empty chunks"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -277,11 +271,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_collection_creation_when_not_exists(self, mock_uuid, mock_qdrant_client):
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that writing to non-existent collection creates it lazily""" """Test that writing to non-existent collection creates it lazily"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -297,6 +289,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('new_user', 'new_collection')] = {}
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'new_user' mock_message.metadata.user = 'new_user'
@ -326,11 +321,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_collection_creation_exception(self, mock_uuid, mock_qdrant_client):
async def test_collection_creation_exception(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that collection creation errors are propagated""" """Test that collection creation errors are propagated"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
# Simulate creation failure # Simulate creation failure
@ -348,6 +341,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('error_user', 'error_collection')] = {}
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'error_user' mock_message.metadata.user = 'error_user'
@ -364,12 +360,10 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client): async def test_collection_validation_on_write(self, mock_uuid, mock_qdrant_client):
"""Test collection validation checks collection exists before writing""" """Test collection validation checks collection exists before writing"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -385,6 +379,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('cache_user', 'cache_collection')] = {}
# Create first mock message # Create first mock message
mock_message1 = MagicMock() mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user' mock_message1.metadata.user = 'cache_user'
@ -428,11 +425,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_different_dimensions_different_collections(self, mock_uuid, mock_qdrant_client):
async def test_different_dimensions_different_collections(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test that different vector dimensions create different collections""" """Test that different vector dimensions create different collections"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -448,6 +443,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('dim_user', 'dim_collection')] = {}
# Create mock message with different dimension vectors # Create mock message with different dimension vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'dim_user' mock_message.metadata.user = 'dim_user'
@ -482,11 +480,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3' assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_add_args_calls_parent(self, mock_qdrant_client):
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method""" """Test that add_args() calls parent add_args method"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock() mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock() mock_parser = MagicMock()
@ -502,11 +498,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_utf8_decoding_handling(self, mock_uuid, mock_qdrant_client):
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test proper UTF-8 decoding of chunk text""" """Test proper UTF-8 decoding of chunk text"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -522,6 +516,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('utf8_user', 'utf8_collection')] = {}
# Create mock message with UTF-8 encoded text # Create mock message with UTF-8 encoded text
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'utf8_user' mock_message.metadata.user = 'utf8_user'
@ -546,11 +543,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé' assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__') async def test_chunk_decode_exception_handling(self, mock_qdrant_client):
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
"""Test handling of chunk decode exceptions""" """Test handling of chunk decode exceptions"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -563,6 +558,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('decode_user', 'decode_collection')] = {}
# Create mock message with decode error # Create mock message with decode error
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'decode_user' mock_message.metadata.user = 'decode_user'

View file

@ -15,11 +15,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
"""Test Qdrant graph embeddings storage functionality""" """Test Qdrant graph embeddings storage functionality"""
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_processor_initialization_basic(self, mock_qdrant_client):
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
"""Test basic Qdrant processor initialization""" """Test basic Qdrant processor initialization"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -34,9 +32,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Assert # Assert
# Verify base class initialization was called
mock_base_init.assert_called_once()
# Verify QdrantClient was created with correct parameters # Verify QdrantClient was created with correct parameters
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key') mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
@ -46,11 +41,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_store_graph_embeddings_basic(self, mock_uuid, mock_qdrant_client):
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with basic message""" """Test storing graph embeddings with basic message"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -65,6 +58,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('test_user', 'test_collection')] = {}
# Create mock message with entities and vectors # Create mock message with entities and vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user' mock_message.metadata.user = 'test_user'
@ -98,11 +94,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_store_graph_embeddings_multiple_entities(self, mock_uuid, mock_qdrant_client):
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple entities""" """Test storing graph embeddings with multiple entities"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -117,6 +111,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('multi_user', 'multi_collection')] = {}
# Create mock message with multiple entities # Create mock message with multiple entities
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'multi_user' mock_message.metadata.user = 'multi_user'
@ -156,11 +153,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_uuid, mock_qdrant_client):
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
"""Test storing graph embeddings with multiple vectors per entity""" """Test storing graph embeddings with multiple vectors per entity"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -175,6 +170,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Add collection to known_collections (simulates config push)
processor.known_collections[('vector_user', 'vector_collection')] = {}
# Create mock message with entity having multiple vectors # Create mock message with entity having multiple vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user' mock_message.metadata.user = 'vector_user'
@ -212,11 +210,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
assert point.payload['entity'] == 'multi_vector_entity' assert point.payload['entity'] == 'multi_vector_entity'
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_store_graph_embeddings_empty_entity_value(self, mock_qdrant_client):
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
"""Test storing graph embeddings skips empty entity values""" """Test storing graph embeddings skips empty entity values"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -253,11 +249,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance.collection_exists.assert_not_called() mock_qdrant_instance.collection_exists.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_processor_initialization_with_defaults(self, mock_qdrant_client):
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
"""Test processor initialization with default values""" """Test processor initialization with default values"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_qdrant_client.return_value = mock_qdrant_instance mock_qdrant_client.return_value = mock_qdrant_instance
@ -275,11 +269,9 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None) mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__') async def test_add_args_calls_parent(self, mock_qdrant_client):
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
"""Test that add_args() calls parent add_args method""" """Test that add_args() calls parent add_args method"""
# Arrange # Arrange
mock_base_init.return_value = None
mock_qdrant_client.return_value = MagicMock() mock_qdrant_client.return_value = MagicMock()
mock_parser = MagicMock() mock_parser = MagicMock()

View file

@ -13,6 +13,7 @@ dependencies = [
"pulsar-client", "pulsar-client",
"prometheus-client", "prometheus-client",
"requests", "requests",
"python-logging-loki",
] ]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",

View file

@ -1,3 +1,114 @@
from . api import * # Core API
from .api import Api
# Flow clients
from .flow import Flow, FlowInstance
from .async_flow import AsyncFlow, AsyncFlowInstance
# WebSocket clients
from .socket_client import SocketClient, SocketFlowInstance
from .async_socket_client import AsyncSocketClient, AsyncSocketFlowInstance
# Bulk operation clients
from .bulk_client import BulkClient
from .async_bulk_client import AsyncBulkClient
# Metrics clients
from .metrics import Metrics
from .async_metrics import AsyncMetrics
# Types
from .types import (
Triple,
ConfigKey,
ConfigValue,
DocumentMetadata,
ProcessingMetadata,
CollectionMetadata,
StreamingChunk,
AgentThought,
AgentObservation,
AgentAnswer,
RAGChunk,
)
# Exceptions
from .exceptions import (
ProtocolException,
TrustGraphException,
AgentError,
ConfigError,
DocumentRagError,
FlowError,
GatewayError,
GraphRagError,
LLMError,
LoadError,
LookupError,
NLPQueryError,
ObjectsQueryError,
RequestError,
StructuredQueryError,
UnexpectedError,
# Legacy alias
ApplicationException,
)
__all__ = [
# Core API
"Api",
# Flow clients
"Flow",
"FlowInstance",
"AsyncFlow",
"AsyncFlowInstance",
# WebSocket clients
"SocketClient",
"SocketFlowInstance",
"AsyncSocketClient",
"AsyncSocketFlowInstance",
# Bulk operation clients
"BulkClient",
"AsyncBulkClient",
# Metrics clients
"Metrics",
"AsyncMetrics",
# Types
"Triple",
"ConfigKey",
"ConfigValue",
"DocumentMetadata",
"ProcessingMetadata",
"CollectionMetadata",
"StreamingChunk",
"AgentThought",
"AgentObservation",
"AgentAnswer",
"RAGChunk",
# Exceptions
"ProtocolException",
"TrustGraphException",
"AgentError",
"ConfigError",
"DocumentRagError",
"FlowError",
"GatewayError",
"GraphRagError",
"LLMError",
"LoadError",
"LookupError",
"NLPQueryError",
"ObjectsQueryError",
"RequestError",
"StructuredQueryError",
"UnexpectedError",
"ApplicationException", # Legacy alias
]

View file

@ -3,6 +3,7 @@ import requests
import json import json
import base64 import base64
import time import time
from typing import Optional
from . library import Library from . library import Library
from . flow import Flow from . flow import Flow
@ -26,7 +27,7 @@ def check_error(response):
class Api: class Api:
def __init__(self, url="http://localhost:8088/", timeout=60): def __init__(self, url="http://localhost:8088/", timeout=60, token: Optional[str] = None):
self.url = url self.url = url
@ -36,6 +37,16 @@ class Api:
self.url += "api/v1/" self.url += "api/v1/"
self.timeout = timeout self.timeout = timeout
self.token = token
# Lazy initialization for new clients
self._socket_client = None
self._bulk_client = None
self._async_flow = None
self._async_socket_client = None
self._async_bulk_client = None
self._metrics = None
self._async_metrics = None
def flow(self): def flow(self):
return Flow(api=self) return Flow(api=self)
@ -50,8 +61,12 @@ class Api:
url = f"{self.url}{path}" url = f"{self.url}{path}"
headers = {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
# Invoke the API, input is passed as JSON # Invoke the API, input is passed as JSON
resp = requests.post(url, json=request, timeout=self.timeout) resp = requests.post(url, json=request, timeout=self.timeout, headers=headers)
# Should be a 200 status code # Should be a 200 status code
if resp.status_code != 200: if resp.status_code != 200:
@ -72,3 +87,96 @@ class Api:
def collection(self): def collection(self):
return Collection(self) return Collection(self)
# New synchronous methods
def socket(self):
"""Synchronous WebSocket-based interface for streaming operations"""
if self._socket_client is None:
from . socket_client import SocketClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._socket_client = SocketClient(base_url, self.timeout, self.token)
return self._socket_client
def bulk(self):
"""Synchronous bulk operations interface for import/export"""
if self._bulk_client is None:
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
return self._bulk_client
def metrics(self):
"""Synchronous metrics interface"""
if self._metrics is None:
from . metrics import Metrics
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._metrics = Metrics(base_url, self.timeout, self.token)
return self._metrics
# New asynchronous methods
def async_flow(self):
"""Asynchronous REST-based flow interface"""
if self._async_flow is None:
from . async_flow import AsyncFlow
self._async_flow = AsyncFlow(self.url, self.timeout, self.token)
return self._async_flow
def async_socket(self):
"""Asynchronous WebSocket-based interface for streaming operations"""
if self._async_socket_client is None:
from . async_socket_client import AsyncSocketClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_socket_client = AsyncSocketClient(base_url, self.timeout, self.token)
return self._async_socket_client
def async_bulk(self):
"""Asynchronous bulk operations interface for import/export"""
if self._async_bulk_client is None:
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
return self._async_bulk_client
def async_metrics(self):
"""Asynchronous metrics interface"""
if self._async_metrics is None:
from . async_metrics import AsyncMetrics
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_metrics = AsyncMetrics(base_url, self.timeout, self.token)
return self._async_metrics
# Resource management
def close(self):
"""Close all synchronous connections"""
if self._socket_client:
self._socket_client.close()
if self._bulk_client:
self._bulk_client.close()
async def aclose(self):
"""Close all asynchronous connections"""
if self._async_socket_client:
await self._async_socket_client.aclose()
if self._async_bulk_client:
await self._async_bulk_client.aclose()
if self._async_flow:
await self._async_flow.aclose()
# Context manager support
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.aclose()

View file

@ -0,0 +1,131 @@
import json
import websockets
from typing import Optional, AsyncIterator, Dict, Any, Iterator
from . types import Triple
class AsyncBulkClient:
"""Asynchronous bulk operations client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
if url.startswith("http://"):
return url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
return url.replace("https://", "wss://", 1)
elif url.startswith("ws://") or url.startswith("wss://"):
return url
else:
return f"ws://{url}"
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
message = {
"s": triple.s,
"p": triple.p,
"o": triple.o
}
await websocket.send(json.dumps(message))
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
data = json.loads(raw_message)
yield Triple(
s=data.get("s", ""),
p=data.get("p", ""),
o=data.get("o", "")
)
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
await websocket.send(json.dumps(embedding))
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
await websocket.send(json.dumps(embedding))
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
await websocket.send(json.dumps(context))
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
async def import_objects(self, flow: str, objects: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import objects via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for obj in objects:
await websocket.send(json.dumps(obj))
async def aclose(self) -> None:
"""Close connections"""
# Cleanup handled by context managers
pass

View file

@ -0,0 +1,245 @@
import aiohttp
import json
from typing import Optional, Dict, Any, List
from . exceptions import ProtocolException, ApplicationException
def check_error(response):
if "error" in response:
try:
msg = response["error"]["message"]
tp = response["error"]["type"]
except:
raise ApplicationException(response["error"])
raise ApplicationException(f"{tp}: {msg}")
class AsyncFlow:
"""Asynchronous REST-based flow interface"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = url
self.timeout: int = timeout
self.token: Optional[str] = token
async def request(self, path: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Make async HTTP request to Gateway API"""
url = f"{self.url}{path}"
headers = {"Content-Type": "application/json"}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=request_data, headers=headers) as resp:
if resp.status != 200:
raise ProtocolException(f"Status code {resp.status}")
try:
obj = await resp.json()
except:
raise ProtocolException(f"Expected JSON response")
check_error(obj)
return obj
async def list(self) -> List[str]:
"""List all flows"""
result = await self.request("flow", {"operation": "list-flows"})
return result.get("flow-ids", [])
async def get(self, id: str) -> Dict[str, Any]:
"""Get flow definition"""
result = await self.request("flow", {
"operation": "get-flow",
"flow-id": id
})
return json.loads(result.get("flow", "{}"))
async def start(self, class_name: str, id: str, description: str, parameters: Optional[Dict] = None):
"""Start a flow"""
request_data = {
"operation": "start-flow",
"flow-id": id,
"class-name": class_name,
"description": description
}
if parameters:
request_data["parameters"] = json.dumps(parameters)
await self.request("flow", request_data)
async def stop(self, id: str):
"""Stop a flow"""
await self.request("flow", {
"operation": "stop-flow",
"flow-id": id
})
async def list_classes(self) -> List[str]:
"""List flow classes"""
result = await self.request("flow", {"operation": "list-classes"})
return result.get("class-names", [])
async def get_class(self, class_name: str) -> Dict[str, Any]:
"""Get flow class definition"""
result = await self.request("flow", {
"operation": "get-class",
"class-name": class_name
})
return json.loads(result.get("class-definition", "{}"))
async def put_class(self, class_name: str, definition: Dict[str, Any]):
"""Create/update flow class"""
await self.request("flow", {
"operation": "put-class",
"class-name": class_name,
"class-definition": json.dumps(definition)
})
async def delete_class(self, class_name: str):
"""Delete flow class"""
await self.request("flow", {
"operation": "delete-class",
"class-name": class_name
})
def id(self, flow_id: str):
"""Get async flow instance"""
return AsyncFlowInstance(self, flow_id)
async def aclose(self) -> None:
"""Close connection (cleanup handled by aiohttp session)"""
pass
class AsyncFlowInstance:
"""Asynchronous REST flow instance"""
def __init__(self, flow: AsyncFlow, flow_id: str):
self.flow = flow
self.flow_id = flow_id
async def request(self, service: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Make request to flow-scoped service"""
return await self.flow.request(f"flow/{self.flow_id}/service/{service}", request_data)
async def agent(self, question: str, user: str, state: Optional[Dict] = None,
group: Optional[str] = None, history: Optional[List] = None, **kwargs: Any) -> Dict[str, Any]:
"""Execute agent (non-streaming, use async_socket for streaming)"""
request_data = {
"question": question,
"user": user,
"streaming": False # REST doesn't support streaming
}
if state is not None:
request_data["state"] = state
if group is not None:
request_data["group"] = group
if history is not None:
request_data["history"] = history
request_data.update(kwargs)
return await self.request("agent", request_data)
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
"""Text completion (non-streaming, use async_socket for streaming)"""
request_data = {
"system": system,
"prompt": prompt,
"streaming": False
}
request_data.update(kwargs)
result = await self.request("text-completion", request_data)
return result.get("response", "")
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
max_entity_distance: int = 3, **kwargs: Any) -> str:
"""Graph RAG (non-streaming, use async_socket for streaming)"""
request_data = {
"query": query,
"user": user,
"collection": collection,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"streaming": False
}
request_data.update(kwargs)
result = await self.request("graph-rag", request_data)
return result.get("response", "")
async def document_rag(self, query: str, user: str, collection: str,
doc_limit: int = 10, **kwargs: Any) -> str:
"""Document RAG (non-streaming, use async_socket for streaming)"""
request_data = {
"query": query,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": False
}
request_data.update(kwargs)
result = await self.request("document-rag", request_data)
return result.get("response", "")
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs: Any):
"""Query graph embeddings for semantic search"""
request_data = {
"text": text,
"user": user,
"collection": collection,
"limit": limit
}
request_data.update(kwargs)
return await self.request("graph-embeddings", request_data)
async def embeddings(self, text: str, **kwargs: Any):
"""Generate text embeddings"""
request_data = {"text": text}
request_data.update(kwargs)
return await self.request("embeddings", request_data)
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs: Any):
"""Triple pattern query"""
request_data = {"limit": limit}
if s is not None:
request_data["s"] = str(s)
if p is not None:
request_data["p"] = str(p)
if o is not None:
request_data["o"] = str(o)
if user is not None:
request_data["user"] = user
if collection is not None:
request_data["collection"] = collection
request_data.update(kwargs)
return await self.request("triples", request_data)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs: Any):
"""GraphQL query"""
request_data = {
"query": query,
"user": user,
"collection": collection
}
if variables:
request_data["variables"] = variables
if operation_name:
request_data["operationName"] = operation_name
request_data.update(kwargs)
return await self.request("objects", request_data)

View file

@ -0,0 +1,33 @@
import aiohttp
from typing import Optional, Dict
class AsyncMetrics:
"""Asynchronous metrics client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = url
self.timeout: int = timeout
self.token: Optional[str] = token
async def get(self) -> str:
"""Get Prometheus metrics as text"""
url: str = f"{self.url}/api/metrics"
headers: Dict[str, str] = {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
if resp.status != 200:
raise Exception(f"Status code {resp.status}")
return await resp.text()
async def aclose(self) -> None:
"""Close connections"""
pass

View file

@ -0,0 +1,343 @@
import json
import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
from . exceptions import ProtocolException, ApplicationException
class AsyncSocketClient:
"""Asynchronous WebSocket client"""
def __init__(self, url: str, timeout: int, token: Optional[str]):
self.url = self._convert_to_ws_url(url)
self.timeout = timeout
self.token = token
self._request_counter = 0
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
if url.startswith("http://"):
return url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
return url.replace("https://", "wss://", 1)
elif url.startswith("ws://") or url.startswith("wss://"):
return url
else:
# Assume ws://
return f"ws://{url}"
def flow(self, flow_id: str):
"""Get async flow instance for WebSocket operations"""
return AsyncSocketFlowInstance(self, flow_id)
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (non-streaming)"""
# Generate unique request ID
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Wait for single response
raw_message = await websocket.recv()
response = json.loads(raw_message)
if response.get("id") != request_id:
raise ProtocolException(f"Response ID mismatch")
if "error" in response:
raise ApplicationException(response["error"])
if "response" not in response:
raise ProtocolException(f"Missing response in message")
return response["response"]
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (streaming)"""
# Generate unique request ID
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Yield chunks as they arrive
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != request_id:
continue # Ignore messages for other requests
if "error" in response:
raise ApplicationException(response["error"])
if "response" in response:
resp = response["response"]
# Parse different chunk types
chunk = self._parse_chunk(resp)
yield chunk
# Check if this is the final chunk
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
break
def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type"""
chunk_type = resp.get("chunk_type")
if chunk_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "answer" or chunk_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
)
elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
else:
# RAG-style chunk (or generic chunk)
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None # Errors are always thrown, never stored
)
async def aclose(self):
"""Close WebSocket connection"""
# Cleanup handled by context manager
pass
class AsyncSocketFlowInstance:
"""Asynchronous WebSocket flow instance"""
def __init__(self, client: AsyncSocketClient, flow_id: str):
self.client = client
self.flow_id = flow_id
async def agent(self, question: str, user: str, state: Optional[Dict[str, Any]] = None,
group: Optional[str] = None, history: Optional[list] = None,
streaming: bool = False, **kwargs) -> Union[Dict[str, Any], AsyncIterator]:
"""Agent with optional streaming"""
request = {
"question": question,
"user": user,
"streaming": streaming
}
if state is not None:
request["state"] = state
if group is not None:
request["group"] = group
if history is not None:
request["history"] = history
request.update(kwargs)
if streaming:
return self.client._send_request_streaming("agent", self.flow_id, request)
else:
return await self.client._send_request("agent", self.flow_id, request)
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
"""Text completion with optional streaming"""
request = {
"system": system,
"prompt": prompt,
"streaming": streaming
}
request.update(kwargs)
if streaming:
return self._text_completion_streaming(request)
else:
result = await self.client._send_request("text-completion", self.flow_id, request)
return result.get("response", "")
async def _text_completion_streaming(self, request):
"""Helper for streaming text completion"""
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
max_entity_distance: int = 3, streaming: bool = False, **kwargs):
"""Graph RAG with optional streaming"""
request = {
"query": query,
"user": user,
"collection": collection,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"streaming": streaming
}
request.update(kwargs)
if streaming:
return self._graph_rag_streaming(request)
else:
result = await self.client._send_request("graph-rag", self.flow_id, request)
return result.get("response", "")
async def _graph_rag_streaming(self, request):
"""Helper for streaming graph RAG"""
async for chunk in self.client._send_request_streaming("graph-rag", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
async def document_rag(self, query: str, user: str, collection: str,
doc_limit: int = 10, streaming: bool = False, **kwargs):
"""Document RAG with optional streaming"""
request = {
"query": query,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": streaming
}
request.update(kwargs)
if streaming:
return self._document_rag_streaming(request)
else:
result = await self.client._send_request("document-rag", self.flow_id, request)
return result.get("response", "")
async def _document_rag_streaming(self, request):
"""Helper for streaming document RAG"""
async for chunk in self.client._send_request_streaming("document-rag", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
async def prompt(self, id: str, variables: Dict[str, str], streaming: bool = False, **kwargs):
"""Execute prompt with optional streaming"""
request = {
"id": id,
"variables": variables,
"streaming": streaming
}
request.update(kwargs)
if streaming:
return self._prompt_streaming(request)
else:
result = await self.client._send_request("prompt", self.flow_id, request)
return result.get("response", "")
async def _prompt_streaming(self, request):
"""Helper for streaming prompt"""
async for chunk in self.client._send_request_streaming("prompt", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
request = {
"text": text,
"user": user,
"collection": collection,
"limit": limit
}
request.update(kwargs)
return await self.client._send_request("graph-embeddings", self.flow_id, request)
async def embeddings(self, text: str, **kwargs):
"""Generate text embeddings"""
request = {"text": text}
request.update(kwargs)
return await self.client._send_request("embeddings", self.flow_id, request)
async def triples_query(self, s=None, p=None, o=None, user=None, collection=None, limit=100, **kwargs):
"""Triple pattern query"""
request = {"limit": limit}
if s is not None:
request["s"] = str(s)
if p is not None:
request["p"] = str(p)
if o is not None:
request["o"] = str(o)
if user is not None:
request["user"] = user
if collection is not None:
request["collection"] = collection
request.update(kwargs)
return await self.client._send_request("triples", self.flow_id, request)
async def objects_query(self, query: str, user: str, collection: str, variables: Optional[Dict] = None,
operation_name: Optional[str] = None, **kwargs):
"""GraphQL query"""
request = {
"query": query,
"user": user,
"collection": collection
}
if variables:
request["variables"] = variables
if operation_name:
request["operationName"] = operation_name
request.update(kwargs)
return await self.client._send_request("objects", self.flow_id, request)
async def mcp_tool(self, name: str, parameters: Dict[str, Any], **kwargs):
"""Execute MCP tool"""
request = {
"name": name,
"parameters": parameters
}
request.update(kwargs)
return await self.client._send_request("mcp-tool", self.flow_id, request)

View file

@ -0,0 +1,270 @@
import json
import asyncio
import websockets
from typing import Optional, Iterator, Dict, Any, Coroutine
from . types import Triple
from . exceptions import ProtocolException
class BulkClient:
"""Synchronous bulk operations client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
if url.startswith("http://"):
return url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
return url.replace("https://", "wss://", 1)
elif url.startswith("ws://") or url.startswith("wss://"):
return url
else:
return f"ws://{url}"
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def import_triples(self, flow: str, triples: Iterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
self._run_async(self._import_triples_async(flow, triples))
async def _import_triples_async(self, flow: str, triples: Iterator[Triple]) -> None:
"""Async implementation of triple import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for triple in triples:
message = {
"s": triple.s,
"p": triple.p,
"o": triple.o
}
await websocket.send(json.dumps(message))
def export_triples(self, flow: str, **kwargs: Any) -> Iterator[Triple]:
"""Bulk export triples via WebSocket"""
async_gen = self._export_triples_async(flow)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
triple = loop.run_until_complete(async_gen.__anext__())
yield triple
except StopAsyncIteration:
break
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
data = json.loads(raw_message)
yield Triple(
s=data.get("s", ""),
p=data.get("p", ""),
o=data.get("o", "")
)
def import_graph_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
self._run_async(self._import_graph_embeddings_async(flow, embeddings))
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
await websocket.send(json.dumps(embedding))
def export_graph_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
async_gen = self._export_graph_embeddings_async(flow)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
embedding = loop.run_until_complete(async_gen.__anext__())
yield embedding
except StopAsyncIteration:
break
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_document_embeddings(self, flow: str, embeddings: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
self._run_async(self._import_document_embeddings_async(flow, embeddings))
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
await websocket.send(json.dumps(embedding))
def export_document_embeddings(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
async_gen = self._export_document_embeddings_async(flow)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
embedding = loop.run_until_complete(async_gen.__anext__())
yield embedding
except StopAsyncIteration:
break
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_entity_contexts(self, flow: str, contexts: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
self._run_async(self._import_entity_contexts_async(flow, contexts))
async def _import_entity_contexts_async(self, flow: str, contexts: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of entity contexts import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for context in contexts:
await websocket.send(json.dumps(context))
def export_entity_contexts(self, flow: str, **kwargs: Any) -> Iterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
async_gen = self._export_entity_contexts_async(flow)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
context = loop.run_until_complete(async_gen.__anext__())
yield context
except StopAsyncIteration:
break
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
yield json.loads(raw_message)
def import_objects(self, flow: str, objects: Iterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import objects via WebSocket"""
self._run_async(self._import_objects_async(flow, objects))
async def _import_objects_async(self, flow: str, objects: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of objects import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/objects"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for obj in objects:
await websocket.send(json.dumps(obj))
def close(self) -> None:
"""Close connections"""
# Cleanup handled by context managers
pass

View file

@ -41,9 +41,7 @@ class Collection:
collection = v["collection"], collection = v["collection"],
name = v["name"], name = v["name"],
description = v["description"], description = v["description"],
tags = v["tags"], tags = v["tags"]
created_at = v["created_at"],
updated_at = v["updated_at"]
) )
for v in collections for v in collections
] ]
@ -76,9 +74,7 @@ class Collection:
collection = v["collection"], collection = v["collection"],
name = v["name"], name = v["name"],
description = v["description"], description = v["description"],
tags = v["tags"], tags = v["tags"]
created_at = v["created_at"],
updated_at = v["updated_at"]
) )
return None return None
except Exception as e: except Exception as e:

View file

@ -1,6 +1,134 @@
"""
TrustGraph API Exceptions
Exception hierarchy for errors returned by TrustGraph services.
Each service error type maps to a specific exception class.
"""
# Protocol-level exceptions (communication errors)
class ProtocolException(Exception): class ProtocolException(Exception):
"""Raised when WebSocket protocol errors occur"""
pass pass
class ApplicationException(Exception):
# Base class for all TrustGraph application errors
class TrustGraphException(Exception):
"""Base class for all TrustGraph service errors"""
def __init__(self, message: str, error_type: str = None):
super().__init__(message)
self.message = message
self.error_type = error_type
# Service-specific exceptions
class AgentError(TrustGraphException):
"""Agent service error"""
pass pass
class ConfigError(TrustGraphException):
"""Configuration service error"""
pass
class DocumentRagError(TrustGraphException):
"""Document RAG retrieval error"""
pass
class FlowError(TrustGraphException):
"""Flow management error"""
pass
class GatewayError(TrustGraphException):
"""API Gateway error"""
pass
class GraphRagError(TrustGraphException):
"""Graph RAG retrieval error"""
pass
class LLMError(TrustGraphException):
"""LLM service error"""
pass
class LoadError(TrustGraphException):
"""Data loading error"""
pass
class LookupError(TrustGraphException):
"""Lookup/search error"""
pass
class NLPQueryError(TrustGraphException):
"""NLP query service error"""
pass
class ObjectsQueryError(TrustGraphException):
"""Objects query service error"""
pass
class RequestError(TrustGraphException):
"""Request processing error"""
pass
class StructuredQueryError(TrustGraphException):
"""Structured query service error"""
pass
class UnexpectedError(TrustGraphException):
"""Unexpected/unknown error"""
pass
# Mapping from error type string to exception class
ERROR_TYPE_MAPPING = {
"agent-error": AgentError,
"config-error": ConfigError,
"document-rag-error": DocumentRagError,
"flow-error": FlowError,
"gateway-error": GatewayError,
"graph-rag-error": GraphRagError,
"llm-error": LLMError,
"load-error": LoadError,
"lookup-error": LookupError,
"nlp-query-error": NLPQueryError,
"objects-query-error": ObjectsQueryError,
"request-error": RequestError,
"structured-query-error": StructuredQueryError,
"unexpected-error": UnexpectedError,
}
def raise_from_error_dict(error_dict: dict) -> None:
"""
Raise appropriate exception from TrustGraph error dictionary.
Args:
error_dict: Dictionary with 'type' and 'message' keys
Raises:
Appropriate TrustGraphException subclass based on error type
"""
error_type = error_dict.get("type", "unexpected-error")
message = error_dict.get("message", "Unknown error")
# Look up exception class, default to UnexpectedError
exception_class = ERROR_TYPE_MAPPING.get(error_type, UnexpectedError)
# Raise the appropriate exception
raise exception_class(message, error_type)
# Legacy exception for backwards compatibility
ApplicationException = TrustGraphException

View file

@ -160,14 +160,14 @@ class FlowInstance:
)["answer"] )["answer"]
def graph_rag( def graph_rag(
self, question, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150, entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2, max_path_length=2,
): ):
# The input consists of a question # The input consists of a question
input = { input = {
"query": question, "query": query,
"user": user, "user": user,
"collection": collection, "collection": collection,
"entity-limit": entity_limit, "entity-limit": entity_limit,
@ -182,13 +182,13 @@ class FlowInstance:
)["response"] )["response"]
def document_rag( def document_rag(
self, question, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
doc_limit=10, doc_limit=10,
): ):
# The input consists of a question # The input consists of a question
input = { input = {
"query": question, "query": query,
"user": user, "user": user,
"collection": collection, "collection": collection,
"doc-limit": doc_limit, "doc-limit": doc_limit,
@ -211,6 +211,21 @@ class FlowInstance:
input input
)["vectors"] )["vectors"]
def graph_embeddings_query(self, text, user, collection, limit=10):
# Query graph embeddings for semantic search
input = {
"text": text,
"user": user,
"collection": collection,
"limit": limit
}
return self.request(
"service/graph-embeddings",
input
)
def prompt(self, id, variables): def prompt(self, id, variables):
input = { input = {

View file

@ -0,0 +1,27 @@
import requests
from typing import Optional, Dict
class Metrics:
"""Synchronous metrics client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = url
self.timeout: int = timeout
self.token: Optional[str] = token
def get(self) -> str:
"""Get Prometheus metrics as text"""
url: str = f"{self.url}/api/metrics"
headers: Dict[str, str] = {}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
resp = requests.get(url, timeout=self.timeout, headers=headers)
if resp.status_code != 200:
raise Exception(f"Status code {resp.status_code}")
return resp.text

View file

@ -0,0 +1,457 @@
import json
import asyncio
import websockets
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk
from . exceptions import ProtocolException, raise_from_error_dict
class SocketClient:
"""Synchronous WebSocket client (wraps async websockets library)"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self._connection: Optional[Any] = None
self._request_counter: int = 0
self._lock: Lock = Lock()
self._loop: Optional[asyncio.AbstractEventLoop] = None
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
if url.startswith("http://"):
return url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
return url.replace("https://", "wss://", 1)
elif url.startswith("ws://") or url.startswith("wss://"):
return url
else:
# Assume ws://
return f"ws://{url}"
def flow(self, flow_id: str) -> "SocketFlowInstance":
"""Get flow instance for WebSocket operations"""
return SocketFlowInstance(self, flow_id)
def _send_request_sync(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any],
streaming: bool = False
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
"""Synchronous wrapper around async WebSocket communication"""
# Create event loop if needed
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop is running (e.g., in Jupyter), create new loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if streaming:
# For streaming, we need to return an iterator
# Create a generator that runs async code
return self._streaming_generator(service, flow, request, loop)
else:
# For non-streaming, just run the async code and return result
return loop.run_until_complete(self._send_request_async(service, flow, request))
def _streaming_generator(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any],
loop: asyncio.AbstractEventLoop
) -> Iterator[StreamingChunk]:
"""Generator that yields streaming chunks"""
async_gen = self._send_request_async_streaming(service, flow, request)
try:
while True:
try:
chunk = loop.run_until_complete(async_gen.__anext__())
yield chunk
except StopAsyncIteration:
break
finally:
# Clean up async generator
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _send_request_async(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any]
) -> Dict[str, Any]:
"""Async implementation of WebSocket request (non-streaming)"""
# Generate unique request ID
with self._lock:
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Wait for single response
raw_message = await websocket.recv()
response = json.loads(raw_message)
if response.get("id") != request_id:
raise ProtocolException(f"Response ID mismatch")
if "error" in response:
raise_from_error_dict(response["error"])
if "response" not in response:
raise ProtocolException(f"Missing response in message")
return response["response"]
async def _send_request_async_streaming(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any]
) -> Iterator[StreamingChunk]:
"""Async implementation of WebSocket request (streaming)"""
# Generate unique request ID
with self._lock:
self._request_counter += 1
request_id = f"req-{self._request_counter}"
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
# Yield chunks as they arrive
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != request_id:
continue # Ignore messages for other requests
if "error" in response:
raise_from_error_dict(response["error"])
if "response" in response:
resp = response["response"]
# Check for errors in response chunks
if "error" in resp:
raise_from_error_dict(resp["error"])
# Parse different chunk types
chunk = self._parse_chunk(resp)
yield chunk
# Check if this is the final chunk
if resp.get("end_of_stream") or resp.get("end_of_dialog") or response.get("complete"):
break
def _parse_chunk(self, resp: Dict[str, Any]) -> StreamingChunk:
"""Parse response chunk into appropriate type"""
chunk_type = resp.get("chunk_type")
if chunk_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "answer" or chunk_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
)
elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
else:
# RAG-style chunk (or generic chunk)
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None # Errors are always thrown, never stored
)
def close(self) -> None:
"""Close WebSocket connection"""
# Cleanup handled by context manager in async code
pass
class SocketFlowInstance:
"""Synchronous WebSocket flow instance with same interface as REST FlowInstance"""
def __init__(self, client: SocketClient, flow_id: str) -> None:
self.client: SocketClient = client
self.flow_id: str = flow_id
def agent(
self,
question: str,
user: str,
state: Optional[Dict[str, Any]] = None,
group: Optional[str] = None,
history: Optional[List[Dict[str, Any]]] = None,
streaming: bool = False,
**kwargs: Any
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
"""Agent with optional streaming"""
request = {
"question": question,
"user": user,
"streaming": streaming
}
if state is not None:
request["state"] = state
if group is not None:
request["group"] = group
if history is not None:
request["history"] = history
request.update(kwargs)
return self.client._send_request_sync("agent", self.flow_id, request, streaming)
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
"""Text completion with optional streaming"""
request = {
"system": system,
"prompt": prompt,
"streaming": streaming
}
request.update(kwargs)
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
if streaming:
# For text completion, yield just the content
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
else:
return result.get("response", "")
def graph_rag(
self,
query: str,
user: str,
collection: str,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Graph RAG with optional streaming"""
request = {
"query": query,
"user": user,
"collection": collection,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"streaming": streaming
}
request.update(kwargs)
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
if streaming:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
else:
return result.get("response", "")
def document_rag(
self,
query: str,
user: str,
collection: str,
doc_limit: int = 10,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Document RAG with optional streaming"""
request = {
"query": query,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": streaming
}
request.update(kwargs)
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
if streaming:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
else:
return result.get("response", "")
def prompt(
self,
id: str,
variables: Dict[str, str],
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute prompt with optional streaming"""
request = {
"id": id,
"variables": variables,
"streaming": streaming
}
request.update(kwargs)
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
if streaming:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
else:
return result.get("response", "")
def graph_embeddings_query(
self,
text: str,
user: str,
collection: str,
limit: int = 10,
**kwargs: Any
) -> Dict[str, Any]:
"""Query graph embeddings for semantic search"""
request = {
"text": text,
"user": user,
"collection": collection,
"limit": limit
}
request.update(kwargs)
return self.client._send_request_sync("graph-embeddings", self.flow_id, request, False)
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
"""Generate text embeddings"""
request = {"text": text}
request.update(kwargs)
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
def triples_query(
self,
s: Optional[str] = None,
p: Optional[str] = None,
o: Optional[str] = None,
user: Optional[str] = None,
collection: Optional[str] = None,
limit: int = 100,
**kwargs: Any
) -> Dict[str, Any]:
"""Triple pattern query"""
request = {"limit": limit}
if s is not None:
request["s"] = str(s)
if p is not None:
request["p"] = str(p)
if o is not None:
request["o"] = str(o)
if user is not None:
request["user"] = user
if collection is not None:
request["collection"] = collection
request.update(kwargs)
return self.client._send_request_sync("triples", self.flow_id, request, False)
def objects_query(
self,
query: str,
user: str,
collection: str,
variables: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
**kwargs: Any
) -> Dict[str, Any]:
"""GraphQL query"""
request = {
"query": query,
"user": user,
"collection": collection
}
if variables:
request["variables"] = variables
if operation_name:
request["operationName"] = operation_name
request.update(kwargs)
return self.client._send_request_sync("objects", self.flow_id, request, False)
def mcp_tool(
self,
name: str,
parameters: Dict[str, Any],
**kwargs: Any
) -> Dict[str, Any]:
"""Execute MCP tool"""
request = {
"name": name,
"parameters": parameters
}
request.update(kwargs)
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)

View file

@ -1,7 +1,7 @@
import dataclasses import dataclasses
import datetime import datetime
from typing import List from typing import List, Optional, Dict, Any
from .. knowledge import hash, Uri, Literal from .. knowledge import hash, Uri, Literal
@dataclasses.dataclass @dataclasses.dataclass
@ -49,5 +49,34 @@ class CollectionMetadata:
name : str name : str
description : str description : str
tags : List[str] tags : List[str]
created_at : str
updated_at : str # Streaming chunk types
@dataclasses.dataclass
class StreamingChunk:
"""Base class for streaming chunks"""
content: str
end_of_message: bool = False
@dataclasses.dataclass
class AgentThought(StreamingChunk):
"""Agent reasoning chunk"""
chunk_type: str = "thought"
@dataclasses.dataclass
class AgentObservation(StreamingChunk):
"""Agent tool observation chunk"""
chunk_type: str = "observation"
@dataclasses.dataclass
class AgentAnswer(StreamingChunk):
"""Agent final answer chunk"""
chunk_type: str = "final-answer"
end_of_dialog: bool = False
@dataclasses.dataclass
class RAGChunk(StreamingChunk):
"""RAG streaming chunk"""
chunk_type: str = "rag"
end_of_stream: bool = False
error: Optional[Dict[str, str]] = None

View file

@ -1,11 +1,12 @@
from . pubsub import PulsarClient from . pubsub import PulsarClient, get_pubsub
from . async_processor import AsyncProcessor from . async_processor import AsyncProcessor
from . consumer import Consumer from . consumer import Consumer
from . producer import Producer from . producer import Producer
from . publisher import Publisher from . publisher import Publisher
from . subscriber import Subscriber from . subscriber import Subscriber
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from . logging import add_logging_args, setup_logging
from . flow_processor import FlowProcessor from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec from . consumer_spec import ConsumerSpec
from . parameter_spec import ParameterSpec from . parameter_spec import ParameterSpec
@ -33,4 +34,5 @@ from . tool_service import ToolService
from . tool_client import ToolClientSpec from . tool_client import ToolClientSpec
from . agent_client import AgentClientSpec from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec from . structured_query_client import StructuredQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -15,10 +15,11 @@ from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel from .. log_level import LogLevel
from . pubsub import PulsarClient from . pubsub import PulsarClient, get_pubsub
from . producer import Producer from . producer import Producer
from . consumer import Consumer from . consumer import Consumer
from . metrics import ProcessorMetrics, ConsumerMetrics from . metrics import ProcessorMetrics, ConsumerMetrics
from . logging import add_logging_args, setup_logging
default_config_queue = config_push_queue default_config_queue = config_push_queue
@ -33,8 +34,11 @@ class AsyncProcessor:
# Store the identity # Store the identity
self.id = params.get("id") self.id = params.get("id")
# Register a pulsar client # Create pub/sub backend via factory
self.pulsar_client_object = PulsarClient(**params) self.pubsub_backend = get_pubsub(**params)
# Store pulsar_host for backward compatibility
self._pulsar_host = params.get("pulsar_host", "pulsar://pulsar:6650")
# Initialise metrics, records the parameters # Initialise metrics, records the parameters
ProcessorMetrics(processor = self.id).info({ ProcessorMetrics(processor = self.id).info({
@ -69,7 +73,7 @@ class AsyncProcessor:
self.config_sub_task = Consumer( self.config_sub_task = Consumer(
taskgroup = self.taskgroup, taskgroup = self.taskgroup,
client = self.pulsar_client, backend = self.pubsub_backend, # Changed from client to backend
subscriber = config_subscriber_id, subscriber = config_subscriber_id,
flow = None, flow = None,
@ -95,16 +99,16 @@ class AsyncProcessor:
# This is called to stop all threads. An over-ride point for extra # This is called to stop all threads. An over-ride point for extra
# functionality # functionality
def stop(self): def stop(self):
self.pulsar_client.close() self.pubsub_backend.close()
self.running = False self.running = False
# Returns the pulsar host # Returns the pub/sub backend (new interface)
@property @property
def pulsar_host(self): return self.pulsar_client_object.pulsar_host def pubsub(self): return self.pubsub_backend
# Returns the pulsar client # Returns the pulsar host (backward compatibility)
@property @property
def pulsar_client(self): return self.pulsar_client_object.client def pulsar_host(self): return self._pulsar_host
# Register a new event handler for configuration change # Register a new event handler for configuration change
def register_config_handler(self, handler): def register_config_handler(self, handler):
@ -165,18 +169,9 @@ class AsyncProcessor:
raise e raise e
@classmethod @classmethod
def setup_logging(cls, log_level='INFO'): def setup_logging(cls, args):
"""Configure logging for the entire application""" """Configure logging for the entire application"""
# Support environment variable override setup_logging(args)
env_log_level = os.environ.get('TRUSTGRAPH_LOG_LEVEL', log_level)
# Configure logging
logging.basicConfig(
level=getattr(logging, env_log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger.info(f"Logging configured with level: {env_log_level}")
# Startup fabric. launch calls launch_async in async mode. # Startup fabric. launch calls launch_async in async mode.
@classmethod @classmethod
@ -203,7 +198,7 @@ class AsyncProcessor:
args = vars(args) args = vars(args)
# Setup logging before anything else # Setup logging before anything else
cls.setup_logging(args.get('log_level', 'INFO').upper()) cls.setup_logging(args)
# Debug # Debug
logger.debug(f"Arguments: {args}") logger.debug(f"Arguments: {args}")
@ -255,12 +250,21 @@ class AsyncProcessor:
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
# Pub/sub backend selection
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
choices=['pulsar', 'mqtt'],
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
PulsarClient.add_args(parser) PulsarClient.add_args(parser)
add_logging_args(parser)
parser.add_argument( parser.add_argument(
'--config-queue', '--config-push-queue',
default=default_config_queue, default=default_config_queue,
help=f'Config push queue {default_config_queue}', help=f'Config push queue (default: {default_config_queue})',
) )
parser.add_argument( parser.add_argument(

View file

@ -0,0 +1,148 @@
"""
Backend abstraction interfaces for pub/sub systems.
This module defines Protocol classes that all pub/sub backends must implement,
allowing TrustGraph to work with different messaging systems (Pulsar, MQTT, Kafka, etc.)
"""
from typing import Protocol, Any, runtime_checkable
@runtime_checkable
class Message(Protocol):
"""Protocol for a received message."""
def value(self) -> Any:
"""
Get the deserialized message content.
Returns:
Dataclass instance representing the message
"""
...
def properties(self) -> dict:
"""
Get message properties/metadata.
Returns:
Dictionary of message properties
"""
...
@runtime_checkable
class BackendProducer(Protocol):
"""Protocol for backend-specific producer."""
def send(self, message: Any, properties: dict = {}) -> None:
"""
Send a message (dataclass instance) with optional properties.
Args:
message: Dataclass instance to send
properties: Optional metadata properties
"""
...
def flush(self) -> None:
"""Flush any buffered messages."""
...
def close(self) -> None:
"""Close the producer."""
...
@runtime_checkable
class BackendConsumer(Protocol):
"""Protocol for backend-specific consumer."""
def receive(self, timeout_millis: int = 2000) -> Message:
"""
Receive a message from the topic.
Args:
timeout_millis: Timeout in milliseconds
Returns:
Message object
Raises:
TimeoutError: If no message received within timeout
"""
...
def acknowledge(self, message: Message) -> None:
"""
Acknowledge successful processing of a message.
Args:
message: The message to acknowledge
"""
...
def negative_acknowledge(self, message: Message) -> None:
"""
Negative acknowledge - triggers redelivery.
Args:
message: The message to negatively acknowledge
"""
...
def unsubscribe(self) -> None:
"""Unsubscribe from the topic."""
...
def close(self) -> None:
"""Close the consumer."""
...
@runtime_checkable
class PubSubBackend(Protocol):
"""Protocol defining the interface all pub/sub backends must implement."""
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""
Create a producer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
schema: Dataclass type for messages
**options: Backend-specific options (e.g., chunking_enabled)
Returns:
Backend-specific producer instance
"""
...
def create_consumer(
self,
topic: str,
subscription: str,
schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options
) -> BackendConsumer:
"""
Create a consumer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
subscription: Subscription/consumer group name
schema: Dataclass type for messages
initial_position: 'earliest' or 'latest' (some backends may ignore)
consumer_type: 'shared', 'exclusive', 'failover' (some backends may ignore)
**options: Backend-specific options
Returns:
Backend-specific consumer instance
"""
...
def close(self) -> None:
"""Close the backend connection."""
...

View file

@ -15,12 +15,13 @@ def get_cassandra_defaults() -> dict:
Get default Cassandra configuration values from environment variables or fallback defaults. Get default Cassandra configuration values from environment variables or fallback defaults.
Returns: Returns:
dict: Dictionary with 'host', 'username', and 'password' keys dict: Dictionary with 'host', 'username', 'password', and 'keyspace' keys
""" """
return { return {
'host': os.getenv('CASSANDRA_HOST', 'cassandra'), 'host': os.getenv('CASSANDRA_HOST', 'cassandra'),
'username': os.getenv('CASSANDRA_USERNAME'), 'username': os.getenv('CASSANDRA_USERNAME'),
'password': os.getenv('CASSANDRA_PASSWORD') 'password': os.getenv('CASSANDRA_PASSWORD'),
'keyspace': os.getenv('CASSANDRA_KEYSPACE')
} }
@ -54,6 +55,12 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
if 'CASSANDRA_PASSWORD' in os.environ: if 'CASSANDRA_PASSWORD' in os.environ:
password_help += " [from CASSANDRA_PASSWORD]" password_help += " [from CASSANDRA_PASSWORD]"
keyspace_help = "Cassandra keyspace (default: service-specific)"
if defaults['keyspace']:
keyspace_help = f"Cassandra keyspace (default: {defaults['keyspace']})"
if 'CASSANDRA_KEYSPACE' in os.environ:
keyspace_help += " [from CASSANDRA_KEYSPACE]"
parser.add_argument( parser.add_argument(
'--cassandra-host', '--cassandra-host',
default=defaults['host'], default=defaults['host'],
@ -72,13 +79,20 @@ def add_cassandra_args(parser: argparse.ArgumentParser) -> None:
help=password_help help=password_help
) )
parser.add_argument(
'--cassandra-keyspace',
default=defaults['keyspace'],
help=keyspace_help
)
def resolve_cassandra_config( def resolve_cassandra_config(
args: Optional[Any] = None, args: Optional[Any] = None,
host: Optional[str] = None, host: Optional[str] = None,
username: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None password: Optional[str] = None,
) -> Tuple[List[str], Optional[str], Optional[str]]: default_keyspace: Optional[str] = None
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
""" """
Resolve Cassandra configuration from various sources. Resolve Cassandra configuration from various sources.
@ -86,25 +100,29 @@ def resolve_cassandra_config(
Converts host string to list format for Cassandra driver. Converts host string to list format for Cassandra driver.
Args: Args:
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace
host: Optional explicit host parameter (overrides args) host: Optional explicit host parameter (overrides args)
username: Optional explicit username parameter (overrides args) username: Optional explicit username parameter (overrides args)
password: Optional explicit password parameter (overrides args) password: Optional explicit password parameter (overrides args)
default_keyspace: Optional default keyspace if not specified elsewhere
Returns: Returns:
tuple: (hosts_list, username, password) tuple: (hosts_list, username, password, keyspace)
""" """
# If args provided, extract values # If args provided, extract values
keyspace = None
if args is not None: if args is not None:
host = host or getattr(args, 'cassandra_host', None) host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None) username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None) password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
# Apply defaults if still None # Apply defaults if still None
defaults = get_cassandra_defaults() defaults = get_cassandra_defaults()
host = host or defaults['host'] host = host or defaults['host']
username = username or defaults['username'] username = username or defaults['username']
password = password or defaults['password'] password = password or defaults['password']
keyspace = keyspace or defaults['keyspace'] or default_keyspace
# Convert host string to list # Convert host string to list
if isinstance(host, str): if isinstance(host, str):
@ -112,18 +130,22 @@ def resolve_cassandra_config(
else: else:
hosts = host hosts = host
return hosts, username, password return hosts, username, password, keyspace
def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[str], Optional[str]]: def get_cassandra_config_from_params(
params: dict,
default_keyspace: Optional[str] = None
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
""" """
Extract and resolve Cassandra configuration from a parameters dictionary. Extract and resolve Cassandra configuration from a parameters dictionary.
Args: Args:
params: Dictionary of parameters that may contain Cassandra configuration params: Dictionary of parameters that may contain Cassandra configuration
default_keyspace: Optional default keyspace if not specified in params
Returns: Returns:
tuple: (hosts_list, username, password) tuple: (hosts_list, username, password, keyspace)
""" """
# Get Cassandra parameters # Get Cassandra parameters
host = params.get('cassandra_host') host = params.get('cassandra_host')
@ -131,4 +153,9 @@ def get_cassandra_config_from_params(params: dict) -> Tuple[List[str], Optional[
password = params.get('cassandra_password') password = params.get('cassandra_password')
# Use resolve function to handle defaults and list conversion # Use resolve function to handle defaults and list conversion
return resolve_cassandra_config(host=host, username=username, password=password) return resolve_cassandra_config(
host=host,
username=username,
password=password,
default_keyspace=default_keyspace
)

View file

@ -0,0 +1,128 @@
"""
Handler for storage services to process collection configuration from config push
"""
import json
import logging
from typing import Dict, Set
logger = logging.getLogger(__name__)
class CollectionConfigHandler:
"""
Handles collection configuration from config push messages for storage services.
Storage services should:
1. Inherit from this class along with their service base class
2. Call register_config_handler(self.on_collection_config) in __init__
3. Implement create_collection(user, collection, metadata) method
4. Implement delete_collection(user, collection) method
"""
def __init__(self, **kwargs):
# Track known collections: {(user, collection): metadata_dict}
self.known_collections: Dict[tuple, dict] = {}
# Pass remaining kwargs up the inheritance chain
super().__init__(**kwargs)
async def on_collection_config(self, config: dict, version: int):
"""
Handle config push messages and extract collection information
Args:
config: Configuration dictionary from ConfigPush message
version: Configuration version number
"""
logger.info(f"Processing collection configuration (version {version})")
# Extract collections from config (treat missing key as empty)
collection_config = config.get("collection", {})
# Track which collections we've seen in this config
current_collections: Set[tuple] = set()
# Process each collection in the config
for key, value_json in collection_config.items():
try:
# Parse user:collection key
if ":" not in key:
logger.warning(f"Invalid collection key format (expected user:collection): {key}")
continue
user, collection = key.split(":", 1)
current_collections.add((user, collection))
# Parse metadata
metadata = json.loads(value_json)
# Check if this is a new collection or updated
collection_key = (user, collection)
if collection_key not in self.known_collections:
logger.info(f"New collection detected: {user}/{collection}")
await self.create_collection(user, collection, metadata)
self.known_collections[collection_key] = metadata
else:
# Collection already exists, update metadata if changed
if self.known_collections[collection_key] != metadata:
logger.info(f"Collection metadata updated: {user}/{collection}")
# Most storage services don't need to do anything for metadata updates
# They just need to know the collection exists
self.known_collections[collection_key] = metadata
except Exception as e:
logger.error(f"Error processing collection config for key {key}: {e}", exc_info=True)
# Find collections that were deleted (in known but not in current)
deleted_collections = set(self.known_collections.keys()) - current_collections
for user, collection in deleted_collections:
logger.info(f"Collection deleted: {user}/{collection}")
try:
# Remove from known_collections FIRST to immediately reject new writes
# This eliminates race condition with worker threads
del self.known_collections[(user, collection)]
# Physical deletion happens after - worker threads already rejecting writes
await self.delete_collection(user, collection)
except Exception as e:
logger.error(f"Error deleting collection {user}/{collection}: {e}", exc_info=True)
# If physical deletion failed, should we re-add to known_collections?
# For now, keep it removed - collection is logically deleted per config
logger.debug(f"Collection config processing complete. Known collections: {len(self.known_collections)}")
async def create_collection(self, user: str, collection: str, metadata: dict):
"""
Create a collection in the storage backend.
Subclasses must implement this method.
Args:
user: User ID
collection: Collection ID
metadata: Collection metadata dictionary
"""
raise NotImplementedError("Storage service must implement create_collection method")
async def delete_collection(self, user: str, collection: str):
"""
Delete a collection from the storage backend.
Subclasses must implement this method.
Args:
user: User ID
collection: Collection ID
"""
raise NotImplementedError("Storage service must implement delete_collection method")
def collection_exists(self, user: str, collection: str) -> bool:
"""
Check if a collection is known to exist
Args:
user: User ID
collection: Collection ID
Returns:
True if collection exists, False otherwise
"""
return (user, collection) in self.known_collections

View file

@ -9,9 +9,6 @@
# one handler, and a single thread of concurrency, nothing too outrageous # one handler, and a single thread of concurrency, nothing too outrageous
# will happen if synchronous / blocking code is used # will happen if synchronous / blocking code is used
from pulsar.schema import JsonSchema
import pulsar
import _pulsar
import asyncio import asyncio
import time import time
import logging import logging
@ -21,10 +18,14 @@ from .. exceptions import TooManyRequests
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Timeout exception - can come from different backends
class TimeoutError(Exception):
pass
class Consumer: class Consumer:
def __init__( def __init__(
self, taskgroup, flow, client, topic, subscriber, schema, self, taskgroup, flow, backend, topic, subscriber, schema,
handler, handler,
metrics = None, metrics = None,
start_of_messages=False, start_of_messages=False,
@ -35,7 +36,7 @@ class Consumer:
self.taskgroup = taskgroup self.taskgroup = taskgroup
self.flow = flow self.flow = flow
self.client = client self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic self.topic = topic
self.subscriber = subscriber self.subscriber = subscriber
self.schema = schema self.schema = schema
@ -96,18 +97,20 @@ class Consumer:
logger.info(f"Subscribing to topic: {self.topic}") logger.info(f"Subscribing to topic: {self.topic}")
# Determine initial position
if self.start_of_messages: if self.start_of_messages:
pos = pulsar.InitialPosition.Earliest initial_pos = 'earliest'
else: else:
pos = pulsar.InitialPosition.Latest initial_pos = 'latest'
# Create consumer via backend
self.consumer = await asyncio.to_thread( self.consumer = await asyncio.to_thread(
self.client.subscribe, self.backend.create_consumer,
topic = self.topic, topic = self.topic,
subscription_name = self.subscriber, subscription = self.subscriber,
schema = JsonSchema(self.schema), schema = self.schema,
initial_position = pos, initial_position = initial_pos,
consumer_type = pulsar.ConsumerType.Shared, consumer_type = 'shared',
) )
except Exception as e: except Exception as e:
@ -159,9 +162,10 @@ class Consumer:
self.consumer.receive, self.consumer.receive,
timeout_millis=2000 timeout_millis=2000
) )
except _pulsar.Timeout:
continue
except Exception as e: except Exception as e:
# Handle timeout from any backend
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
continue
raise e raise e
await self.handle_one_from_queue(msg) await self.handle_one_from_queue(msg)

View file

@ -19,7 +19,7 @@ class ConsumerSpec(Spec):
consumer = Consumer( consumer = Consumer(
taskgroup = processor.taskgroup, taskgroup = processor.taskgroup,
flow = flow, flow = flow,
client = processor.pulsar_client, backend = processor.pubsub,
topic = definition[self.name], topic = definition[self.name],
subscriber = processor.id + "--" + flow.name + "--" + self.name, subscriber = processor.id + "--" + flow.name + "--" + self.name,
schema = self.schema, schema = self.schema,

View file

@ -0,0 +1,159 @@
"""
Centralized logging configuration for TrustGraph server-side components.
This module provides standardized logging setup across all TrustGraph services,
ensuring consistent log formats, levels, and command-line arguments.
Supports dual output to console and Loki for centralized log aggregation.
"""
import logging
import logging.handlers
from queue import Queue
import os
def add_logging_args(parser):
"""
Add standard logging arguments to an argument parser.
Args:
parser: argparse.ArgumentParser instance to add arguments to
"""
parser.add_argument(
'-l', '--log-level',
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Log level (default: INFO)'
)
parser.add_argument(
'--loki-enabled',
action='store_true',
default=True,
help='Enable Loki logging (default: True)'
)
parser.add_argument(
'--no-loki-enabled',
dest='loki_enabled',
action='store_false',
help='Disable Loki logging'
)
parser.add_argument(
'--loki-url',
default=os.getenv('LOKI_URL', 'http://loki:3100/loki/api/v1/push'),
help='Loki push URL (default: http://loki:3100/loki/api/v1/push)'
)
parser.add_argument(
'--loki-username',
default=os.getenv('LOKI_USERNAME', None),
help='Loki username for authentication (optional)'
)
parser.add_argument(
'--loki-password',
default=os.getenv('LOKI_PASSWORD', None),
help='Loki password for authentication (optional)'
)
def setup_logging(args):
"""
Configure logging from parsed command-line arguments.
Sets up logging with a standardized format and output to stdout.
Optionally enables Loki integration for centralized log aggregation.
This should be called early in application startup, before any
logging calls are made.
Args:
args: Dictionary of parsed arguments (typically from vars(args))
Must contain 'log_level' key, optional Loki configuration
"""
log_level = args.get('log_level', 'INFO')
loki_enabled = args.get('loki_enabled', True)
# Build list of handlers starting with console
handlers = [logging.StreamHandler()]
# Add Loki handler if enabled
queue_listener = None
if loki_enabled:
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
loki_username = args.get('loki_username')
loki_password = args.get('loki_password')
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
try:
from logging_loki import LokiHandler
# Create Loki handler with optional authentication and processor label
loki_handler_kwargs = {
'url': loki_url,
'version': "1",
}
if loki_username and loki_password:
loki_handler_kwargs['auth'] = (loki_username, loki_password)
# Add processor label if available (for consistency with Prometheus metrics)
if processor_id:
loki_handler_kwargs['tags'] = {'processor': processor_id}
loki_handler = LokiHandler(**loki_handler_kwargs)
# Wrap in QueueHandler for non-blocking operation
log_queue = Queue(maxsize=500)
queue_handler = logging.handlers.QueueHandler(log_queue)
handlers.append(queue_handler)
# Start QueueListener in background thread
queue_listener = logging.handlers.QueueListener(
log_queue,
loki_handler,
respect_handler_level=True
)
queue_listener.start()
# Store listener reference for potential cleanup
# (attached to root logger for access if needed)
logging.getLogger().loki_queue_listener = queue_listener
except ImportError:
# Graceful degradation if python-logging-loki not installed
print("WARNING: python-logging-loki not installed, Loki logging disabled")
print("Install with: pip install python-logging-loki")
except Exception as e:
# Graceful degradation if Loki connection fails
print(f"WARNING: Failed to setup Loki logging: {e}")
print("Continuing with console-only logging")
# Get processor ID for log formatting (use 'unknown' if not available)
processor_id = args.get('id', 'unknown')
# Configure logging with all handlers
# Use processor ID as the primary identifier in logs
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
handlers=handlers,
force=True # Force reconfiguration if already configured
)
# Prevent recursive logging from Loki's HTTP client
if loki_enabled and queue_listener:
# Disable urllib3 logging to prevent infinite loop
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.info(f"Logging configured with level: {log_level}")
if loki_enabled and queue_listener:
logger.info(f"Loki logging enabled: {loki_url}")
elif loki_enabled:
logger.warning("Loki logging requested but not available")

View file

@ -1,5 +1,4 @@
from pulsar.schema import JsonSchema
import asyncio import asyncio
import logging import logging
@ -8,10 +7,10 @@ logger = logging.getLogger(__name__)
class Producer: class Producer:
def __init__(self, client, topic, schema, metrics=None, def __init__(self, backend, topic, schema, metrics=None,
chunking_enabled=True): chunking_enabled=True):
self.client = client self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
@ -44,9 +43,9 @@ class Producer:
try: try:
logger.info(f"Connecting publisher to {self.topic}...") logger.info(f"Connecting publisher to {self.topic}...")
self.producer = self.client.create_producer( self.producer = self.backend.create_producer(
topic = self.topic, topic = self.topic,
schema = JsonSchema(self.schema), schema = self.schema,
chunking_enabled = self.chunking_enabled, chunking_enabled = self.chunking_enabled,
) )
logger.info(f"Connected publisher to {self.topic}") logger.info(f"Connected publisher to {self.topic}")

View file

@ -15,7 +15,7 @@ class ProducerSpec(Spec):
) )
producer = Producer( producer = Producer(
client = processor.pulsar_client, backend = processor.pubsub,
topic = definition[self.name], topic = definition[self.name],
schema = self.schema, schema = self.schema,
metrics = producer_metrics, metrics = producer_metrics,

View file

@ -37,33 +37,34 @@ class PromptClient(RequestResponse):
else: else:
logger.info("DEBUG prompt_client: Streaming path") logger.info("DEBUG prompt_client: Streaming path")
# Streaming path - collect all chunks # Streaming path - just forward chunks, don't accumulate
full_text = "" last_text = ""
full_object = None last_object = None
async def collect_chunks(resp): async def forward_chunks(resp):
nonlocal full_text, full_object nonlocal last_text, last_object
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}") logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
if resp.error: if resp.error:
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}") logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
raise RuntimeError(resp.error.message) raise RuntimeError(resp.error.message)
if resp.text: end_stream = getattr(resp, 'end_of_stream', False)
full_text += resp.text
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars") # Always call callback if there's text OR if it's the final message
# Call chunk callback if provided if resp.text is not None:
last_text = resp.text
# Call chunk callback if provided with both chunk and end_of_stream flag
if chunk_callback: if chunk_callback:
logger.info(f"DEBUG prompt_client: Calling chunk_callback") logger.info(f"DEBUG prompt_client: Calling chunk_callback with end_of_stream={end_stream}")
if asyncio.iscoroutinefunction(chunk_callback): if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text) await chunk_callback(resp.text, end_stream)
else: else:
chunk_callback(resp.text) chunk_callback(resp.text, end_stream)
elif resp.object: elif resp.object:
logger.info(f"DEBUG prompt_client: Got object response") logger.info(f"DEBUG prompt_client: Got object response")
full_object = resp.object last_object = resp.object
end_stream = getattr(resp, 'end_of_stream', False)
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}") logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
return end_stream return end_stream
@ -79,17 +80,17 @@ class PromptClient(RequestResponse):
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}") logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
await self.request( await self.request(
req, req,
recipient=collect_chunks, recipient=forward_chunks,
timeout=timeout timeout=timeout
) )
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars") logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
if full_text: if last_text:
logger.info("DEBUG prompt_client: Returning full_text") logger.info("DEBUG prompt_client: Returning last_text")
return full_text return last_text
logger.info("DEBUG prompt_client: Returning parsed full_object") logger.info("DEBUG prompt_client: Returning parsed last_object")
return json.loads(full_object) return json.loads(last_object) if last_object else None
async def extract_definitions(self, text, timeout=600): async def extract_definitions(self, text, timeout=600):
return await self.prompt( return await self.prompt(

View file

@ -1,9 +1,6 @@
from pulsar.schema import JsonSchema
import asyncio import asyncio
import time import time
import pulsar
import logging import logging
# Module logger # Module logger
@ -11,9 +8,9 @@ logger = logging.getLogger(__name__)
class Publisher: class Publisher:
def __init__(self, client, topic, schema=None, max_size=10, def __init__(self, backend, topic, schema=None, max_size=10,
chunking_enabled=True, drain_timeout=5.0): chunking_enabled=True, drain_timeout=5.0):
self.client = client self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
self.q = asyncio.Queue(maxsize=max_size) self.q = asyncio.Queue(maxsize=max_size)
@ -47,9 +44,9 @@ class Publisher:
try: try:
producer = self.client.create_producer( producer = self.backend.create_producer(
topic=self.topic, topic=self.topic,
schema=JsonSchema(self.schema), schema=self.schema,
chunking_enabled=self.chunking_enabled, chunking_enabled=self.chunking_enabled,
) )

View file

@ -4,8 +4,45 @@ import pulsar
import _pulsar import _pulsar
import uuid import uuid
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import logging
from .. log_level import LogLevel from .. log_level import LogLevel
from .pulsar_backend import PulsarBackend
logger = logging.getLogger(__name__)
def get_pubsub(**config):
"""
Factory function to create a pub/sub backend based on configuration.
Args:
config: Configuration dictionary from command-line args
Must include 'pubsub_backend' key
Returns:
Backend instance (PulsarBackend, MQTTBackend, etc.)
Example:
backend = get_pubsub(
pubsub_backend='pulsar',
pulsar_host='pulsar://localhost:6650'
)
"""
backend_type = config.get('pubsub_backend', 'pulsar')
if backend_type == 'pulsar':
return PulsarBackend(
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
listener=config.get('pulsar_listener'),
)
elif backend_type == 'mqtt':
# TODO: Implement MQTT backend
raise NotImplementedError("MQTT backend not yet implemented")
else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
class PulsarClient: class PulsarClient:
@ -71,10 +108,3 @@ class PulsarClient:
'--pulsar-listener', '--pulsar-listener',
help=f'Pulsar listener (default: none)', help=f'Pulsar listener (default: none)',
) )
parser.add_argument(
'-l', '--log-level',
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help=f'Log level (default: INFO)'
)

View file

@ -0,0 +1,350 @@
"""
Pulsar backend implementation for pub/sub abstraction.
This module provides a Pulsar-specific implementation of the backend interfaces,
handling topic mapping, serialization, and Pulsar client management.
"""
import pulsar
import _pulsar
import json
import logging
import base64
import types
from dataclasses import asdict, is_dataclass
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
logger = logging.getLogger(__name__)
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)
class PulsarMessage:
"""Wrapper for Pulsar messages to match Message protocol."""
def __init__(self, pulsar_msg, schema_cls):
self._msg = pulsar_msg
self._schema_cls = schema_cls
self._value = None
def value(self) -> Any:
"""Deserialize and return the message value as a dataclass."""
if self._value is None:
# Get JSON string from Pulsar message
json_data = self._msg.data().decode('utf-8')
data_dict = json.loads(json_data)
# Convert to dataclass
self._value = dict_to_dataclass(data_dict, self._schema_cls)
return self._value
def properties(self) -> dict:
"""Return message properties."""
return self._msg.properties()
class PulsarBackendProducer:
"""Pulsar-specific producer implementation."""
def __init__(self, pulsar_producer, schema_cls):
self._producer = pulsar_producer
self._schema_cls = schema_cls
def send(self, message: Any, properties: dict = {}) -> None:
"""Send a dataclass message."""
# Convert dataclass to dict, excluding None values
data_dict = dataclass_to_dict(message)
# Serialize to JSON
json_data = json.dumps(data_dict)
# Send via Pulsar
self._producer.send(json_data.encode('utf-8'), properties=properties)
def flush(self) -> None:
"""Flush buffered messages."""
self._producer.flush()
def close(self) -> None:
"""Close the producer."""
self._producer.close()
class PulsarBackendConsumer:
"""Pulsar-specific consumer implementation."""
def __init__(self, pulsar_consumer, schema_cls):
self._consumer = pulsar_consumer
self._schema_cls = schema_cls
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message."""
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
return PulsarMessage(pulsar_msg, self._schema_cls)
def acknowledge(self, message: Message) -> None:
"""Acknowledge a message."""
if isinstance(message, PulsarMessage):
self._consumer.acknowledge(message._msg)
def negative_acknowledge(self, message: Message) -> None:
"""Negative acknowledge a message."""
if isinstance(message, PulsarMessage):
self._consumer.negative_acknowledge(message._msg)
def unsubscribe(self) -> None:
"""Unsubscribe from the topic."""
self._consumer.unsubscribe()
def close(self) -> None:
"""Close the consumer."""
self._consumer.close()
class PulsarBackend:
"""
Pulsar backend implementation.
Handles topic mapping, client management, and creation of Pulsar-specific
producers and consumers.
"""
def __init__(self, host: str, api_key: str = None, listener: str = None):
"""
Initialize Pulsar backend.
Args:
host: Pulsar broker URL (e.g., pulsar://localhost:6650)
api_key: Optional API key for authentication
listener: Optional listener name for multi-homed setups
"""
self.host = host
self.api_key = api_key
self.listener = listener
# Create Pulsar client
client_args = {'service_url': host}
if listener:
client_args['listener_name'] = listener
if api_key:
client_args['authentication'] = pulsar.AuthenticationToken(api_key)
self.client = pulsar.Client(**client_args)
logger.info(f"Pulsar client connected to {host}")
def map_topic(self, generic_topic: str) -> str:
"""
Map generic topic format to Pulsar URI.
Format: qos/tenant/namespace/queue
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
Args:
generic_topic: Generic topic string or already-formatted Pulsar URI
Returns:
Pulsar topic URI
"""
# If already a Pulsar URI, return as-is
if '://' in generic_topic:
return generic_topic
parts = generic_topic.split('/', 3)
if len(parts) != 4:
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
qos, tenant, namespace, queue = parts
# Map QoS to persistence
if qos == 'q0':
persistence = 'non-persistent'
elif qos in ['q1', 'q2']:
persistence = 'persistent'
else:
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
return f"{persistence}://{tenant}/{namespace}/{queue}"
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""
Create a Pulsar producer.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
schema: Dataclass type for messages
**options: Backend-specific options (e.g., chunking_enabled)
Returns:
PulsarBackendProducer instance
"""
pulsar_topic = self.map_topic(topic)
producer_args = {
'topic': pulsar_topic,
'schema': pulsar.schema.BytesSchema(), # We handle serialization ourselves
}
# Add optional parameters
if 'chunking_enabled' in options:
producer_args['chunking_enabled'] = options['chunking_enabled']
pulsar_producer = self.client.create_producer(**producer_args)
logger.debug(f"Created producer for topic: {pulsar_topic}")
return PulsarBackendProducer(pulsar_producer, schema)
def create_consumer(
self,
topic: str,
subscription: str,
schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options
) -> BackendConsumer:
"""
Create a Pulsar consumer.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
subscription: Subscription name
schema: Dataclass type for messages
initial_position: 'earliest' or 'latest'
consumer_type: 'shared', 'exclusive', or 'failover'
**options: Backend-specific options
Returns:
PulsarBackendConsumer instance
"""
pulsar_topic = self.map_topic(topic)
# Map initial position
if initial_position == 'earliest':
pos = pulsar.InitialPosition.Earliest
else:
pos = pulsar.InitialPosition.Latest
# Map consumer type
if consumer_type == 'exclusive':
ctype = pulsar.ConsumerType.Exclusive
elif consumer_type == 'failover':
ctype = pulsar.ConsumerType.Failover
else:
ctype = pulsar.ConsumerType.Shared
consumer_args = {
'topic': pulsar_topic,
'subscription_name': subscription,
'schema': pulsar.schema.BytesSchema(), # We handle deserialization ourselves
'initial_position': pos,
'consumer_type': ctype,
}
pulsar_consumer = self.client.subscribe(**consumer_args)
logger.debug(f"Created consumer for topic: {pulsar_topic}, subscription: {subscription}")
return PulsarBackendConsumer(pulsar_consumer, schema)
def close(self) -> None:
"""Close the Pulsar client."""
self.client.close()
logger.info("Pulsar client closed")

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class RequestResponse(Subscriber): class RequestResponse(Subscriber):
def __init__( def __init__(
self, client, subscription, consumer_name, self, backend, subscription, consumer_name,
request_topic, request_schema, request_topic, request_schema,
request_metrics, request_metrics,
response_topic, response_schema, response_topic, response_schema,
@ -22,7 +22,7 @@ class RequestResponse(Subscriber):
): ):
super(RequestResponse, self).__init__( super(RequestResponse, self).__init__(
client = client, backend = backend,
subscription = subscription, subscription = subscription,
consumer_name = consumer_name, consumer_name = consumer_name,
topic = response_topic, topic = response_topic,
@ -31,7 +31,7 @@ class RequestResponse(Subscriber):
) )
self.producer = Producer( self.producer = Producer(
client = client, backend = backend,
topic = request_topic, topic = request_topic,
schema = request_schema, schema = request_schema,
metrics = request_metrics, metrics = request_metrics,
@ -126,7 +126,7 @@ class RequestResponseSpec(Spec):
) )
rr = self.impl( rr = self.impl(
client = processor.pulsar_client, backend = processor.pubsub,
# Make subscription names unique, so that all subscribers get # Make subscription names unique, so that all subscribers get
# to see all response messages # to see all response messages

View file

@ -3,9 +3,7 @@
# off of a queue and make it available using an internal broker system, # off of a queue and make it available using an internal broker system,
# so suitable for when multiple recipients are reading from the same queue # so suitable for when multiple recipients are reading from the same queue
from pulsar.schema import JsonSchema
import asyncio import asyncio
import _pulsar
import time import time
import logging import logging
import uuid import uuid
@ -13,12 +11,16 @@ import uuid
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Timeout exception - can come from different backends
class TimeoutError(Exception):
pass
class Subscriber: class Subscriber:
def __init__(self, client, topic, subscription, consumer_name, def __init__(self, backend, topic, subscription, consumer_name,
schema=None, max_size=100, metrics=None, schema=None, max_size=100, metrics=None,
backpressure_strategy="block", drain_timeout=5.0): backpressure_strategy="block", drain_timeout=5.0):
self.client = client self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic self.topic = topic
self.subscription = subscription self.subscription = subscription
self.consumer_name = consumer_name self.consumer_name = consumer_name
@ -43,18 +45,14 @@ class Subscriber:
async def start(self): async def start(self):
# Build subscribe arguments # Create consumer via backend
subscribe_args = { self.consumer = await asyncio.to_thread(
'topic': self.topic, self.backend.create_consumer,
'subscription_name': self.subscription, topic=self.topic,
'consumer_name': self.consumer_name, subscription=self.subscription,
} schema=self.schema,
consumer_type='shared',
# Only add schema if provided (omit if None) )
if self.schema is not None:
subscribe_args['schema'] = JsonSchema(self.schema)
self.consumer = self.client.subscribe(**subscribe_args)
self.task = asyncio.create_task(self.run()) self.task = asyncio.create_task(self.run())
@ -94,12 +92,13 @@ class Subscriber:
drain_end_time = time.time() + self.drain_timeout drain_end_time = time.time() + self.drain_timeout
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s") logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
# Stop accepting new messages from Pulsar during drain # Stop accepting new messages during drain
if self.consumer: # Note: Not all backends support pausing message listeners
if self.consumer and hasattr(self.consumer, 'pause_message_listener'):
try: try:
self.consumer.pause_message_listener() self.consumer.pause_message_listener()
except _pulsar.InvalidConfiguration: except Exception:
# Not all consumers have message listeners (e.g., blocking receive mode) # Not all consumers support message listeners
pass pass
# Check drain timeout # Check drain timeout
@ -133,9 +132,10 @@ class Subscriber:
self.consumer.receive, self.consumer.receive,
timeout_millis=250 timeout_millis=250
) )
except _pulsar.Timeout:
continue
except Exception as e: except Exception as e:
# Handle timeout from any backend
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
continue
logger.error(f"Exception in subscriber receive: {e}", exc_info=True) logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
raise e raise e
@ -157,19 +157,20 @@ class Subscriber:
for msg in self.pending_acks.values(): for msg in self.pending_acks.values():
try: try:
self.consumer.negative_acknowledge(msg) self.consumer.negative_acknowledge(msg)
except _pulsar.AlreadyClosed: except Exception:
pass # Consumer already closed pass # Consumer already closed or error
self.pending_acks.clear() self.pending_acks.clear()
if self.consumer: if self.consumer:
if hasattr(self.consumer, 'unsubscribe'):
try: try:
self.consumer.unsubscribe() self.consumer.unsubscribe()
except _pulsar.AlreadyClosed: except Exception:
pass # Already closed pass # Already closed or error
try: try:
self.consumer.close() self.consumer.close()
except _pulsar.AlreadyClosed: except Exception:
pass # Already closed pass # Already closed or error
self.consumer = None self.consumer = None

View file

@ -16,7 +16,7 @@ class SubscriberSpec(Spec):
) )
subscriber = Subscriber( subscriber = Subscriber(
client = processor.pulsar_client, backend = processor.pubsub,
topic = definition[self.name], topic = definition[self.name],
subscription = flow.id, subscription = flow.id,
consumer_name = flow.id, consumer_name = flow.id,

View file

@ -7,6 +7,7 @@ import time
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. exceptions import * from .. exceptions import *
from ..base.pubsub import get_pubsub
# Default timeout for a request/response. In seconds. # Default timeout for a request/response. In seconds.
DEFAULT_TIMEOUT=300 DEFAULT_TIMEOUT=300
@ -39,30 +40,25 @@ class BaseClient:
if subscriber == None: if subscriber == None:
subscriber = str(uuid.uuid4()) subscriber = str(uuid.uuid4())
if pulsar_api_key: # Create backend using factory
auth = pulsar.AuthenticationToken(pulsar_api_key) self.backend = get_pubsub(
self.client = pulsar.Client( pulsar_host=pulsar_host,
pulsar_host, pulsar_api_key=pulsar_api_key,
logger=pulsar.ConsoleLogger(log_level), pulsar_listener=listener,
authentication=auth, pubsub_backend='pulsar'
listener=listener,
)
else:
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
listener_name=listener,
) )
self.producer = self.client.create_producer( self.producer = self.backend.create_producer(
topic=input_queue, topic=input_queue,
schema=JsonSchema(input_schema), schema=input_schema,
chunking_enabled=True, chunking_enabled=True,
) )
self.consumer = self.client.subscribe( self.consumer = self.backend.create_consumer(
output_queue, subscriber, topic=output_queue,
schema=JsonSchema(output_schema), subscription=subscriber,
schema=output_schema,
consumer_type='shared',
) )
self.input_schema = input_schema self.input_schema = input_schema
@ -141,5 +137,6 @@ class BaseClient:
self.producer.flush() self.producer.flush()
self.producer.close() self.producer.close()
self.client.close() if hasattr(self, "backend"):
self.backend.close()

View file

@ -64,7 +64,6 @@ class ConfigClient(BaseClient):
def get(self, keys, timeout=300): def get(self, keys, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="get", operation="get",
keys=[ keys=[
ConfigKey( ConfigKey(
@ -88,7 +87,6 @@ class ConfigClient(BaseClient):
def list(self, type, timeout=300): def list(self, type, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="list", operation="list",
type=type, type=type,
timeout=timeout timeout=timeout
@ -99,7 +97,6 @@ class ConfigClient(BaseClient):
def getvalues(self, type, timeout=300): def getvalues(self, type, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="getvalues", operation="getvalues",
type=type, type=type,
timeout=timeout timeout=timeout
@ -117,7 +114,6 @@ class ConfigClient(BaseClient):
def delete(self, keys, timeout=300): def delete(self, keys, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="delete", operation="delete",
keys=[ keys=[
ConfigKey( ConfigKey(
@ -134,7 +130,6 @@ class ConfigClient(BaseClient):
def put(self, values, timeout=300): def put(self, values, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="put", operation="put",
values=[ values=[
ConfigValue( ConfigValue(
@ -152,7 +147,6 @@ class ConfigClient(BaseClient):
def config(self, timeout=300): def config(self, timeout=300):
resp = self.call( resp = self.call(
id=id,
operation="config", operation="config",
timeout=timeout timeout=timeout
) )

View file

@ -15,8 +15,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
name=data.get("name"), name=data.get("name"),
description=data.get("description"), description=data.get("description"),
tags=data.get("tags"), tags=data.get("tags"),
created_at=data.get("created_at"),
updated_at=data.get("updated_at"),
tag_filter=data.get("tag_filter"), tag_filter=data.get("tag_filter"),
limit=data.get("limit") limit=data.get("limit")
) )
@ -38,10 +36,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
result["description"] = obj.description result["description"] = obj.description
if obj.tags is not None: if obj.tags is not None:
result["tags"] = list(obj.tags) result["tags"] = list(obj.tags)
if obj.created_at is not None:
result["created_at"] = obj.created_at
if obj.updated_at is not None:
result["updated_at"] = obj.updated_at
if obj.tag_filter is not None: if obj.tag_filter is not None:
result["tag_filter"] = list(obj.tag_filter) result["tag_filter"] = list(obj.tag_filter)
if obj.limit is not None: if obj.limit is not None:
@ -73,9 +67,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
collection=coll_data.get("collection"), collection=coll_data.get("collection"),
name=coll_data.get("name"), name=coll_data.get("name"),
description=coll_data.get("description"), description=coll_data.get("description"),
tags=coll_data.get("tags"), tags=coll_data.get("tags", [])
created_at=coll_data.get("created_at"),
updated_at=coll_data.get("updated_at")
)) ))
return CollectionManagementResponse( return CollectionManagementResponse(
@ -104,9 +96,7 @@ class CollectionManagementResponseTranslator(MessageTranslator):
"collection": coll.collection, "collection": coll.collection,
"name": coll.name, "name": coll.name,
"description": coll.description, "description": coll.description,
"tags": list(coll.tags) if coll.tags else [], "tags": list(coll.tags) if coll.tags else []
"created_at": coll.created_at,
"updated_at": coll.updated_at
}) })
print("RESULT IS", result, flush=True) print("RESULT IS", result, flush=True)

View file

@ -57,7 +57,9 @@ class StructuredDataDiagnosisResponseTranslator(MessageTranslator):
result["descriptor"] = obj.descriptor result["descriptor"] = obj.descriptor
if obj.metadata: if obj.metadata:
result["metadata"] = obj.metadata result["metadata"] = obj.metadata
if obj.schema_matches is not None: # For schema-selection, always include schema_matches (even if empty)
# For other operations, only include if non-empty
if obj.operation == "schema-selection" or obj.schema_matches:
result["schema-matches"] = obj.schema_matches result["schema-matches"] = obj.schema_matches
return result return result

View file

@ -43,11 +43,16 @@ class PromptResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
result = {} result = {}
if obj.text: # Include text field if present (even if empty string)
if obj.text is not None:
result["text"] = obj.text result["text"] = obj.text
if obj.object: # Include object field if present
if obj.object is not None:
result["object"] = obj.object result["object"] = obj.object
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
return result return result
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -34,15 +34,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
result = {} result = {}
# Check if this is a streaming response (has chunk) # Include response content (even if empty string)
if hasattr(obj, 'chunk') and obj.chunk: if obj.response is not None:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response result["response"] = obj.response
# Include end_of_stream flag
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Always include error if present # Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type} result["error"] = {"message": obj.error.message, "type": obj.error.type}
@ -51,13 +49,7 @@ class DocumentRagResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False) is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final return self.from_pulsar(obj), is_final
@ -98,15 +90,13 @@ class GraphRagResponseTranslator(MessageTranslator):
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
result = {} result = {}
# Check if this is a streaming response (has chunk) # Include response content (even if empty string)
if hasattr(obj, 'chunk') and obj.chunk: if obj.response is not None:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response result["response"] = obj.response
# Include end_of_stream flag
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
# Always include error if present # Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message: if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type} result["error"] = {"message": obj.error.message, "type": obj.error.type}
@ -115,11 +105,5 @@ class GraphRagResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)""" """Returns (response_dict, is_final)"""
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False) is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final return self.from_pulsar(obj), is_final

View file

@ -36,6 +36,9 @@ class TextCompletionResponseTranslator(MessageTranslator):
if obj.model: if obj.model:
result["model"] = obj.model result["model"] = obj.model
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
return result return result
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -1,16 +1,14 @@
from dataclasses import dataclass, field
from pulsar.schema import Record, String, Array
from .primitives import Triple from .primitives import Triple
class Metadata(Record): @dataclass
class Metadata:
# Source identifier # Source identifier
id = String() id: str = ""
# Subgraph # Subgraph
metadata = Array(Triple()) metadata: list[Triple] = field(default_factory=list)
# Collection management # Collection management
user = String() user: str = ""
collection = String() collection: str = ""

View file

@ -1,34 +1,39 @@
from pulsar.schema import Record, String, Boolean, Array, Integer from dataclasses import dataclass, field
class Error(Record): @dataclass
type = String() class Error:
message = String() type: str = ""
message: str = ""
class Value(Record): @dataclass
value = String() class Value:
is_uri = Boolean() value: str = ""
type = String() is_uri: bool = False
type: str = ""
class Triple(Record): @dataclass
s = Value() class Triple:
p = Value() s: Value | None = None
o = Value() p: Value | None = None
o: Value | None = None
class Field(Record): @dataclass
name = String() class Field:
name: str = ""
# int, string, long, bool, float, double, timestamp # int, string, long, bool, float, double, timestamp
type = String() type: str = ""
size = Integer() size: int = 0
primary = Boolean() primary: bool = False
description = String() description: str = ""
# NEW FIELDS for structured data: # NEW FIELDS for structured data:
required = Boolean() # Whether field is required required: bool = False # Whether field is required
enum_values = Array(String()) # For enum type fields enum_values: list[str] = field(default_factory=list) # For enum type fields
indexed = Boolean() # Whether field should be indexed indexed: bool = False # Whether field should be indexed
class RowSchema(Record): @dataclass
name = String() class RowSchema:
description = String() name: str = ""
fields = Array(Field()) description: str = ""
fields: list[Field] = field(default_factory=list)

View file

@ -1,4 +1,23 @@
def topic(topic, kind='persistent', tenant='tg', namespace='flow'): def topic(queue_name, qos='q1', tenant='tg', namespace='flow'):
return f"{kind}://{tenant}/{namespace}/{topic}" """
Create a generic topic identifier that can be mapped by backends.
Args:
queue_name: The queue/topic name
qos: Quality of service
- 'q0' = best-effort (no ack)
- 'q1' = at-least-once (ack required)
- 'q2' = exactly-once (two-phase ack)
tenant: Tenant identifier for multi-tenancy
namespace: Namespace within tenant
Returns:
Generic topic string: qos/tenant/namespace/queue_name
Examples:
topic('my-queue') # q1/tg/flow/my-queue
topic('config', qos='q2', namespace='config') # q2/tg/config/config
"""
return f"{qos}/{tenant}/{namespace}/{queue_name}"

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, Bytes from dataclasses import dataclass
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.topic import topic from ..core.topic import topic
@ -6,24 +6,27 @@ from ..core.topic import topic
############################################################################ ############################################################################
# PDF docs etc. # PDF docs etc.
class Document(Record): @dataclass
metadata = Metadata() class Document:
data = Bytes() metadata: Metadata | None = None
data: bytes = b""
############################################################################ ############################################################################
# Text documents / text from PDF # Text documents / text from PDF
class TextDocument(Record): @dataclass
metadata = Metadata() class TextDocument:
text = Bytes() metadata: Metadata | None = None
text: bytes = b""
############################################################################ ############################################################################
# Chunks of text # Chunks of text
class Chunk(Record): @dataclass
metadata = Metadata() class Chunk:
chunk = Bytes() metadata: Metadata | None = None
chunk: bytes = b""
############################################################################ ############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double, Map from dataclasses import dataclass, field
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.primitives import Value, RowSchema from ..core.primitives import Value, RowSchema
@ -8,49 +8,55 @@ from ..core.topic import topic
# Graph embeddings are embeddings associated with a graph entity # Graph embeddings are embeddings associated with a graph entity
class EntityEmbeddings(Record): @dataclass
entity = Value() class EntityEmbeddings:
vectors = Array(Array(Double())) entity: Value | None = None
vectors: list[list[float]] = field(default_factory=list)
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data
class GraphEmbeddings(Record): @dataclass
metadata = Metadata() class GraphEmbeddings:
entities = Array(EntityEmbeddings()) metadata: Metadata | None = None
entities: list[EntityEmbeddings] = field(default_factory=list)
############################################################################ ############################################################################
# Document embeddings are embeddings associated with a chunk # Document embeddings are embeddings associated with a chunk
class ChunkEmbeddings(Record): @dataclass
chunk = Bytes() class ChunkEmbeddings:
vectors = Array(Array(Double())) chunk: bytes = b""
vectors: list[list[float]] = field(default_factory=list)
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data
class DocumentEmbeddings(Record): @dataclass
metadata = Metadata() class DocumentEmbeddings:
chunks = Array(ChunkEmbeddings()) metadata: Metadata | None = None
chunks: list[ChunkEmbeddings] = field(default_factory=list)
############################################################################ ############################################################################
# Object embeddings are embeddings associated with the primary key of an # Object embeddings are embeddings associated with the primary key of an
# object # object
class ObjectEmbeddings(Record): @dataclass
metadata = Metadata() class ObjectEmbeddings:
vectors = Array(Array(Double())) metadata: Metadata | None = None
name = String() vectors: list[list[float]] = field(default_factory=list)
key_name = String() name: str = ""
id = String() key_name: str = ""
id: str = ""
############################################################################ ############################################################################
# Structured object embeddings with enhanced capabilities # Structured object embeddings with enhanced capabilities
class StructuredObjectEmbedding(Record): @dataclass
metadata = Metadata() class StructuredObjectEmbedding:
vectors = Array(Array(Double())) metadata: Metadata | None = None
schema_name = String() vectors: list[list[float]] = field(default_factory=list)
object_id = String() # Primary key value schema_name: str = ""
field_embeddings = Map(Array(Double())) # Per-field embeddings object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
############################################################################ ############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Array from dataclasses import dataclass, field
from ..core.primitives import Value, Triple from ..core.primitives import Value, Triple
from ..core.metadata import Metadata from ..core.metadata import Metadata
@ -8,21 +8,24 @@ from ..core.topic import topic
# Entity context are an entity associated with textual context # Entity context are an entity associated with textual context
class EntityContext(Record): @dataclass
entity = Value() class EntityContext:
context = String() entity: Value | None = None
context: str = ""
# This is a 'batching' mechanism for the above data # This is a 'batching' mechanism for the above data
class EntityContexts(Record): @dataclass
metadata = Metadata() class EntityContexts:
entities = Array(EntityContext()) metadata: Metadata | None = None
entities: list[EntityContext] = field(default_factory=list)
############################################################################ ############################################################################
# Graph triples # Graph triples
class Triples(Record): @dataclass
metadata = Metadata() class Triples:
triples = Array(Triple()) metadata: Metadata | None = None
triples: list[Triple] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from pulsar.schema import Record, Bytes, String, Array, Long, Boolean
from ..core.primitives import Triple, Error from ..core.primitives import Triple, Error
from ..core.topic import topic from ..core.topic import topic
from ..core.metadata import Metadata from ..core.metadata import Metadata
@ -22,40 +21,40 @@ from .embeddings import GraphEmbeddings
# <- () # <- ()
# <- (error) # <- (error)
class KnowledgeRequest(Record): @dataclass
class KnowledgeRequest:
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core # get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
# load-kg-core, unload-kg-core # load-kg-core, unload-kg-core
operation = String() operation: str = ""
# list-kg-cores, delete-kg-core, put-kg-core # list-kg-cores, delete-kg-core, put-kg-core
user = String() user: str = ""
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
# load-kg-core, unload-kg-core # load-kg-core, unload-kg-core
id = String() id: str = ""
# load-kg-core # load-kg-core
flow = String() flow: str = ""
# load-kg-core # load-kg-core
collection = String() collection: str = ""
# put-kg-core # put-kg-core
triples = Triples() triples: Triples | None = None
graph_embeddings = GraphEmbeddings() graph_embeddings: GraphEmbeddings | None = None
class KnowledgeResponse(Record): @dataclass
error = Error() class KnowledgeResponse:
ids = Array(String()) error: Error | None = None
eos = Boolean() # Indicates end of knowledge core stream ids: list[str] = field(default_factory=list)
triples = Triples() eos: bool = False # Indicates end of knowledge core stream
graph_embeddings = GraphEmbeddings() triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None
knowledge_request_queue = topic( knowledge_request_queue = topic(
'knowledge', kind='non-persistent', namespace='request' 'knowledge', qos='q0', namespace='request'
) )
knowledge_response_queue = topic( knowledge_response_queue = topic(
'knowledge', kind='non-persistent', namespace='response', 'knowledge', qos='q0', namespace='response',
) )

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Boolean from dataclasses import dataclass
from ..core.topic import topic from ..core.topic import topic
@ -6,21 +6,25 @@ from ..core.topic import topic
# NLP extraction data types # NLP extraction data types
class Definition(Record): @dataclass
name = String() class Definition:
definition = String() name: str = ""
definition: str = ""
class Topic(Record): @dataclass
name = String() class Topic:
definition = String() name: str = ""
definition: str = ""
class Relationship(Record): @dataclass
s = String() class Relationship:
p = String() s: str = ""
o = String() p: str = ""
o_entity = Boolean() o: str = ""
o_entity: bool = False
class Fact(Record): @dataclass
s = String() class Fact:
p = String() s: str = ""
o = String() p: str = ""
o: str = ""

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Map, Double, Array from dataclasses import dataclass, field
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.topic import topic from ..core.topic import topic
@ -7,11 +7,13 @@ from ..core.topic import topic
# Extracted object from text processing # Extracted object from text processing
class ExtractedObject(Record): @dataclass
metadata = Metadata() class ExtractedObject:
schema_name = String() # Which schema this object belongs to metadata: Metadata | None = None
values = Array(Map(String())) # Array of objects, each object is field name -> value schema_name: str = "" # Which schema this object belongs to
confidence = Double() values: list[dict[str, str]] = field(default_factory=list) # Array of objects, each object is field name -> value
source_span = String() # Text span where object was found confidence: float = 0.0
source_span: str = "" # Text span where object was found
############################################################################ ############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, Array, Map, String from dataclasses import dataclass, field
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.primitives import RowSchema from ..core.primitives import RowSchema
@ -8,9 +8,10 @@ from ..core.topic import topic
# Stores rows of information # Stores rows of information
class Rows(Record): @dataclass
metadata = Metadata() class Rows:
row_schema = RowSchema() metadata: Metadata | None = None
rows = Array(Map(String())) row_schema: RowSchema | None = None
rows: list[dict[str, str]] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Bytes, Map from dataclasses import dataclass, field
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..core.topic import topic from ..core.topic import topic
@ -7,11 +7,13 @@ from ..core.topic import topic
# Structured data submission for fire-and-forget processing # Structured data submission for fire-and-forget processing
class StructuredDataSubmission(Record): @dataclass
metadata = Metadata() class StructuredDataSubmission:
format = String() # "json", "csv", "xml" metadata: Metadata | None = None
schema_name = String() # Reference to schema in config format: str = "" # "json", "csv", "xml"
data = Bytes() # Raw data to ingest schema_name: str = "" # Reference to schema in config
options = Map(String()) # Format-specific options data: bytes = b"" # Raw data to ingest
options: dict[str, str] = field(default_factory=dict) # Format-specific options
############################################################################ ############################################################################

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, String, Array, Map, Boolean from dataclasses import dataclass, field
from ..core.topic import topic from ..core.topic import topic
from ..core.primitives import Error from ..core.primitives import Error
@ -8,33 +8,36 @@ from ..core.primitives import Error
# Prompt services, abstract the prompt generation # Prompt services, abstract the prompt generation
class AgentStep(Record): @dataclass
thought = String() class AgentStep:
action = String() thought: str = ""
arguments = Map(String()) action: str = ""
observation = String() arguments: dict[str, str] = field(default_factory=dict)
user = String() # User context for the step observation: str = ""
user: str = "" # User context for the step
class AgentRequest(Record): @dataclass
question = String() class AgentRequest:
state = String() question: str = ""
group = Array(String()) state: str = ""
history = Array(AgentStep()) group: list[str] | None = None
user = String() # User context for multi-tenancy history: list[AgentStep] = field(default_factory=list)
streaming = Boolean() # NEW: Enable streaming response delivery (default false) user: str = "" # User context for multi-tenancy
streaming: bool = False # NEW: Enable streaming response delivery (default false)
class AgentResponse(Record): @dataclass
class AgentResponse:
# Streaming-first design # Streaming-first design
chunk_type = String() # "thought", "action", "observation", "answer", "error" chunk_type: str = "" # "thought", "action", "observation", "answer", "error"
content = String() # The actual content (interpretation depends on chunk_type) content: str = "" # The actual content (interpretation depends on chunk_type)
end_of_message = Boolean() # Current chunk type (thought/action/etc.) is complete end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
end_of_dialog = Boolean() # Entire agent dialog is complete end_of_dialog: bool = False # Entire agent dialog is complete
# Legacy fields (deprecated but kept for backward compatibility) # Legacy fields (deprecated but kept for backward compatibility)
answer = String() answer: str = ""
error = Error() error: Error | None = None
thought = String() thought: str = ""
observation = String() observation: str = ""
############################################################################ ############################################################################

View file

@ -1,4 +1,4 @@
from pulsar.schema import Record, String, Integer, Array from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from ..core.primitives import Error from ..core.primitives import Error
@ -10,41 +10,40 @@ from ..core.topic import topic
# Collection metadata operations (for librarian service) # Collection metadata operations (for librarian service)
class CollectionMetadata(Record): @dataclass
class CollectionMetadata:
"""Collection metadata record""" """Collection metadata record"""
user = String() user: str = ""
collection = String() collection: str = ""
name = String() name: str = ""
description = String() description: str = ""
tags = Array(String()) tags: list[str] = field(default_factory=list)
created_at = String() # ISO timestamp
updated_at = String() # ISO timestamp
############################################################################ ############################################################################
class CollectionManagementRequest(Record): @dataclass
class CollectionManagementRequest:
"""Request for collection management operations""" """Request for collection management operations"""
operation = String() # e.g., "delete-collection" operation: str = "" # e.g., "delete-collection"
# For 'list-collections' # For 'list-collections'
user = String() user: str = ""
collection = String() collection: str = ""
timestamp = String() # ISO timestamp timestamp: str = "" # ISO timestamp
name = String() name: str = ""
description = String() description: str = ""
tags = Array(String()) tags: list[str] = field(default_factory=list)
created_at = String() # ISO timestamp
updated_at = String() # ISO timestamp
# For list # For list
tag_filter = Array(String()) # Optional filter by tags tag_filter: list[str] = field(default_factory=list) # Optional filter by tags
limit = Integer() limit: int = 0
class CollectionManagementResponse(Record): @dataclass
class CollectionManagementResponse:
"""Response for collection management operations""" """Response for collection management operations"""
error = Error() # Only populated if there's an error error: Error | None = None # Only populated if there's an error
timestamp = String() # ISO timestamp timestamp: str = "" # ISO timestamp
collections = Array(CollectionMetadata()) collections: list[CollectionMetadata] = field(default_factory=list)
############################################################################ ############################################################################
@ -52,8 +51,9 @@ class CollectionManagementResponse(Record):
# Topics # Topics
collection_request_queue = topic( collection_request_queue = topic(
'collection', kind='non-persistent', namespace='request' 'collection', qos='q0', namespace='request'
) )
collection_response_queue = topic( collection_response_queue = topic(
'collection', kind='non-persistent', namespace='response' 'collection', qos='q0', namespace='response'
) )

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer from dataclasses import dataclass, field
from ..core.topic import topic from ..core.topic import topic
from ..core.primitives import Error from ..core.primitives import Error
@ -13,58 +13,61 @@ from ..core.primitives import Error
# put(values) -> () # put(values) -> ()
# delete(keys) -> () # delete(keys) -> ()
# config() -> (version, config) # config() -> (version, config)
class ConfigKey(Record): @dataclass
type = String() class ConfigKey:
key = String() type: str = ""
key: str = ""
class ConfigValue(Record): @dataclass
type = String() class ConfigValue:
key = String() type: str = ""
value = String() key: str = ""
value: str = ""
# Prompt services, abstract the prompt generation # Prompt services, abstract the prompt generation
class ConfigRequest(Record): @dataclass
class ConfigRequest:
operation = String() # get, list, getvalues, delete, put, config operation: str = "" # get, list, getvalues, delete, put, config
# get, delete # get, delete
keys = Array(ConfigKey()) keys: list[ConfigKey] = field(default_factory=list)
# list, getvalues # list, getvalues
type = String() type: str = ""
# put # put
values = Array(ConfigValue()) values: list[ConfigValue] = field(default_factory=list)
class ConfigResponse(Record):
@dataclass
class ConfigResponse:
# get, list, getvalues, config # get, list, getvalues, config
version = Integer() version: int = 0
# get, getvalues # get, getvalues
values = Array(ConfigValue()) values: list[ConfigValue] = field(default_factory=list)
# list # list
directory = Array(String()) directory: list[str] = field(default_factory=list)
# config # config
config = Map(Map(String())) config: dict[str, dict[str, str]] = field(default_factory=dict)
# Everything # Everything
error = Error() error: Error | None = None
class ConfigPush(Record): @dataclass
version = Integer() class ConfigPush:
config = Map(Map(String())) version: int = 0
config: dict[str, dict[str, str]] = field(default_factory=dict)
config_request_queue = topic( config_request_queue = topic(
'config', kind='non-persistent', namespace='request' 'config', qos='q0', namespace='request'
) )
config_response_queue = topic( config_response_queue = topic(
'config', kind='non-persistent', namespace='response' 'config', qos='q0', namespace='response'
) )
config_push_queue = topic( config_push_queue = topic(
'config', kind='persistent', namespace='config' 'config', qos='q2', namespace='config'
) )
############################################################################ ############################################################################

View file

@ -1,33 +1,36 @@
from pulsar.schema import Record, String, Map, Double, Array from dataclasses import dataclass, field
from ..core.primitives import Error from ..core.primitives import Error
############################################################################ ############################################################################
# Structured data diagnosis services # Structured data diagnosis services
class StructuredDataDiagnosisRequest(Record): @dataclass
operation = String() # "detect-type", "generate-descriptor", "diagnose", or "schema-selection" class StructuredDataDiagnosisRequest:
sample = String() # Data sample to analyze (text content) operation: str = "" # "detect-type", "generate-descriptor", "diagnose", or "schema-selection"
type = String() # Data type (csv, json, xml) - optional, required for generate-descriptor sample: str = "" # Data sample to analyze (text content)
schema_name = String() # Target schema name for descriptor generation - optional type: str = "" # Data type (csv, json, xml) - optional, required for generate-descriptor
schema_name: str = "" # Target schema name for descriptor generation - optional
# JSON encoded options (e.g., delimiter for CSV) # JSON encoded options (e.g., delimiter for CSV)
options = Map(String()) options: dict[str, str] = field(default_factory=dict)
class StructuredDataDiagnosisResponse(Record): @dataclass
error = Error() class StructuredDataDiagnosisResponse:
error: Error | None = None
operation = String() # The operation that was performed operation: str = "" # The operation that was performed
detected_type = String() # Detected data type (for detect-type/diagnose) - optional detected_type: str = "" # Detected data type (for detect-type/diagnose) - optional
confidence = Double() # Confidence score for type detection - optional confidence: float = 0.0 # Confidence score for type detection - optional
# JSON encoded descriptor (for generate-descriptor/diagnose) - optional # JSON encoded descriptor (for generate-descriptor/diagnose) - optional
descriptor = String() descriptor: str = ""
# JSON encoded additional metadata (e.g., field count, sample records) # JSON encoded additional metadata (e.g., field count, sample records)
metadata = Map(String()) metadata: dict[str, str] = field(default_factory=dict)
# Array of matching schema IDs (for schema-selection operation) - optional # Array of matching schema IDs (for schema-selection operation) - optional
schema_matches = Array(String()) schema_matches: list[str] = field(default_factory=list)
############################################################################ ############################################################################

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer from dataclasses import dataclass, field
from ..core.topic import topic from ..core.topic import topic
from ..core.primitives import Error from ..core.primitives import Error
@ -18,54 +18,54 @@ from ..core.primitives import Error
# stop_flow(flowid) -> () # stop_flow(flowid) -> ()
# Prompt services, abstract the prompt generation # Prompt services, abstract the prompt generation
class FlowRequest(Record): @dataclass
class FlowRequest:
operation = String() # list-classes, get-class, put-class, delete-class operation: str = "" # list-classes, get-class, put-class, delete-class
# list-flows, get-flow, start-flow, stop-flow # list-flows, get-flow, start-flow, stop-flow
# get_class, put_class, delete_class, start_flow # get_class, put_class, delete_class, start_flow
class_name = String() class_name: str = ""
# put_class # put_class
class_definition = String() class_definition: str = ""
# start_flow # start_flow
description = String() description: str = ""
# get_flow, start_flow, stop_flow # get_flow, start_flow, stop_flow
flow_id = String() flow_id: str = ""
# start_flow - optional parameters for flow customization # start_flow - optional parameters for flow customization
parameters = Map(String()) parameters: dict[str, str] = field(default_factory=dict)
class FlowResponse(Record):
@dataclass
class FlowResponse:
# list_classes # list_classes
class_names = Array(String()) class_names: list[str] = field(default_factory=list)
# list_flows # list_flows
flow_ids = Array(String()) flow_ids: list[str] = field(default_factory=list)
# get_class # get_class
class_definition = String() class_definition: str = ""
# get_flow # get_flow
flow = String() flow: str = ""
# get_flow # get_flow
description = String() description: str = ""
# get_flow - parameters used when flow was started # get_flow - parameters used when flow was started
parameters = Map(String()) parameters: dict[str, str] = field(default_factory=dict)
# Everything # Everything
error = Error() error: Error | None = None
flow_request_queue = topic( flow_request_queue = topic(
'flow', kind='non-persistent', namespace='request' 'flow', qos='q0', namespace='request'
) )
flow_response_queue = topic( flow_response_queue = topic(
'flow', kind='non-persistent', namespace='response' 'flow', qos='q0', namespace='response'
) )
############################################################################ ############################################################################

View file

@ -1,9 +1,8 @@
from dataclasses import dataclass, field
from pulsar.schema import Record, Bytes, String, Array, Long
from ..core.primitives import Triple, Error from ..core.primitives import Triple, Error
from ..core.topic import topic from ..core.topic import topic
from ..core.metadata import Metadata from ..core.metadata import Metadata
from ..knowledge.document import Document, TextDocument # Note: Document imports will be updated after knowledge schemas are converted
# add-document # add-document
# -> (document_id, document_metadata, content) # -> (document_id, document_metadata, content)
@ -50,76 +49,79 @@ from ..knowledge.document import Document, TextDocument
# <- (processing_metadata[]) # <- (processing_metadata[])
# <- (error) # <- (error)
class DocumentMetadata(Record): @dataclass
id = String() class DocumentMetadata:
time = Long() id: str = ""
kind = String() time: int = 0
title = String() kind: str = ""
comments = String() title: str = ""
metadata = Array(Triple()) comments: str = ""
user = String() metadata: list[Triple] = field(default_factory=list)
tags = Array(String()) user: str = ""
tags: list[str] = field(default_factory=list)
class ProcessingMetadata(Record): @dataclass
id = String() class ProcessingMetadata:
document_id = String() id: str = ""
time = Long() document_id: str = ""
flow = String() time: int = 0
user = String() flow: str = ""
collection = String() user: str = ""
tags = Array(String()) collection: str = ""
tags: list[str] = field(default_factory=list)
class Criteria(Record): @dataclass
key = String() class Criteria:
value = String() key: str = ""
operator = String() value: str = ""
operator: str = ""
class LibrarianRequest(Record):
@dataclass
class LibrarianRequest:
# add-document, remove-document, update-document, get-document-metadata, # add-document, remove-document, update-document, get-document-metadata,
# get-document-content, add-processing, remove-processing, list-documents, # get-document-content, add-processing, remove-processing, list-documents,
# list-processing # list-processing
operation = String() operation: str = ""
# add-document, remove-document, update-document, get-document-metadata, # add-document, remove-document, update-document, get-document-metadata,
# get-document-content # get-document-content
document_id = String() document_id: str = ""
# add-processing, remove-processing # add-processing, remove-processing
processing_id = String() processing_id: str = ""
# add-document, update-document # add-document, update-document
document_metadata = DocumentMetadata() document_metadata: DocumentMetadata | None = None
# add-processing # add-processing
processing_metadata = ProcessingMetadata() processing_metadata: ProcessingMetadata | None = None
# add-document # add-document
content = Bytes() content: bytes = b""
# list-documents, list-processing # list-documents, list-processing
user = String() user: str = ""
# list-documents?, list-processing? # list-documents?, list-processing?
collection = String() collection: str = ""
# #
criteria = Array(Criteria()) criteria: list[Criteria] = field(default_factory=list)
class LibrarianResponse(Record): @dataclass
error = Error() class LibrarianResponse:
document_metadata = DocumentMetadata() error: Error | None = None
content = Bytes() document_metadata: DocumentMetadata | None = None
document_metadatas = Array(DocumentMetadata()) content: bytes = b""
processing_metadatas = Array(ProcessingMetadata()) document_metadatas: list[DocumentMetadata] = field(default_factory=list)
processing_metadatas: list[ProcessingMetadata] = field(default_factory=list)
# FIXME: Is this right? Using persistence on librarian so that # FIXME: Is this right? Using persistence on librarian so that
# message chunking works # message chunking works
librarian_request_queue = topic( librarian_request_queue = topic(
'librarian', kind='persistent', namespace='request' 'librarian', qos='q1', namespace='request'
) )
librarian_response_queue = topic( librarian_response_queue = topic(
'librarian', kind='persistent', namespace='response', 'librarian', qos='q1', namespace='response',
) )

Some files were not shown because too many files have changed in this diff Show more